Internal change

PiperOrigin-RevId: 488770794
This commit is contained in:
MediaPipe Team 2022-11-15 15:10:36 -08:00 committed by Copybara-Service
parent 7a87546c30
commit 38b636f7ee
2 changed files with 22 additions and 18 deletions

View File

@ -225,6 +225,7 @@ cc_library(
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/strings",

View File

@ -26,10 +26,12 @@
#include "absl/base/macros.h"
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/deps/registration_token.h"
#include "mediapipe/framework/port/canonical_errors.h"
@ -159,7 +161,7 @@ class FunctionRegistry {
FunctionRegistry(const FunctionRegistry&) = delete;
FunctionRegistry& operator=(const FunctionRegistry&) = delete;
RegistrationToken Register(const std::string& name, Function func)
RegistrationToken Register(absl::string_view name, Function func)
ABSL_LOCKS_EXCLUDED(lock_) {
std::string normalized_name = GetNormalizedName(name);
absl::WriterMutexLock lock(&lock_);
@ -189,14 +191,15 @@ class FunctionRegistry {
absl::enable_if_t<std::is_convertible<std::tuple<Args2...>,
std::tuple<Args...>>::value,
int> = 0>
ReturnType Invoke(const std::string& name, Args2&&... args)
ReturnType Invoke(absl::string_view name, Args2&&... args)
ABSL_LOCKS_EXCLUDED(lock_) {
Function function;
{
absl::ReaderMutexLock lock(&lock_);
auto it = functions_.find(name);
if (it == functions_.end()) {
return absl::NotFoundError("No registered object with name: " + name);
return absl::NotFoundError(
absl::StrCat("No registered object with name: ", name));
}
function = it->second;
}
@ -206,7 +209,7 @@ class FunctionRegistry {
// Invokes the specified factory function and returns the result.
// Namespaces in |name| and |ns| are separated by kNameSep.
template <typename... Args2>
ReturnType Invoke(const std::string& ns, const std::string& name,
ReturnType Invoke(absl::string_view ns, absl::string_view name,
Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) {
return Invoke(GetQualifiedName(ns, name), args...);
}
@ -214,14 +217,14 @@ class FunctionRegistry {
// Note that it's possible for registered implementations to be subsequently
// unregistered, though this will never happen with registrations made via
// MEDIAPIPE_REGISTER_FACTORY_FUNCTION.
bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) {
bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) {
absl::ReaderMutexLock lock(&lock_);
return functions_.count(name) != 0;
}
// Returns true if the specified factory function is available.
// Namespaces in |name| and |ns| are separated by kNameSep.
bool IsRegistered(const std::string& ns, const std::string& name) const
bool IsRegistered(absl::string_view ns, absl::string_view name) const
ABSL_LOCKS_EXCLUDED(lock_) {
return IsRegistered(GetQualifiedName(ns, name));
}
@ -244,7 +247,7 @@ class FunctionRegistry {
// Normalizes a C++ qualified name. Validates the name qualification.
// The name must be either unqualified or fully qualified with a leading "::".
// The leading "::" in a fully qualified name is stripped.
std::string GetNormalizedName(const std::string& name) {
std::string GetNormalizedName(absl::string_view name) {
using ::mediapipe::registration_internal::kCxxSep;
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
if (names[0].empty()) {
@ -259,8 +262,8 @@ class FunctionRegistry {
// Returns the registry key for a name specified within a namespace.
// Namespaces are separated by kNameSep.
std::string GetQualifiedName(const std::string& ns,
const std::string& name) const {
std::string GetQualifiedName(absl::string_view ns,
absl::string_view name) const {
using ::mediapipe::registration_internal::kCxxSep;
using ::mediapipe::registration_internal::kNameSep;
std::vector<std::string> names = absl::StrSplit(name, kNameSep);
@ -287,10 +290,10 @@ class FunctionRegistry {
private:
mutable absl::Mutex lock_;
std::unordered_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
absl::flat_hash_map<std::string, Function> functions_ ABSL_GUARDED_BY(lock_);
// For names included in NamespaceAllowlist, strips the namespace.
std::string GetAdjustedName(const std::string& name) {
std::string GetAdjustedName(absl::string_view name) {
using ::mediapipe::registration_internal::kCxxSep;
std::vector<std::string> names = absl::StrSplit(name, kCxxSep);
std::string base_name = names.back();
@ -299,10 +302,10 @@ class FunctionRegistry {
if (NamespaceAllowlist::TopNamespaces().count(ns)) {
return base_name;
}
return name;
return std::string(name);
}
void Unregister(const std::string& name) {
void Unregister(absl::string_view name) {
absl::WriterMutexLock lock(&lock_);
std::string adjusted_name = GetAdjustedName(name);
if (adjusted_name != name) {
@ -317,7 +320,7 @@ class GlobalFactoryRegistry {
using Functions = FunctionRegistry<R, Args...>;
public:
static RegistrationToken Register(const std::string& name,
static RegistrationToken Register(absl::string_view name,
typename Functions::Function func) {
return functions()->Register(name, std::move(func));
}
@ -326,7 +329,7 @@ class GlobalFactoryRegistry {
// If using namespaces with this registry, the variant with a namespace
// argument should be used.
template <typename... Args2>
static typename Functions::ReturnType CreateByName(const std::string& name,
static typename Functions::ReturnType CreateByName(absl::string_view name,
Args2&&... args) {
return functions()->Invoke(name, std::forward<Args2>(args)...);
}
@ -334,7 +337,7 @@ class GlobalFactoryRegistry {
// Returns true if the specified factory function is available.
// If using namespaces with this registry, the variant with a namespace
// argument should be used.
static bool IsRegistered(const std::string& name) {
static bool IsRegistered(absl::string_view name) {
return functions()->IsRegistered(name);
}
@ -350,13 +353,13 @@ class GlobalFactoryRegistry {
std::tuple<Args...>>::value,
int> = 0>
static typename Functions::ReturnType CreateByNameInNamespace(
const std::string& ns, const std::string& name, Args2&&... args) {
absl::string_view ns, absl::string_view name, Args2&&... args) {
return functions()->Invoke(ns, name, std::forward<Args2>(args)...);
}
// Returns true if the specified factory function is available.
// Namespaces in |name| and |ns| are separated by kNameSep.
static bool IsRegistered(const std::string& ns, const std::string& name) {
static bool IsRegistered(absl::string_view ns, absl::string_view name) {
return functions()->IsRegistered(ns, name);
}