Internal change

PiperOrigin-RevId: 488770794
This commit is contained in:
MediaPipe Team 2022-11-15 15:10:36 -08:00 committed by Copybara-Service
parent 7a87546c30
commit 38b636f7ee
2 changed files with 22 additions and 18 deletions

View File

@ -225,6 +225,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -26,10 +26,12 @@
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/base/thread_annotations.h" #include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/meta/type_traits.h" #include "absl/meta/type_traits.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/deps/registration_token.h" #include "mediapipe/framework/deps/registration_token.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
@ -159,7 +161,7 @@ class FunctionRegistry {
FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry(const FunctionRegistry&) = delete;
FunctionRegistry& operator=(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete;
RegistrationToken Register(const std::string& name, Function func) RegistrationToken Register(absl::string_view name, Function func)
ABSL_LOCKS_EXCLUDED(lock_) { ABSL_LOCKS_EXCLUDED(lock_) {
std::string normalized_name = GetNormalizedName(name); std::string normalized_name = GetNormalizedName(name);
absl::WriterMutexLock lock(&lock_); absl::WriterMutexLock lock(&lock_);
@ -189,14 +191,15 @@ class FunctionRegistry {
absl::enable_if_t<std::is_convertible<std::tuple<Args2...>, absl::enable_if_t<std::is_convertible<std::tuple<Args2...>,
std::tuple<Args...>>::value, std::tuple<Args...>>::value,
int> = 0> int> = 0>
ReturnType Invoke(const std::string& name, Args2&&... args) ReturnType Invoke(absl::string_view name, Args2&&... args)
ABSL_LOCKS_EXCLUDED(lock_) { ABSL_LOCKS_EXCLUDED(lock_) {
Function function; Function function;
{ {
absl::ReaderMutexLock lock(&lock_); absl::ReaderMutexLock lock(&lock_);
auto it = functions_.find(name); auto it = functions_.find(name);
if (it == functions_.end()) { if (it == functions_.end()) {
return absl::NotFoundError("No registered object with name: " + name); return absl::NotFoundError(
absl::StrCat("No registered object with name: ", name));
} }
function = it->second; function = it->second;
} }
@ -206,7 +209,7 @@ class FunctionRegistry {
// Invokes the specified factory function and returns the result. // Invokes the specified factory function and returns the result.
// Namespaces in |name| and |ns| are separated by kNameSep. // Namespaces in |name| and |ns| are separated by kNameSep.
template <typename... Args2> template <typename... Args2>
ReturnType Invoke(const std::string& ns, const std::string& name, ReturnType Invoke(absl::string_view ns, absl::string_view name,
Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) {
return Invoke(GetQualifiedName(ns, name), args...); return Invoke(GetQualifiedName(ns, name), args...);
} }
@ -214,14 +217,14 @@ class FunctionRegistry {
// Note that it's possible for registered implementations to be subsequently // Note that it's possible for registered implementations to be subsequently
// unregistered, though this will never happen with registrations made via // unregistered, though this will never happen with registrations made via
// MEDIAPIPE_REGISTER_FACTORY_FUNCTION. // MEDIAPIPE_REGISTER_FACTORY_FUNCTION.
bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) { bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) {
absl::ReaderMutexLock lock(&lock_); absl::ReaderMutexLock lock(&lock_);
return functions_.count(name) != 0; return functions_.count(name) != 0;
} }
// Returns true if the specified factory function is available. // Returns true if the specified factory function is available.
// Namespaces in |name| and |ns| are separated by kNameSep. // Namespaces in |name| and |ns| are separated by kNameSep.
bool IsRegistered(const std::string& ns, const std::string& name) const bool IsRegistered(absl::string_view ns, absl::string_view name) const
ABSL_LOCKS_EXCLUDED(lock_) { ABSL_LOCKS_EXCLUDED(lock_) {
return IsRegistered(GetQualifiedName(ns, name)); return IsRegistered(GetQualifiedName(ns, name));
} }
@ -244,7 +247,7 @@ class FunctionRegistry {
// Normalizes a C++ qualified name. Validates the name qualification. // Normalizes a C++ qualified name. Validates the name qualification.
// The name must be either unqualified or fully qualified with a leading "::". // The name must be either unqualified or fully qualified with a leading "::".
// The leading "::" in a fully qualified name is stripped. // The leading "::" in a fully qualified name is stripped.
std::string GetNormalizedName(const std::string& name) { std::string GetNormalizedName(absl::string_view name) {
using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kCxxSep;
std::vector<std::string> names = absl::StrSplit(name, kCxxSep); std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
if (names[0].empty()) { if (names[0].empty()) {
@ -259,8 +262,8 @@ class FunctionRegistry {
// Returns the registry key for a name specified within a namespace. // Returns the registry key for a name specified within a namespace.
// Namespaces are separated by kNameSep. // Namespaces are separated by kNameSep.
std::string GetQualifiedName(const std::string& ns, std::string GetQualifiedName(absl::string_view ns,
const std::string& name) const { absl::string_view name) const {
using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kCxxSep;
using ::mediapipe::registration_internal::kNameSep; using ::mediapipe::registration_internal::kNameSep;
std::vector<std::string> names = absl::StrSplit(name, kNameSep); std::vector<std::string> names = absl::StrSplit(name, kNameSep);
@ -287,10 +290,10 @@ class FunctionRegistry {
private: private:
mutable absl::Mutex lock_; mutable absl::Mutex lock_;
std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_); absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
// For names included in NamespaceAllowlist, strips the namespace. // For names included in NamespaceAllowlist, strips the namespace.
std::string GetAdjustedName(const std::string& name) { std::string GetAdjustedName(absl::string_view name) {
using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kCxxSep;
std::vector<std::string> names = absl::StrSplit(name, kCxxSep); std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
std::string base_name = names.back(); std::string base_name = names.back();
@ -299,10 +302,10 @@ class FunctionRegistry {
if (NamespaceAllowlist::TopNamespaces().count(ns)) { if (NamespaceAllowlist::TopNamespaces().count(ns)) {
return base_name; return base_name;
} }
return name; return std::string(name);
} }
void Unregister(const std::string& name) { void Unregister(absl::string_view name) {
absl::WriterMutexLock lock(&lock_); absl::WriterMutexLock lock(&lock_);
std::string adjusted_name = GetAdjustedName(name); std::string adjusted_name = GetAdjustedName(name);
if (adjusted_name != name) { if (adjusted_name != name) {
@ -317,7 +320,7 @@ class GlobalFactoryRegistry {
using Functions = FunctionRegistry<R, Args...>; using Functions = FunctionRegistry<R, Args...>;
public: public:
static RegistrationToken Register(const std::string& name, static RegistrationToken Register(absl::string_view name,
typename Functions::Function func) { typename Functions::Function func) {
return functions()->Register(name, std::move(func)); return functions()->Register(name, std::move(func));
} }
@ -326,7 +329,7 @@ class GlobalFactoryRegistry {
// If using namespaces with this registry, the variant with a namespace // If using namespaces with this registry, the variant with a namespace
// argument should be used. // argument should be used.
template <typename... Args2> template <typename... Args2>
static typename Functions::ReturnType CreateByName(const std::string& name, static typename Functions::ReturnType CreateByName(absl::string_view name,
Args2&&... args) { Args2&&... args) {
return functions()->Invoke(name, std::forward<Args2>(args)...); return functions()->Invoke(name, std::forward<Args2>(args)...);
} }
@ -334,7 +337,7 @@ class GlobalFactoryRegistry {
// Returns true if the specified factory function is available. // Returns true if the specified factory function is available.
// If using namespaces with this registry, the variant with a namespace // If using namespaces with this registry, the variant with a namespace
// argument should be used. // argument should be used.
static bool IsRegistered(const std::string& name) { static bool IsRegistered(absl::string_view name) {
return functions()->IsRegistered(name); return functions()->IsRegistered(name);
} }
@ -350,13 +353,13 @@ class GlobalFactoryRegistry {
std::tuple<Args...>>::value, std::tuple<Args...>>::value,
int> = 0> int> = 0>
static typename Functions::ReturnType CreateByNameInNamespace( static typename Functions::ReturnType CreateByNameInNamespace(
const std::string& ns, const std::string& name, Args2&&... args) { absl::string_view ns, absl::string_view name, Args2&&... args) {
return functions()->Invoke(ns, name, std::forward<Args2>(args)...); return functions()->Invoke(ns, name, std::forward<Args2>(args)...);
} }
// Returns true if the specified factory function is available. // Returns true if the specified factory function is available.
// Namespaces in |name| and |ns| are separated by kNameSep. // Namespaces in |name| and |ns| are separated by kNameSep.
static bool IsRegistered(const std::string& ns, const std::string& name) { static bool IsRegistered(absl::string_view ns, absl::string_view name) {
return functions()->IsRegistered(ns, name); return functions()->IsRegistered(ns, name);
} }