Internal change
PiperOrigin-RevId: 488770794
This commit is contained in:
parent
7a87546c30
commit
38b636f7ee
|
@ -225,6 +225,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/meta:type_traits",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -26,10 +26,12 @@
|
|||
|
||||
#include "absl/base/macros.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/deps/registration_token.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
|
@ -159,7 +161,7 @@ class FunctionRegistry {
|
|||
FunctionRegistry(const FunctionRegistry&) = delete;
|
||||
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
|
||||
|
||||
RegistrationToken Register(const std::string& name, Function func)
|
||||
RegistrationToken Register(absl::string_view name, Function func)
|
||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
std::string normalized_name = GetNormalizedName(name);
|
||||
absl::WriterMutexLock lock(&lock_);
|
||||
|
@ -189,14 +191,15 @@ class FunctionRegistry {
|
|||
absl::enable_if_t<std::is_convertible<std::tuple<Args2...>,
|
||||
std::tuple<Args...>>::value,
|
||||
int> = 0>
|
||||
ReturnType Invoke(const std::string& name, Args2&&... args)
|
||||
ReturnType Invoke(absl::string_view name, Args2&&... args)
|
||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
Function function;
|
||||
{
|
||||
absl::ReaderMutexLock lock(&lock_);
|
||||
auto it = functions_.find(name);
|
||||
if (it == functions_.end()) {
|
||||
return absl::NotFoundError("No registered object with name: " + name);
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No registered object with name: ", name));
|
||||
}
|
||||
function = it->second;
|
||||
}
|
||||
|
@ -206,7 +209,7 @@ class FunctionRegistry {
|
|||
// Invokes the specified factory function and returns the result.
|
||||
// Namespaces in |name| and |ns| are separated by kNameSep.
|
||||
template <typename... Args2>
|
||||
ReturnType Invoke(const std::string& ns, const std::string& name,
|
||||
ReturnType Invoke(absl::string_view ns, absl::string_view name,
|
||||
Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
return Invoke(GetQualifiedName(ns, name), args...);
|
||||
}
|
||||
|
@ -214,14 +217,14 @@ class FunctionRegistry {
|
|||
// Note that it's possible for registered implementations to be subsequently
|
||||
// unregistered, though this will never happen with registrations made via
|
||||
// MEDIAPIPE_REGISTER_FACTORY_FUNCTION.
|
||||
bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
absl::ReaderMutexLock lock(&lock_);
|
||||
return functions_.count(name) != 0;
|
||||
}
|
||||
|
||||
// Returns true if the specified factory function is available.
|
||||
// Namespaces in |name| and |ns| are separated by kNameSep.
|
||||
bool IsRegistered(const std::string& ns, const std::string& name) const
|
||||
bool IsRegistered(absl::string_view ns, absl::string_view name) const
|
||||
ABSL_LOCKS_EXCLUDED(lock_) {
|
||||
return IsRegistered(GetQualifiedName(ns, name));
|
||||
}
|
||||
|
@ -244,7 +247,7 @@ class FunctionRegistry {
|
|||
// Normalizes a C++ qualified name. Validates the name qualification.
|
||||
// The name must be either unqualified or fully qualified with a leading "::".
|
||||
// The leading "::" in a fully qualified name is stripped.
|
||||
std::string GetNormalizedName(const std::string& name) {
|
||||
std::string GetNormalizedName(absl::string_view name) {
|
||||
using ::mediapipe::registration_internal::kCxxSep;
|
||||
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
|
||||
if (names[0].empty()) {
|
||||
|
@ -259,8 +262,8 @@ class FunctionRegistry {
|
|||
|
||||
// Returns the registry key for a name specified within a namespace.
|
||||
// Namespaces are separated by kNameSep.
|
||||
std::string GetQualifiedName(const std::string& ns,
|
||||
const std::string& name) const {
|
||||
std::string GetQualifiedName(absl::string_view ns,
|
||||
absl::string_view name) const {
|
||||
using ::mediapipe::registration_internal::kCxxSep;
|
||||
using ::mediapipe::registration_internal::kNameSep;
|
||||
std::vector<std::string> names = absl::StrSplit(name, kNameSep);
|
||||
|
@ -287,10 +290,10 @@ class FunctionRegistry {
|
|||
|
||||
private:
|
||||
mutable absl::Mutex lock_;
|
||||
std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
||||
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
|
||||
|
||||
// For names included in NamespaceAllowlist, strips the namespace.
|
||||
std::string GetAdjustedName(const std::string& name) {
|
||||
std::string GetAdjustedName(absl::string_view name) {
|
||||
using ::mediapipe::registration_internal::kCxxSep;
|
||||
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
|
||||
std::string base_name = names.back();
|
||||
|
@ -299,10 +302,10 @@ class FunctionRegistry {
|
|||
if (NamespaceAllowlist::TopNamespaces().count(ns)) {
|
||||
return base_name;
|
||||
}
|
||||
return name;
|
||||
return std::string(name);
|
||||
}
|
||||
|
||||
void Unregister(const std::string& name) {
|
||||
void Unregister(absl::string_view name) {
|
||||
absl::WriterMutexLock lock(&lock_);
|
||||
std::string adjusted_name = GetAdjustedName(name);
|
||||
if (adjusted_name != name) {
|
||||
|
@ -317,7 +320,7 @@ class GlobalFactoryRegistry {
|
|||
using Functions = FunctionRegistry<R, Args...>;
|
||||
|
||||
public:
|
||||
static RegistrationToken Register(const std::string& name,
|
||||
static RegistrationToken Register(absl::string_view name,
|
||||
typename Functions::Function func) {
|
||||
return functions()->Register(name, std::move(func));
|
||||
}
|
||||
|
@ -326,7 +329,7 @@ class GlobalFactoryRegistry {
|
|||
// If using namespaces with this registry, the variant with a namespace
|
||||
// argument should be used.
|
||||
template <typename... Args2>
|
||||
static typename Functions::ReturnType CreateByName(const std::string& name,
|
||||
static typename Functions::ReturnType CreateByName(absl::string_view name,
|
||||
Args2&&... args) {
|
||||
return functions()->Invoke(name, std::forward<Args2>(args)...);
|
||||
}
|
||||
|
@ -334,7 +337,7 @@ class GlobalFactoryRegistry {
|
|||
// Returns true if the specified factory function is available.
|
||||
// If using namespaces with this registry, the variant with a namespace
|
||||
// argument should be used.
|
||||
static bool IsRegistered(const std::string& name) {
|
||||
static bool IsRegistered(absl::string_view name) {
|
||||
return functions()->IsRegistered(name);
|
||||
}
|
||||
|
||||
|
@ -350,13 +353,13 @@ class GlobalFactoryRegistry {
|
|||
std::tuple<Args...>>::value,
|
||||
int> = 0>
|
||||
static typename Functions::ReturnType CreateByNameInNamespace(
|
||||
const std::string& ns, const std::string& name, Args2&&... args) {
|
||||
absl::string_view ns, absl::string_view name, Args2&&... args) {
|
||||
return functions()->Invoke(ns, name, std::forward<Args2>(args)...);
|
||||
}
|
||||
|
||||
// Returns true if the specified factory function is available.
|
||||
// Namespaces in |name| and |ns| are separated by kNameSep.
|
||||
static bool IsRegistered(const std::string& ns, const std::string& name) {
|
||||
static bool IsRegistered(absl::string_view ns, absl::string_view name) {
|
||||
return functions()->IsRegistered(ns, name);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user