Plumb an optional default Executor and set of input side packets

through TaskApiFactory::Create so that consumers of that API
can provide these inputs to their underlying graph.

PiperOrigin-RevId: 574503266
This commit is contained in:
MediaPipe Team 2023-10-18 09:41:28 -07:00 committed by Copybara-Service
parent e27bbf15dc
commit 2bd6726c89
4 changed files with 41 additions and 12 deletions

View File

@ -261,6 +261,7 @@ cc_library_with_tflite(
deps = [ deps = [
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:executor",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util", "//mediapipe/framework/tool:name_util",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
@ -319,6 +320,7 @@ cc_library(
":task_runner", ":task_runner",
":utils", ":utils",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:executor",
"//mediapipe/framework/port:requires", "//mediapipe/framework/port:requires",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",

View File

@ -17,6 +17,7 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_ #define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
@ -26,6 +27,7 @@ limitations under the License.
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/port/requires.h" #include "mediapipe/framework/port/requires.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
@ -56,7 +58,9 @@ class TaskApiFactory {
static absl::StatusOr<std::unique_ptr<T>> Create( static absl::StatusOr<std::unique_ptr<T>> Create(
CalculatorGraphConfig graph_config, CalculatorGraphConfig graph_config,
std::unique_ptr<tflite::OpResolver> resolver, std::unique_ptr<tflite::OpResolver> resolver,
PacketsCallback packets_callback = nullptr) { PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt) {
bool found_task_subgraph = false; bool found_task_subgraph = false;
// This for-loop ensures there's only one subgraph besides // This for-loop ensures there's only one subgraph besides
// FlowLimiterCalculator. // FlowLimiterCalculator.
@ -77,7 +81,9 @@ class TaskApiFactory {
MP_ASSIGN_OR_RETURN( MP_ASSIGN_OR_RETURN(
auto runner, auto runner,
core::TaskRunner::Create(std::move(graph_config), std::move(resolver), core::TaskRunner::Create(std::move(graph_config), std::move(resolver),
std::move(packets_callback))); std::move(packets_callback),
std::move(default_executor),
std::move(input_side_packets)));
return std::make_unique<T>(std::move(runner)); return std::make_unique<T>(std::move(runner));
} }

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <iterator> #include <iterator>
#include <map> #include <map>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -33,6 +34,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/tool/name_util.h" #include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h"
@ -89,17 +91,22 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create( absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver, std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback) { PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) {
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback)); auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(task_runner->Initialize(
task_runner->Initialize(std::move(config), std::move(op_resolver))); std::move(config), std::move(op_resolver), std::move(default_executor),
std::move(input_side_packets)));
MP_RETURN_IF_ERROR(task_runner->Start()); MP_RETURN_IF_ERROR(task_runner->Start());
return task_runner; return task_runner;
} }
absl::Status TaskRunner::Initialize( absl::Status TaskRunner::Initialize(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver) { std::unique_ptr<tflite::OpResolver> op_resolver,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) {
if (initialized_) { if (initialized_) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -123,7 +130,9 @@ absl::Status TaskRunner::Initialize(
MediaPipeTasksStatus::kRunnerInitializationError); MediaPipeTasksStatus::kRunnerInitializationError);
} }
config.clear_output_stream(); config.clear_output_stream();
PacketMap input_side_packets; if (!input_side_packets) {
input_side_packets.emplace();
}
if (packets_callback_) { if (packets_callback_) {
tool::AddMultiStreamCallback( tool::AddMultiStreamCallback(
output_stream_names_, output_stream_names_,
@ -132,7 +141,7 @@ absl::Status TaskRunner::Initialize(
GenerateOutputPacketMap(packets, output_stream_names_)); GenerateOutputPacketMap(packets, output_stream_names_));
return; return;
}, },
&config, &input_side_packets, &config, &input_side_packets.value(),
/*observe_timestamp_bounds=*/true); /*observe_timestamp_bounds=*/true);
} else { } else {
mediapipe::tool::AddMultiStreamCallback( mediapipe::tool::AddMultiStreamCallback(
@ -142,8 +151,14 @@ absl::Status TaskRunner::Initialize(
GenerateOutputPacketMap(packets, output_stream_names_); GenerateOutputPacketMap(packets, output_stream_names_);
return; return;
}, },
&config, &input_side_packets, /*observe_timestamp_bounds=*/true); &config, &input_side_packets.value(),
/*observe_timestamp_bounds=*/true);
} }
if (default_executor) {
MP_RETURN_IF_ERROR(graph_.SetExecutor("", std::move(default_executor)));
}
auto model_resources_cache = auto model_resources_cache =
std::make_shared<ModelResourcesCache>(std::move(op_resolver)); std::make_shared<ModelResourcesCache>(std::move(op_resolver));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
@ -152,7 +167,7 @@ absl::Status TaskRunner::Initialize(
"ModelResourcesCacheService is not set up successfully.", "ModelResourcesCacheService is not set up successfully.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError)); MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
AddPayload(graph_.Initialize(std::move(config), input_side_packets), AddPayload(graph_.Initialize(std::move(config), *input_side_packets),
"MediaPipe CalculatorGraph is not successfully initialized.", "MediaPipe CalculatorGraph is not successfully initialized.",
MediaPipeTasksStatus::kRunnerInitializationError)); MediaPipeTasksStatus::kRunnerInitializationError));
initialized_ = true; initialized_ = true;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <functional> #include <functional>
#include <map> #include <map>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <vector> #include <vector>
@ -34,6 +35,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h"
@ -73,7 +75,9 @@ class TaskRunner {
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create( static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr, std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr); PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt);
// TaskRunner is neither copyable nor movable. // TaskRunner is neither copyable nor movable.
TaskRunner(const TaskRunner&) = delete; TaskRunner(const TaskRunner&) = delete;
@ -125,7 +129,9 @@ class TaskRunner {
// be only initialized once. // be only initialized once.
absl::Status Initialize( absl::Status Initialize(
CalculatorGraphConfig config, CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr); std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt);
// Starts the task runner. Returns an ok status to indicate that the // Starts the task runner. Returns an ok status to indicate that the
// runner is ready to accept input data. Otherwise, returns an error status to // runner is ready to accept input data. Otherwise, returns an error status to