Plumb an optional default Executor and set of input side packets
through TaskApiFactory::Create so that consumers of that API can provide these inputs to their underlying graph. PiperOrigin-RevId: 574503266
This commit is contained in:
parent
e27bbf15dc
commit
2bd6726c89
|
@ -261,6 +261,7 @@ cc_library_with_tflite(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:executor",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/tool:name_util",
|
"//mediapipe/framework/tool:name_util",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
|
@ -319,6 +320,7 @@ cc_library(
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":utils",
|
":utils",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:executor",
|
||||||
"//mediapipe/framework/port:requires",
|
"//mediapipe/framework/port:requires",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_
|
#define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/executor.h"
|
||||||
#include "mediapipe/framework/port/requires.h"
|
#include "mediapipe/framework/port/requires.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
@ -56,7 +58,9 @@ class TaskApiFactory {
|
||||||
static absl::StatusOr<std::unique_ptr<T>> Create(
|
static absl::StatusOr<std::unique_ptr<T>> Create(
|
||||||
CalculatorGraphConfig graph_config,
|
CalculatorGraphConfig graph_config,
|
||||||
std::unique_ptr<tflite::OpResolver> resolver,
|
std::unique_ptr<tflite::OpResolver> resolver,
|
||||||
PacketsCallback packets_callback = nullptr) {
|
PacketsCallback packets_callback = nullptr,
|
||||||
|
std::shared_ptr<Executor> default_executor = nullptr,
|
||||||
|
std::optional<PacketMap> input_side_packets = std::nullopt) {
|
||||||
bool found_task_subgraph = false;
|
bool found_task_subgraph = false;
|
||||||
// This for-loop ensures there's only one subgraph besides
|
// This for-loop ensures there's only one subgraph besides
|
||||||
// FlowLimiterCalculator.
|
// FlowLimiterCalculator.
|
||||||
|
@ -77,7 +81,9 @@ class TaskApiFactory {
|
||||||
MP_ASSIGN_OR_RETURN(
|
MP_ASSIGN_OR_RETURN(
|
||||||
auto runner,
|
auto runner,
|
||||||
core::TaskRunner::Create(std::move(graph_config), std::move(resolver),
|
core::TaskRunner::Create(std::move(graph_config), std::move(resolver),
|
||||||
std::move(packets_callback)));
|
std::move(packets_callback),
|
||||||
|
std::move(default_executor),
|
||||||
|
std::move(input_side_packets)));
|
||||||
return std::make_unique<T>(std::move(runner));
|
return std::make_unique<T>(std::move(runner));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -33,6 +34,7 @@ limitations under the License.
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/executor.h"
|
||||||
#include "mediapipe/framework/tool/name_util.h"
|
#include "mediapipe/framework/tool/name_util.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
|
@ -89,17 +91,22 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
|
||||||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||||
CalculatorGraphConfig config,
|
CalculatorGraphConfig config,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||||
PacketsCallback packets_callback) {
|
PacketsCallback packets_callback,
|
||||||
|
std::shared_ptr<Executor> default_executor,
|
||||||
|
std::optional<PacketMap> input_side_packets) {
|
||||||
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(task_runner->Initialize(
|
||||||
task_runner->Initialize(std::move(config), std::move(op_resolver)));
|
std::move(config), std::move(op_resolver), std::move(default_executor),
|
||||||
|
std::move(input_side_packets)));
|
||||||
MP_RETURN_IF_ERROR(task_runner->Start());
|
MP_RETURN_IF_ERROR(task_runner->Start());
|
||||||
return task_runner;
|
return task_runner;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TaskRunner::Initialize(
|
absl::Status TaskRunner::Initialize(
|
||||||
CalculatorGraphConfig config,
|
CalculatorGraphConfig config,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver) {
|
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||||
|
std::shared_ptr<Executor> default_executor,
|
||||||
|
std::optional<PacketMap> input_side_packets) {
|
||||||
if (initialized_) {
|
if (initialized_) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
@ -123,7 +130,9 @@ absl::Status TaskRunner::Initialize(
|
||||||
MediaPipeTasksStatus::kRunnerInitializationError);
|
MediaPipeTasksStatus::kRunnerInitializationError);
|
||||||
}
|
}
|
||||||
config.clear_output_stream();
|
config.clear_output_stream();
|
||||||
PacketMap input_side_packets;
|
if (!input_side_packets) {
|
||||||
|
input_side_packets.emplace();
|
||||||
|
}
|
||||||
if (packets_callback_) {
|
if (packets_callback_) {
|
||||||
tool::AddMultiStreamCallback(
|
tool::AddMultiStreamCallback(
|
||||||
output_stream_names_,
|
output_stream_names_,
|
||||||
|
@ -132,7 +141,7 @@ absl::Status TaskRunner::Initialize(
|
||||||
GenerateOutputPacketMap(packets, output_stream_names_));
|
GenerateOutputPacketMap(packets, output_stream_names_));
|
||||||
return;
|
return;
|
||||||
},
|
},
|
||||||
&config, &input_side_packets,
|
&config, &input_side_packets.value(),
|
||||||
/*observe_timestamp_bounds=*/true);
|
/*observe_timestamp_bounds=*/true);
|
||||||
} else {
|
} else {
|
||||||
mediapipe::tool::AddMultiStreamCallback(
|
mediapipe::tool::AddMultiStreamCallback(
|
||||||
|
@ -142,8 +151,14 @@ absl::Status TaskRunner::Initialize(
|
||||||
GenerateOutputPacketMap(packets, output_stream_names_);
|
GenerateOutputPacketMap(packets, output_stream_names_);
|
||||||
return;
|
return;
|
||||||
},
|
},
|
||||||
&config, &input_side_packets, /*observe_timestamp_bounds=*/true);
|
&config, &input_side_packets.value(),
|
||||||
|
/*observe_timestamp_bounds=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (default_executor) {
|
||||||
|
MP_RETURN_IF_ERROR(graph_.SetExecutor("", std::move(default_executor)));
|
||||||
|
}
|
||||||
|
|
||||||
auto model_resources_cache =
|
auto model_resources_cache =
|
||||||
std::make_shared<ModelResourcesCache>(std::move(op_resolver));
|
std::make_shared<ModelResourcesCache>(std::move(op_resolver));
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
|
@ -152,7 +167,7 @@ absl::Status TaskRunner::Initialize(
|
||||||
"ModelResourcesCacheService is not set up successfully.",
|
"ModelResourcesCacheService is not set up successfully.",
|
||||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError));
|
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError));
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
AddPayload(graph_.Initialize(std::move(config), input_side_packets),
|
AddPayload(graph_.Initialize(std::move(config), *input_side_packets),
|
||||||
"MediaPipe CalculatorGraph is not successfully initialized.",
|
"MediaPipe CalculatorGraph is not successfully initialized.",
|
||||||
MediaPipeTasksStatus::kRunnerInitializationError));
|
MediaPipeTasksStatus::kRunnerInitializationError));
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -34,6 +35,7 @@ limitations under the License.
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/executor.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
|
@ -73,7 +75,9 @@ class TaskRunner {
|
||||||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||||
CalculatorGraphConfig config,
|
CalculatorGraphConfig config,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||||
PacketsCallback packets_callback = nullptr);
|
PacketsCallback packets_callback = nullptr,
|
||||||
|
std::shared_ptr<Executor> default_executor = nullptr,
|
||||||
|
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||||
|
|
||||||
// TaskRunner is neither copyable nor movable.
|
// TaskRunner is neither copyable nor movable.
|
||||||
TaskRunner(const TaskRunner&) = delete;
|
TaskRunner(const TaskRunner&) = delete;
|
||||||
|
@ -125,7 +129,9 @@ class TaskRunner {
|
||||||
// be only initialized once.
|
// be only initialized once.
|
||||||
absl::Status Initialize(
|
absl::Status Initialize(
|
||||||
CalculatorGraphConfig config,
|
CalculatorGraphConfig config,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
|
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||||
|
std::shared_ptr<Executor> default_executor = nullptr,
|
||||||
|
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||||
|
|
||||||
// Starts the task runner. Returns an ok status to indicate that the
|
// Starts the task runner. Returns an ok status to indicate that the
|
||||||
// runner is ready to accept input data. Otherwise, returns an error status to
|
// runner is ready to accept input data. Otherwise, returns an error status to
|
||||||
|
|
Loading…
Reference in New Issue
Block a user