Plumb an optional default Executor and set of input side packets
through TaskApiFactory::Create so that consumers of that API can provide these inputs to their underlying graph. PiperOrigin-RevId: 574503266
This commit is contained in:
		
							parent
							
								
									e27bbf15dc
								
							
						
					
					
						commit
						2bd6726c89
					
				|  | @ -261,6 +261,7 @@ cc_library_with_tflite( | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/framework:calculator_cc_proto", |         "//mediapipe/framework:calculator_cc_proto", | ||||||
|         "//mediapipe/framework:calculator_framework", |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework:executor", | ||||||
|         "//mediapipe/framework/port:status", |         "//mediapipe/framework/port:status", | ||||||
|         "//mediapipe/framework/tool:name_util", |         "//mediapipe/framework/tool:name_util", | ||||||
|         "//mediapipe/tasks/cc:common", |         "//mediapipe/tasks/cc:common", | ||||||
|  | @ -319,6 +320,7 @@ cc_library( | ||||||
|         ":task_runner", |         ":task_runner", | ||||||
|         ":utils", |         ":utils", | ||||||
|         "//mediapipe/framework:calculator_cc_proto", |         "//mediapipe/framework:calculator_cc_proto", | ||||||
|  |         "//mediapipe/framework:executor", | ||||||
|         "//mediapipe/framework/port:requires", |         "//mediapipe/framework/port:requires", | ||||||
|         "//mediapipe/framework/port:status", |         "//mediapipe/framework/port:status", | ||||||
|         "//mediapipe/tasks/cc:common", |         "//mediapipe/tasks/cc:common", | ||||||
|  |  | ||||||
|  | @ -17,6 +17,7 @@ limitations under the License. | ||||||
| #define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_ | #define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_ | ||||||
| 
 | 
 | ||||||
| #include <memory> | #include <memory> | ||||||
|  | #include <optional> | ||||||
| #include <string> | #include <string> | ||||||
| #include <type_traits> | #include <type_traits> | ||||||
| #include <utility> | #include <utility> | ||||||
|  | @ -26,6 +27,7 @@ limitations under the License. | ||||||
| #include "absl/strings/match.h" | #include "absl/strings/match.h" | ||||||
| #include "absl/strings/str_cat.h" | #include "absl/strings/str_cat.h" | ||||||
| #include "mediapipe/framework/calculator.pb.h" | #include "mediapipe/framework/calculator.pb.h" | ||||||
|  | #include "mediapipe/framework/executor.h" | ||||||
| #include "mediapipe/framework/port/requires.h" | #include "mediapipe/framework/port/requires.h" | ||||||
| #include "mediapipe/framework/port/status_macros.h" | #include "mediapipe/framework/port/status_macros.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
|  | @ -56,7 +58,9 @@ class TaskApiFactory { | ||||||
|   static absl::StatusOr<std::unique_ptr<T>> Create( |   static absl::StatusOr<std::unique_ptr<T>> Create( | ||||||
|       CalculatorGraphConfig graph_config, |       CalculatorGraphConfig graph_config, | ||||||
|       std::unique_ptr<tflite::OpResolver> resolver, |       std::unique_ptr<tflite::OpResolver> resolver, | ||||||
|       PacketsCallback packets_callback = nullptr) { |       PacketsCallback packets_callback = nullptr, | ||||||
|  |       std::shared_ptr<Executor> default_executor = nullptr, | ||||||
|  |       std::optional<PacketMap> input_side_packets = std::nullopt) { | ||||||
|     bool found_task_subgraph = false; |     bool found_task_subgraph = false; | ||||||
|     // This for-loop ensures there's only one subgraph besides
 |     // This for-loop ensures there's only one subgraph besides
 | ||||||
|     // FlowLimiterCalculator.
 |     // FlowLimiterCalculator.
 | ||||||
|  | @ -77,7 +81,9 @@ class TaskApiFactory { | ||||||
|     MP_ASSIGN_OR_RETURN( |     MP_ASSIGN_OR_RETURN( | ||||||
|         auto runner, |         auto runner, | ||||||
|         core::TaskRunner::Create(std::move(graph_config), std::move(resolver), |         core::TaskRunner::Create(std::move(graph_config), std::move(resolver), | ||||||
|                                  std::move(packets_callback))); |                                  std::move(packets_callback), | ||||||
|  |                                  std::move(default_executor), | ||||||
|  |                                  std::move(input_side_packets))); | ||||||
|     return std::make_unique<T>(std::move(runner)); |     return std::make_unique<T>(std::move(runner)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -21,6 +21,7 @@ limitations under the License. | ||||||
| #include <iterator> | #include <iterator> | ||||||
| #include <map> | #include <map> | ||||||
| #include <memory> | #include <memory> | ||||||
|  | #include <optional> | ||||||
| #include <string> | #include <string> | ||||||
| #include <utility> | #include <utility> | ||||||
| #include <vector> | #include <vector> | ||||||
|  | @ -33,6 +34,7 @@ limitations under the License. | ||||||
| #include "absl/synchronization/mutex.h" | #include "absl/synchronization/mutex.h" | ||||||
| #include "mediapipe/framework/calculator.pb.h" | #include "mediapipe/framework/calculator.pb.h" | ||||||
| #include "mediapipe/framework/calculator_framework.h" | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/executor.h" | ||||||
| #include "mediapipe/framework/tool/name_util.h" | #include "mediapipe/framework/tool/name_util.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
| #include "mediapipe/tasks/cc/core/model_resources_cache.h" | #include "mediapipe/tasks/cc/core/model_resources_cache.h" | ||||||
|  | @ -89,17 +91,22 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap( | ||||||
| absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create( | absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create( | ||||||
|     CalculatorGraphConfig config, |     CalculatorGraphConfig config, | ||||||
|     std::unique_ptr<tflite::OpResolver> op_resolver, |     std::unique_ptr<tflite::OpResolver> op_resolver, | ||||||
|     PacketsCallback packets_callback) { |     PacketsCallback packets_callback, | ||||||
|  |     std::shared_ptr<Executor> default_executor, | ||||||
|  |     std::optional<PacketMap> input_side_packets) { | ||||||
|   auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); |   auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); | ||||||
|   MP_RETURN_IF_ERROR( |   MP_RETURN_IF_ERROR(task_runner->Initialize( | ||||||
|       task_runner->Initialize(std::move(config), std::move(op_resolver))); |       std::move(config), std::move(op_resolver), std::move(default_executor), | ||||||
|  |       std::move(input_side_packets))); | ||||||
|   MP_RETURN_IF_ERROR(task_runner->Start()); |   MP_RETURN_IF_ERROR(task_runner->Start()); | ||||||
|   return task_runner; |   return task_runner; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::Status TaskRunner::Initialize( | absl::Status TaskRunner::Initialize( | ||||||
|     CalculatorGraphConfig config, |     CalculatorGraphConfig config, | ||||||
|     std::unique_ptr<tflite::OpResolver> op_resolver) { |     std::unique_ptr<tflite::OpResolver> op_resolver, | ||||||
|  |     std::shared_ptr<Executor> default_executor, | ||||||
|  |     std::optional<PacketMap> input_side_packets) { | ||||||
|   if (initialized_) { |   if (initialized_) { | ||||||
|     return CreateStatusWithPayload( |     return CreateStatusWithPayload( | ||||||
|         absl::StatusCode::kInvalidArgument, |         absl::StatusCode::kInvalidArgument, | ||||||
|  | @ -123,7 +130,9 @@ absl::Status TaskRunner::Initialize( | ||||||
|         MediaPipeTasksStatus::kRunnerInitializationError); |         MediaPipeTasksStatus::kRunnerInitializationError); | ||||||
|   } |   } | ||||||
|   config.clear_output_stream(); |   config.clear_output_stream(); | ||||||
|   PacketMap input_side_packets; |   if (!input_side_packets) { | ||||||
|  |     input_side_packets.emplace(); | ||||||
|  |   } | ||||||
|   if (packets_callback_) { |   if (packets_callback_) { | ||||||
|     tool::AddMultiStreamCallback( |     tool::AddMultiStreamCallback( | ||||||
|         output_stream_names_, |         output_stream_names_, | ||||||
|  | @ -132,7 +141,7 @@ absl::Status TaskRunner::Initialize( | ||||||
|               GenerateOutputPacketMap(packets, output_stream_names_)); |               GenerateOutputPacketMap(packets, output_stream_names_)); | ||||||
|           return; |           return; | ||||||
|         }, |         }, | ||||||
|         &config, &input_side_packets, |         &config, &input_side_packets.value(), | ||||||
|         /*observe_timestamp_bounds=*/true); |         /*observe_timestamp_bounds=*/true); | ||||||
|   } else { |   } else { | ||||||
|     mediapipe::tool::AddMultiStreamCallback( |     mediapipe::tool::AddMultiStreamCallback( | ||||||
|  | @ -142,8 +151,14 @@ absl::Status TaskRunner::Initialize( | ||||||
|               GenerateOutputPacketMap(packets, output_stream_names_); |               GenerateOutputPacketMap(packets, output_stream_names_); | ||||||
|           return; |           return; | ||||||
|         }, |         }, | ||||||
|         &config, &input_side_packets, /*observe_timestamp_bounds=*/true); |         &config, &input_side_packets.value(), | ||||||
|  |         /*observe_timestamp_bounds=*/true); | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   if (default_executor) { | ||||||
|  |     MP_RETURN_IF_ERROR(graph_.SetExecutor("", std::move(default_executor))); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   auto model_resources_cache = |   auto model_resources_cache = | ||||||
|       std::make_shared<ModelResourcesCache>(std::move(op_resolver)); |       std::make_shared<ModelResourcesCache>(std::move(op_resolver)); | ||||||
|   MP_RETURN_IF_ERROR( |   MP_RETURN_IF_ERROR( | ||||||
|  | @ -152,7 +167,7 @@ absl::Status TaskRunner::Initialize( | ||||||
|                  "ModelResourcesCacheService is not set up successfully.", |                  "ModelResourcesCacheService is not set up successfully.", | ||||||
|                  MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError)); |                  MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError)); | ||||||
|   MP_RETURN_IF_ERROR( |   MP_RETURN_IF_ERROR( | ||||||
|       AddPayload(graph_.Initialize(std::move(config), input_side_packets), |       AddPayload(graph_.Initialize(std::move(config), *input_side_packets), | ||||||
|                  "MediaPipe CalculatorGraph is not successfully initialized.", |                  "MediaPipe CalculatorGraph is not successfully initialized.", | ||||||
|                  MediaPipeTasksStatus::kRunnerInitializationError)); |                  MediaPipeTasksStatus::kRunnerInitializationError)); | ||||||
|   initialized_ = true; |   initialized_ = true; | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ limitations under the License. | ||||||
| #include <functional> | #include <functional> | ||||||
| #include <map> | #include <map> | ||||||
| #include <memory> | #include <memory> | ||||||
|  | #include <optional> | ||||||
| #include <string> | #include <string> | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
|  | @ -34,6 +35,7 @@ limitations under the License. | ||||||
| #include "absl/synchronization/mutex.h" | #include "absl/synchronization/mutex.h" | ||||||
| #include "mediapipe/framework/calculator.pb.h" | #include "mediapipe/framework/calculator.pb.h" | ||||||
| #include "mediapipe/framework/calculator_framework.h" | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/executor.h" | ||||||
| #include "mediapipe/framework/port/status_macros.h" | #include "mediapipe/framework/port/status_macros.h" | ||||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | #include "mediapipe/tasks/cc/core/model_resources.h" | ||||||
| #include "mediapipe/tasks/cc/core/model_resources_cache.h" | #include "mediapipe/tasks/cc/core/model_resources_cache.h" | ||||||
|  | @ -73,7 +75,9 @@ class TaskRunner { | ||||||
|   static absl::StatusOr<std::unique_ptr<TaskRunner>> Create( |   static absl::StatusOr<std::unique_ptr<TaskRunner>> Create( | ||||||
|       CalculatorGraphConfig config, |       CalculatorGraphConfig config, | ||||||
|       std::unique_ptr<tflite::OpResolver> op_resolver = nullptr, |       std::unique_ptr<tflite::OpResolver> op_resolver = nullptr, | ||||||
|       PacketsCallback packets_callback = nullptr); |       PacketsCallback packets_callback = nullptr, | ||||||
|  |       std::shared_ptr<Executor> default_executor = nullptr, | ||||||
|  |       std::optional<PacketMap> input_side_packets = std::nullopt); | ||||||
| 
 | 
 | ||||||
|   // TaskRunner is neither copyable nor movable.
 |   // TaskRunner is neither copyable nor movable.
 | ||||||
|   TaskRunner(const TaskRunner&) = delete; |   TaskRunner(const TaskRunner&) = delete; | ||||||
|  | @ -125,7 +129,9 @@ class TaskRunner { | ||||||
|   // be only initialized once.
 |   // be only initialized once.
 | ||||||
|   absl::Status Initialize( |   absl::Status Initialize( | ||||||
|       CalculatorGraphConfig config, |       CalculatorGraphConfig config, | ||||||
|       std::unique_ptr<tflite::OpResolver> op_resolver = nullptr); |       std::unique_ptr<tflite::OpResolver> op_resolver = nullptr, | ||||||
|  |       std::shared_ptr<Executor> default_executor = nullptr, | ||||||
|  |       std::optional<PacketMap> input_side_packets = std::nullopt); | ||||||
| 
 | 
 | ||||||
|   // Starts the task runner. Returns an ok status to indicate that the
 |   // Starts the task runner. Returns an ok status to indicate that the
 | ||||||
|   // runner is ready to accept input data. Otherwise, returns an error status to
 |   // runner is ready to accept input data. Otherwise, returns an error status to
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user