Plumb an optional default Executor and set of input side packets

through TaskApiFactory::Create so that consumers of that API
can provide these inputs to their underlying graph.

PiperOrigin-RevId: 574503266
This commit is contained in:
MediaPipe Team 2023-10-18 09:41:28 -07:00 committed by Copybara-Service
parent e27bbf15dc
commit 2bd6726c89
4 changed files with 41 additions and 12 deletions

View File

@ -261,6 +261,7 @@ cc_library_with_tflite(
deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:executor",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:name_util",
"//mediapipe/tasks/cc:common",
@ -319,6 +320,7 @@ cc_library(
":task_runner",
":utils",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:executor",
"//mediapipe/framework/port:requires",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",

View File

@ -17,6 +17,7 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>
@ -26,6 +27,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/port/requires.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
@ -56,7 +58,9 @@ class TaskApiFactory {
static absl::StatusOr<std::unique_ptr<T>> Create(
CalculatorGraphConfig graph_config,
std::unique_ptr<tflite::OpResolver> resolver,
PacketsCallback packets_callback = nullptr) {
PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt) {
bool found_task_subgraph = false;
// This for-loop ensures there's only one subgraph besides
// FlowLimiterCalculator.
@ -77,7 +81,9 @@ class TaskApiFactory {
MP_ASSIGN_OR_RETURN(
auto runner,
core::TaskRunner::Create(std::move(graph_config), std::move(resolver),
std::move(packets_callback)));
std::move(packets_callback),
std::move(default_executor),
std::move(input_side_packets)));
return std::make_unique<T>(std::move(runner));
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <iterator>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
@ -33,6 +34,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
@ -89,17 +91,22 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver,
PacketsCallback packets_callback) {
PacketsCallback packets_callback,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) {
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
MP_RETURN_IF_ERROR(
task_runner->Initialize(std::move(config), std::move(op_resolver)));
MP_RETURN_IF_ERROR(task_runner->Initialize(
std::move(config), std::move(op_resolver), std::move(default_executor),
std::move(input_side_packets)));
MP_RETURN_IF_ERROR(task_runner->Start());
return task_runner;
}
absl::Status TaskRunner::Initialize(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver) {
std::unique_ptr<tflite::OpResolver> op_resolver,
std::shared_ptr<Executor> default_executor,
std::optional<PacketMap> input_side_packets) {
if (initialized_) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
@ -123,7 +130,9 @@ absl::Status TaskRunner::Initialize(
MediaPipeTasksStatus::kRunnerInitializationError);
}
config.clear_output_stream();
PacketMap input_side_packets;
if (!input_side_packets) {
input_side_packets.emplace();
}
if (packets_callback_) {
tool::AddMultiStreamCallback(
output_stream_names_,
@ -132,7 +141,7 @@ absl::Status TaskRunner::Initialize(
GenerateOutputPacketMap(packets, output_stream_names_));
return;
},
&config, &input_side_packets,
&config, &input_side_packets.value(),
/*observe_timestamp_bounds=*/true);
} else {
mediapipe::tool::AddMultiStreamCallback(
@ -142,8 +151,14 @@ absl::Status TaskRunner::Initialize(
GenerateOutputPacketMap(packets, output_stream_names_);
return;
},
&config, &input_side_packets, /*observe_timestamp_bounds=*/true);
&config, &input_side_packets.value(),
/*observe_timestamp_bounds=*/true);
}
if (default_executor) {
MP_RETURN_IF_ERROR(graph_.SetExecutor("", std::move(default_executor)));
}
auto model_resources_cache =
std::make_shared<ModelResourcesCache>(std::move(op_resolver));
MP_RETURN_IF_ERROR(
@ -152,7 +167,7 @@ absl::Status TaskRunner::Initialize(
"ModelResourcesCacheService is not set up successfully.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError));
MP_RETURN_IF_ERROR(
AddPayload(graph_.Initialize(std::move(config), input_side_packets),
AddPayload(graph_.Initialize(std::move(config), *input_side_packets),
"MediaPipe CalculatorGraph is not successfully initialized.",
MediaPipeTasksStatus::kRunnerInitializationError));
initialized_ = true;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>
@ -34,6 +35,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
@ -73,7 +75,9 @@ class TaskRunner {
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
PacketsCallback packets_callback = nullptr);
PacketsCallback packets_callback = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt);
// TaskRunner is neither copyable nor movable.
TaskRunner(const TaskRunner&) = delete;
@ -125,7 +129,9 @@ class TaskRunner {
// be only initialized once.
absl::Status Initialize(
CalculatorGraphConfig config,
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
std::shared_ptr<Executor> default_executor = nullptr,
std::optional<PacketMap> input_side_packets = std::nullopt);
// Starts the task runner. Returns an ok status to indicate that the
// runner is ready to accept input data. Otherwise, returns an error status to