Plumb an optional default Executor and set of input side packets
through TaskApiFactory::Create so that consumers of that API can provide these inputs to their underlying graph. PiperOrigin-RevId: 574503266
This commit is contained in:
parent
e27bbf15dc
commit
2bd6726c89
|
@ -261,6 +261,7 @@ cc_library_with_tflite(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:executor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:name_util",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
|
@ -319,6 +320,7 @@ cc_library(
|
|||
":task_runner",
|
||||
":utils",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:executor",
|
||||
"//mediapipe/framework/port:requires",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
#define MEDIAPIPE_TASKS_CC_CORE_TASK_API_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
|||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/executor.h"
|
||||
#include "mediapipe/framework/port/requires.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
|
@ -56,7 +58,9 @@ class TaskApiFactory {
|
|||
static absl::StatusOr<std::unique_ptr<T>> Create(
|
||||
CalculatorGraphConfig graph_config,
|
||||
std::unique_ptr<tflite::OpResolver> resolver,
|
||||
PacketsCallback packets_callback = nullptr) {
|
||||
PacketsCallback packets_callback = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt) {
|
||||
bool found_task_subgraph = false;
|
||||
// This for-loop ensures there's only one subgraph besides
|
||||
// FlowLimiterCalculator.
|
||||
|
@ -77,7 +81,9 @@ class TaskApiFactory {
|
|||
MP_ASSIGN_OR_RETURN(
|
||||
auto runner,
|
||||
core::TaskRunner::Create(std::move(graph_config), std::move(resolver),
|
||||
std::move(packets_callback)));
|
||||
std::move(packets_callback),
|
||||
std::move(default_executor),
|
||||
std::move(input_side_packets)));
|
||||
return std::make_unique<T>(std::move(runner));
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <iterator>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -33,6 +34,7 @@ limitations under the License.
|
|||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/executor.h"
|
||||
#include "mediapipe/framework/tool/name_util.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
|
@ -89,17 +91,22 @@ absl::StatusOr<PacketMap> GenerateOutputPacketMap(
|
|||
absl::StatusOr<std::unique_ptr<TaskRunner>> TaskRunner::Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||
PacketsCallback packets_callback) {
|
||||
PacketsCallback packets_callback,
|
||||
std::shared_ptr<Executor> default_executor,
|
||||
std::optional<PacketMap> input_side_packets) {
|
||||
auto task_runner = absl::WrapUnique(new TaskRunner(packets_callback));
|
||||
MP_RETURN_IF_ERROR(
|
||||
task_runner->Initialize(std::move(config), std::move(op_resolver)));
|
||||
MP_RETURN_IF_ERROR(task_runner->Initialize(
|
||||
std::move(config), std::move(op_resolver), std::move(default_executor),
|
||||
std::move(input_side_packets)));
|
||||
MP_RETURN_IF_ERROR(task_runner->Start());
|
||||
return task_runner;
|
||||
}
|
||||
|
||||
absl::Status TaskRunner::Initialize(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver) {
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver,
|
||||
std::shared_ptr<Executor> default_executor,
|
||||
std::optional<PacketMap> input_side_packets) {
|
||||
if (initialized_) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
|
@ -123,7 +130,9 @@ absl::Status TaskRunner::Initialize(
|
|||
MediaPipeTasksStatus::kRunnerInitializationError);
|
||||
}
|
||||
config.clear_output_stream();
|
||||
PacketMap input_side_packets;
|
||||
if (!input_side_packets) {
|
||||
input_side_packets.emplace();
|
||||
}
|
||||
if (packets_callback_) {
|
||||
tool::AddMultiStreamCallback(
|
||||
output_stream_names_,
|
||||
|
@ -132,7 +141,7 @@ absl::Status TaskRunner::Initialize(
|
|||
GenerateOutputPacketMap(packets, output_stream_names_));
|
||||
return;
|
||||
},
|
||||
&config, &input_side_packets,
|
||||
&config, &input_side_packets.value(),
|
||||
/*observe_timestamp_bounds=*/true);
|
||||
} else {
|
||||
mediapipe::tool::AddMultiStreamCallback(
|
||||
|
@ -142,8 +151,14 @@ absl::Status TaskRunner::Initialize(
|
|||
GenerateOutputPacketMap(packets, output_stream_names_);
|
||||
return;
|
||||
},
|
||||
&config, &input_side_packets, /*observe_timestamp_bounds=*/true);
|
||||
&config, &input_side_packets.value(),
|
||||
/*observe_timestamp_bounds=*/true);
|
||||
}
|
||||
|
||||
if (default_executor) {
|
||||
MP_RETURN_IF_ERROR(graph_.SetExecutor("", std::move(default_executor)));
|
||||
}
|
||||
|
||||
auto model_resources_cache =
|
||||
std::make_shared<ModelResourcesCache>(std::move(op_resolver));
|
||||
MP_RETURN_IF_ERROR(
|
||||
|
@ -152,7 +167,7 @@ absl::Status TaskRunner::Initialize(
|
|||
"ModelResourcesCacheService is not set up successfully.",
|
||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError));
|
||||
MP_RETURN_IF_ERROR(
|
||||
AddPayload(graph_.Initialize(std::move(config), input_side_packets),
|
||||
AddPayload(graph_.Initialize(std::move(config), *input_side_packets),
|
||||
"MediaPipe CalculatorGraph is not successfully initialized.",
|
||||
MediaPipeTasksStatus::kRunnerInitializationError));
|
||||
initialized_ = true;
|
||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
@ -34,6 +35,7 @@ limitations under the License.
|
|||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/executor.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
|
@ -73,7 +75,9 @@ class TaskRunner {
|
|||
static absl::StatusOr<std::unique_ptr<TaskRunner>> Create(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||
PacketsCallback packets_callback = nullptr);
|
||||
PacketsCallback packets_callback = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||
|
||||
// TaskRunner is neither copyable nor movable.
|
||||
TaskRunner(const TaskRunner&) = delete;
|
||||
|
@ -125,7 +129,9 @@ class TaskRunner {
|
|||
// be only initialized once.
|
||||
absl::Status Initialize(
|
||||
CalculatorGraphConfig config,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr);
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver = nullptr,
|
||||
std::shared_ptr<Executor> default_executor = nullptr,
|
||||
std::optional<PacketMap> input_side_packets = std::nullopt);
|
||||
|
||||
// Starts the task runner. Returns an ok status to indicate that the
|
||||
// runner is ready to accept input data. Otherwise, returns an error status to
|
||||
|
|
Loading…
Reference in New Issue
Block a user