Internal change
PiperOrigin-RevId: 488770794
This commit is contained in:
		
							parent
							
								
									7a87546c30
								
							
						
					
					
						commit
						38b636f7ee
					
				|  | @ -225,6 +225,7 @@ cc_library( | |||
|         "//mediapipe/framework/port:status", | ||||
|         "//mediapipe/framework/port:statusor", | ||||
|         "@com_google_absl//absl/base:core_headers", | ||||
|         "@com_google_absl//absl/container:flat_hash_map", | ||||
|         "@com_google_absl//absl/container:flat_hash_set", | ||||
|         "@com_google_absl//absl/meta:type_traits", | ||||
|         "@com_google_absl//absl/strings", | ||||
|  |  | |||
|  | @ -26,10 +26,12 @@ | |||
| 
 | ||||
| #include "absl/base/macros.h" | ||||
| #include "absl/base/thread_annotations.h" | ||||
| #include "absl/container/flat_hash_map.h" | ||||
| #include "absl/container/flat_hash_set.h" | ||||
| #include "absl/meta/type_traits.h" | ||||
| #include "absl/strings/str_join.h" | ||||
| #include "absl/strings/str_split.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "absl/synchronization/mutex.h" | ||||
| #include "mediapipe/framework/deps/registration_token.h" | ||||
| #include "mediapipe/framework/port/canonical_errors.h" | ||||
|  | @ -159,7 +161,7 @@ class FunctionRegistry { | |||
|   FunctionRegistry(const FunctionRegistry&) = delete; | ||||
|   FunctionRegistry& operator=(const FunctionRegistry&) = delete; | ||||
| 
 | ||||
|   RegistrationToken Register(const std::string& name, Function func) | ||||
|   RegistrationToken Register(absl::string_view name, Function func) | ||||
|       ABSL_LOCKS_EXCLUDED(lock_) { | ||||
|     std::string normalized_name = GetNormalizedName(name); | ||||
|     absl::WriterMutexLock lock(&lock_); | ||||
|  | @ -189,14 +191,15 @@ class FunctionRegistry { | |||
|             absl::enable_if_t<std::is_convertible<std::tuple<Args2...>, | ||||
|                                                   std::tuple<Args...>>::value, | ||||
|                               int> = 0> | ||||
|   ReturnType Invoke(const std::string& name, Args2&&... args) | ||||
|   ReturnType Invoke(absl::string_view name, Args2&&... args) | ||||
|       ABSL_LOCKS_EXCLUDED(lock_) { | ||||
|     Function function; | ||||
|     { | ||||
|       absl::ReaderMutexLock lock(&lock_); | ||||
|       auto it = functions_.find(name); | ||||
|       if (it == functions_.end()) { | ||||
|         return absl::NotFoundError("No registered object with name: " + name); | ||||
|         return absl::NotFoundError( | ||||
|             absl::StrCat("No registered object with name: ", name)); | ||||
|       } | ||||
|       function = it->second; | ||||
|     } | ||||
|  | @ -206,7 +209,7 @@ class FunctionRegistry { | |||
|   // Invokes the specified factory function and returns the result.
 | ||||
|   // Namespaces in |name| and |ns| are separated by kNameSep.
 | ||||
|   template <typename... Args2> | ||||
|   ReturnType Invoke(const std::string& ns, const std::string& name, | ||||
|   ReturnType Invoke(absl::string_view ns, absl::string_view name, | ||||
|                     Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { | ||||
|     return Invoke(GetQualifiedName(ns, name), args...); | ||||
|   } | ||||
|  | @ -214,14 +217,14 @@ class FunctionRegistry { | |||
|   // Note that it's possible for registered implementations to be subsequently
 | ||||
|   // unregistered, though this will never happen with registrations made via
 | ||||
|   // MEDIAPIPE_REGISTER_FACTORY_FUNCTION.
 | ||||
|   bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) { | ||||
|   bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { | ||||
|     absl::ReaderMutexLock lock(&lock_); | ||||
|     return functions_.count(name) != 0; | ||||
|   } | ||||
| 
 | ||||
|   // Returns true if the specified factory function is available.
 | ||||
|   // Namespaces in |name| and |ns| are separated by kNameSep.
 | ||||
|   bool IsRegistered(const std::string& ns, const std::string& name) const | ||||
|   bool IsRegistered(absl::string_view ns, absl::string_view name) const | ||||
|       ABSL_LOCKS_EXCLUDED(lock_) { | ||||
|     return IsRegistered(GetQualifiedName(ns, name)); | ||||
|   } | ||||
|  | @ -244,7 +247,7 @@ class FunctionRegistry { | |||
|   // Normalizes a C++ qualified name.  Validates the name qualification.
 | ||||
|   // The name must be either unqualified or fully qualified with a leading "::".
 | ||||
|   // The leading "::" in a fully qualified name is stripped.
 | ||||
|   std::string GetNormalizedName(const std::string& name) { | ||||
|   std::string GetNormalizedName(absl::string_view name) { | ||||
|     using ::mediapipe::registration_internal::kCxxSep; | ||||
|     std::vector<std::string> names = absl::StrSplit(name, kCxxSep); | ||||
|     if (names[0].empty()) { | ||||
|  | @ -259,8 +262,8 @@ class FunctionRegistry { | |||
| 
 | ||||
|   // Returns the registry key for a name specified within a namespace.
 | ||||
|   // Namespaces are separated by kNameSep.
 | ||||
|   std::string GetQualifiedName(const std::string& ns, | ||||
|                                const std::string& name) const { | ||||
|   std::string GetQualifiedName(absl::string_view ns, | ||||
|                                absl::string_view name) const { | ||||
|     using ::mediapipe::registration_internal::kCxxSep; | ||||
|     using ::mediapipe::registration_internal::kNameSep; | ||||
|     std::vector<std::string> names = absl::StrSplit(name, kNameSep); | ||||
|  | @ -287,10 +290,10 @@ class FunctionRegistry { | |||
| 
 | ||||
|  private: | ||||
|   mutable absl::Mutex lock_; | ||||
|   std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_); | ||||
|   absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_); | ||||
| 
 | ||||
|   // For names included in NamespaceAllowlist, strips the namespace.
 | ||||
|   std::string GetAdjustedName(const std::string& name) { | ||||
|   std::string GetAdjustedName(absl::string_view name) { | ||||
|     using ::mediapipe::registration_internal::kCxxSep; | ||||
|     std::vector<std::string> names = absl::StrSplit(name, kCxxSep); | ||||
|     std::string base_name = names.back(); | ||||
|  | @ -299,10 +302,10 @@ class FunctionRegistry { | |||
|     if (NamespaceAllowlist::TopNamespaces().count(ns)) { | ||||
|       return base_name; | ||||
|     } | ||||
|     return name; | ||||
|     return std::string(name); | ||||
|   } | ||||
| 
 | ||||
|   void Unregister(const std::string& name) { | ||||
|   void Unregister(absl::string_view name) { | ||||
|     absl::WriterMutexLock lock(&lock_); | ||||
|     std::string adjusted_name = GetAdjustedName(name); | ||||
|     if (adjusted_name != name) { | ||||
|  | @ -317,7 +320,7 @@ class GlobalFactoryRegistry { | |||
|   using Functions = FunctionRegistry<R, Args...>; | ||||
| 
 | ||||
|  public: | ||||
|   static RegistrationToken Register(const std::string& name, | ||||
|   static RegistrationToken Register(absl::string_view name, | ||||
|                                     typename Functions::Function func) { | ||||
|     return functions()->Register(name, std::move(func)); | ||||
|   } | ||||
|  | @ -326,7 +329,7 @@ class GlobalFactoryRegistry { | |||
|   // If using namespaces with this registry, the variant with a namespace
 | ||||
|   // argument should be used.
 | ||||
|   template <typename... Args2> | ||||
|   static typename Functions::ReturnType CreateByName(const std::string& name, | ||||
|   static typename Functions::ReturnType CreateByName(absl::string_view name, | ||||
|                                                      Args2&&... args) { | ||||
|     return functions()->Invoke(name, std::forward<Args2>(args)...); | ||||
|   } | ||||
|  | @ -334,7 +337,7 @@ class GlobalFactoryRegistry { | |||
|   // Returns true if the specified factory function is available.
 | ||||
|   // If using namespaces with this registry, the variant with a namespace
 | ||||
|   // argument should be used.
 | ||||
|   static bool IsRegistered(const std::string& name) { | ||||
|   static bool IsRegistered(absl::string_view name) { | ||||
|     return functions()->IsRegistered(name); | ||||
|   } | ||||
| 
 | ||||
|  | @ -350,13 +353,13 @@ class GlobalFactoryRegistry { | |||
|                                                   std::tuple<Args...>>::value, | ||||
|                               int> = 0> | ||||
|   static typename Functions::ReturnType CreateByNameInNamespace( | ||||
|       const std::string& ns, const std::string& name, Args2&&... args) { | ||||
|       absl::string_view ns, absl::string_view name, Args2&&... args) { | ||||
|     return functions()->Invoke(ns, name, std::forward<Args2>(args)...); | ||||
|   } | ||||
| 
 | ||||
|   // Returns true if the specified factory function is available.
 | ||||
|   // Namespaces in |name| and |ns| are separated by kNameSep.
 | ||||
|   static bool IsRegistered(const std::string& ns, const std::string& name) { | ||||
|   static bool IsRegistered(absl::string_view ns, absl::string_view name) { | ||||
|     return functions()->IsRegistered(ns, name); | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user