mediapipe/mediapipe/framework/graph_service.h
MediaPipe Team 7fb37c80e8 Project import generated by Copybara.
GitOrigin-RevId: 19a829ffd755edb43e54d20c0e7b9348512d5108
2022-05-05 19:57:20 +00:00

122 lines
3.7 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_
#define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_
#include <memory>
#include <type_traits>
#include <utility>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
// The GraphService API can be used to define extensions to a graph's execution
// environment. These are, essentially, graph-level singletons, and are
// available to all calculators in the graph (and in any subgraphs) without
// requiring a manual connection.
//
// IMPORTANT: this is an experimental API. Get in touch with the MediaPipe team
// if you want to use it. In most cases, you should use a side packet instead.
class GraphServiceBase {
public:
// TODO: fix services for which default init is broken, remove
// this setting.
enum DefaultInitSupport {
kAllowDefaultInitialization,
kDisallowDefaultInitialization
};
constexpr GraphServiceBase(const char* key) : key(key) {}
virtual ~GraphServiceBase() = default;
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
return DefaultInitializationUnsupported();
}
const char* key;
protected:
absl::Status DefaultInitializationUnsupported() const {
return absl::UnimplementedError(absl::StrCat(
"Graph service '", key, "' does not support default initialization"));
}
};
template <typename T>
class GraphService : public GraphServiceBase {
public:
using type = T;
using packet_type = std::shared_ptr<T>;
constexpr GraphService(const char* my_key, DefaultInitSupport default_init =
kDisallowDefaultInitialization)
: GraphServiceBase(my_key), default_init_(default_init) {}
absl::StatusOr<Packet> CreateDefaultObject() const override {
if (default_init_ != kAllowDefaultInitialization) {
return DefaultInitializationUnsupported();
}
auto packet_or = CreateDefaultObjectInternal();
if (packet_or.ok()) {
return MakePacket<std::shared_ptr<T>>(std::move(packet_or).value());
} else {
return packet_or.status();
}
}
private:
absl::StatusOr<std::shared_ptr<T>> CreateDefaultObjectInternal() const {
auto call_create = [](auto x) -> decltype(decltype(x)::type::Create()) {
return decltype(x)::type::Create();
};
if constexpr (std::is_invocable_r_v<absl::StatusOr<std::shared_ptr<T>>,
decltype(call_create), type_tag<T>>) {
return T::Create();
}
if constexpr (std::is_default_constructible_v<T>) {
return std::make_shared<T>();
}
return DefaultInitializationUnsupported();
}
template <class U>
struct type_tag {
using type = U;
};
DefaultInitSupport default_init_;
};
template <typename T>
class ServiceBinding {
public:
bool IsAvailable() { return service_ != nullptr; }
T& GetObject() { return *service_; }
ServiceBinding() {}
explicit ServiceBinding(std::shared_ptr<T> service) : service_(service) {}
private:
std::shared_ptr<T> service_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_