Internal change
PiperOrigin-RevId: 488770794
This commit is contained in:
parent
7a87546c30
commit
38b636f7ee
|
@ -225,6 +225,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/meta:type_traits",
|
"@com_google_absl//absl/meta:type_traits",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -26,10 +26,12 @@
|
||||||
|
|
||||||
#include "absl/base/macros.h"
|
#include "absl/base/macros.h"
|
||||||
#include "absl/base/thread_annotations.h"
|
#include "absl/base/thread_annotations.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/meta/type_traits.h"
|
#include "absl/meta/type_traits.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "mediapipe/framework/deps/registration_token.h"
|
#include "mediapipe/framework/deps/registration_token.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
|
@ -159,7 +161,7 @@ class FunctionRegistry {
|
||||||
FunctionRegistry(const FunctionRegistry&) = delete;
|
FunctionRegistry(const FunctionRegistry&) = delete;
|
||||||
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
|
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
|
||||||
|
|
||||||
RegistrationToken Register(const std::string& name, Function func)
|
RegistrationToken Register(absl::string_view name, Function func)
|
||||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||||
std::string normalized_name = GetNormalizedName(name);
|
std::string normalized_name = GetNormalizedName(name);
|
||||||
absl::WriterMutexLock lock(&lock_);
|
absl::WriterMutexLock lock(&lock_);
|
||||||
|
@ -189,14 +191,15 @@ class FunctionRegistry {
|
||||||
absl::enable_if_t<std::is_convertible<std::tuple<Args2...>,
|
absl::enable_if_t<std::is_convertible<std::tuple<Args2...>,
|
||||||
std::tuple<Args...>>::value,
|
std::tuple<Args...>>::value,
|
||||||
int> = 0>
|
int> = 0>
|
||||||
ReturnType Invoke(const std::string& name, Args2&&... args)
|
ReturnType Invoke(absl::string_view name, Args2&&... args)
|
||||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||||
Function function;
|
Function function;
|
||||||
{
|
{
|
||||||
absl::ReaderMutexLock lock(&lock_);
|
absl::ReaderMutexLock lock(&lock_);
|
||||||
auto it = functions_.find(name);
|
auto it = functions_.find(name);
|
||||||
if (it == functions_.end()) {
|
if (it == functions_.end()) {
|
||||||
return absl::NotFoundError("No registered object with name: " + name);
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("No registered object with name: ", name));
|
||||||
}
|
}
|
||||||
function = it->second;
|
function = it->second;
|
||||||
}
|
}
|
||||||
|
@ -206,7 +209,7 @@ class FunctionRegistry {
|
||||||
// Invokes the specified factory function and returns the result.
|
// Invokes the specified factory function and returns the result.
|
||||||
// Namespaces in |name| and |ns| are separated by kNameSep.
|
// Namespaces in |name| and |ns| are separated by kNameSep.
|
||||||
template <typename... Args2>
|
template <typename... Args2>
|
||||||
ReturnType Invoke(const std::string& ns, const std::string& name,
|
ReturnType Invoke(absl::string_view ns, absl::string_view name,
|
||||||
Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) {
|
Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) {
|
||||||
return Invoke(GetQualifiedName(ns, name), args...);
|
return Invoke(GetQualifiedName(ns, name), args...);
|
||||||
}
|
}
|
||||||
|
@ -214,14 +217,14 @@ class FunctionRegistry {
|
||||||
// Note that it's possible for registered implementations to be subsequently
|
// Note that it's possible for registered implementations to be subsequently
|
||||||
// unregistered, though this will never happen with registrations made via
|
// unregistered, though this will never happen with registrations made via
|
||||||
// MEDIAPIPE_REGISTER_FACTORY_FUNCTION.
|
// MEDIAPIPE_REGISTER_FACTORY_FUNCTION.
|
||||||
bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) {
|
bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) {
|
||||||
absl::ReaderMutexLock lock(&lock_);
|
absl::ReaderMutexLock lock(&lock_);
|
||||||
return functions_.count(name) != 0;
|
return functions_.count(name) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the specified factory function is available.
|
// Returns true if the specified factory function is available.
|
||||||
// Namespaces in |name| and |ns| are separated by kNameSep.
|
// Namespaces in |name| and |ns| are separated by kNameSep.
|
||||||
bool IsRegistered(const std::string& ns, const std::string& name) const
|
bool IsRegistered(absl::string_view ns, absl::string_view name) const
|
||||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||||
return IsRegistered(GetQualifiedName(ns, name));
|
return IsRegistered(GetQualifiedName(ns, name));
|
||||||
}
|
}
|
||||||
|
@ -244,7 +247,7 @@ class FunctionRegistry {
|
||||||
// Normalizes a C++ qualified name. Validates the name qualification.
|
// Normalizes a C++ qualified name. Validates the name qualification.
|
||||||
// The name must be either unqualified or fully qualified with a leading "::".
|
// The name must be either unqualified or fully qualified with a leading "::".
|
||||||
// The leading "::" in a fully qualified name is stripped.
|
// The leading "::" in a fully qualified name is stripped.
|
||||||
std::string GetNormalizedName(const std::string& name) {
|
std::string GetNormalizedName(absl::string_view name) {
|
||||||
using ::mediapipe::registration_internal::kCxxSep;
|
using ::mediapipe::registration_internal::kCxxSep;
|
||||||
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
|
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
|
||||||
if (names[0].empty()) {
|
if (names[0].empty()) {
|
||||||
|
@ -259,8 +262,8 @@ class FunctionRegistry {
|
||||||
|
|
||||||
// Returns the registry key for a name specified within a namespace.
|
// Returns the registry key for a name specified within a namespace.
|
||||||
// Namespaces are separated by kNameSep.
|
// Namespaces are separated by kNameSep.
|
||||||
std::string GetQualifiedName(const std::string& ns,
|
std::string GetQualifiedName(absl::string_view ns,
|
||||||
const std::string& name) const {
|
absl::string_view name) const {
|
||||||
using ::mediapipe::registration_internal::kCxxSep;
|
using ::mediapipe::registration_internal::kCxxSep;
|
||||||
using ::mediapipe::registration_internal::kNameSep;
|
using ::mediapipe::registration_internal::kNameSep;
|
||||||
std::vector<std::string> names = absl::StrSplit(name, kNameSep);
|
std::vector<std::string> names = absl::StrSplit(name, kNameSep);
|
||||||
|
@ -287,10 +290,10 @@ class FunctionRegistry {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutable absl::Mutex lock_;
|
mutable absl::Mutex lock_;
|
||||||
std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
||||||
|
|
||||||
// For names included in NamespaceAllowlist, strips the namespace.
|
// For names included in NamespaceAllowlist, strips the namespace.
|
||||||
std::string GetAdjustedName(const std::string& name) {
|
std::string GetAdjustedName(absl::string_view name) {
|
||||||
using ::mediapipe::registration_internal::kCxxSep;
|
using ::mediapipe::registration_internal::kCxxSep;
|
||||||
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
|
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
|
||||||
std::string base_name = names.back();
|
std::string base_name = names.back();
|
||||||
|
@ -299,10 +302,10 @@ class FunctionRegistry {
|
||||||
if (NamespaceAllowlist::TopNamespaces().count(ns)) {
|
if (NamespaceAllowlist::TopNamespaces().count(ns)) {
|
||||||
return base_name;
|
return base_name;
|
||||||
}
|
}
|
||||||
return name;
|
return std::string(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unregister(const std::string& name) {
|
void Unregister(absl::string_view name) {
|
||||||
absl::WriterMutexLock lock(&lock_);
|
absl::WriterMutexLock lock(&lock_);
|
||||||
std::string adjusted_name = GetAdjustedName(name);
|
std::string adjusted_name = GetAdjustedName(name);
|
||||||
if (adjusted_name != name) {
|
if (adjusted_name != name) {
|
||||||
|
@ -317,7 +320,7 @@ class GlobalFactoryRegistry {
|
||||||
using Functions = FunctionRegistry<R, Args...>;
|
using Functions = FunctionRegistry<R, Args...>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static RegistrationToken Register(const std::string& name,
|
static RegistrationToken Register(absl::string_view name,
|
||||||
typename Functions::Function func) {
|
typename Functions::Function func) {
|
||||||
return functions()->Register(name, std::move(func));
|
return functions()->Register(name, std::move(func));
|
||||||
}
|
}
|
||||||
|
@ -326,7 +329,7 @@ class GlobalFactoryRegistry {
|
||||||
// If using namespaces with this registry, the variant with a namespace
|
// If using namespaces with this registry, the variant with a namespace
|
||||||
// argument should be used.
|
// argument should be used.
|
||||||
template <typename... Args2>
|
template <typename... Args2>
|
||||||
static typename Functions::ReturnType CreateByName(const std::string& name,
|
static typename Functions::ReturnType CreateByName(absl::string_view name,
|
||||||
Args2&&... args) {
|
Args2&&... args) {
|
||||||
return functions()->Invoke(name, std::forward<Args2>(args)...);
|
return functions()->Invoke(name, std::forward<Args2>(args)...);
|
||||||
}
|
}
|
||||||
|
@ -334,7 +337,7 @@ class GlobalFactoryRegistry {
|
||||||
// Returns true if the specified factory function is available.
|
// Returns true if the specified factory function is available.
|
||||||
// If using namespaces with this registry, the variant with a namespace
|
// If using namespaces with this registry, the variant with a namespace
|
||||||
// argument should be used.
|
// argument should be used.
|
||||||
static bool IsRegistered(const std::string& name) {
|
static bool IsRegistered(absl::string_view name) {
|
||||||
return functions()->IsRegistered(name);
|
return functions()->IsRegistered(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,13 +353,13 @@ class GlobalFactoryRegistry {
|
||||||
std::tuple<Args...>>::value,
|
std::tuple<Args...>>::value,
|
||||||
int> = 0>
|
int> = 0>
|
||||||
static typename Functions::ReturnType CreateByNameInNamespace(
|
static typename Functions::ReturnType CreateByNameInNamespace(
|
||||||
const std::string& ns, const std::string& name, Args2&&... args) {
|
absl::string_view ns, absl::string_view name, Args2&&... args) {
|
||||||
return functions()->Invoke(ns, name, std::forward<Args2>(args)...);
|
return functions()->Invoke(ns, name, std::forward<Args2>(args)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the specified factory function is available.
|
// Returns true if the specified factory function is available.
|
||||||
// Namespaces in |name| and |ns| are separated by kNameSep.
|
// Namespaces in |name| and |ns| are separated by kNameSep.
|
||||||
static bool IsRegistered(const std::string& ns, const std::string& name) {
|
static bool IsRegistered(absl::string_view ns, absl::string_view name) {
|
||||||
return functions()->IsRegistered(ns, name);
|
return functions()->IsRegistered(ns, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user