From bae14a83b2fc6d26d9a0a9f524d869e057ff24c8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 6 Apr 2023 22:11:16 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 522524565 --- mediapipe/framework/api2/BUILD | 1 + mediapipe/framework/api2/builder.h | 4 ++-- mediapipe/framework/api2/port.h | 13 ++++++++++++- mediapipe/framework/api2/port_test.cc | 12 ++++++++++++ 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 76aace6f5..44486cd91 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -160,6 +160,7 @@ cc_test( deps = [ ":port", "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index ee9796e49..3cae87a3b 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -713,12 +713,12 @@ class Graph { } } - std::string TaggedName(const TagIndexLocation& loc, const std::string& name) { + std::string TaggedName(const TagIndexLocation& loc, absl::string_view name) { if (loc.tag.empty()) { // ParseTagIndexName does not allow using explicit indices without tags, // while ParseTagIndex does. // TODO: decide whether we should just allow it. - return name; + return std::string(name); } else { if (loc.count <= 1) { return absl::StrCat(loc.tag, ":", name); diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index eee542640..f6abe75ed 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -21,6 +21,7 @@ #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/api2/const_str.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator_context.h" @@ -36,6 +37,13 @@ namespace api2 { // directly by node code. class PortBase { public: + constexpr PortBase(absl::string_view tag, TypeId type_id, bool optional, + bool multiple) + : tag_(tag.size(), tag.data()), + optional_(optional), + multiple_(multiple), + type_id_(type_id) {} + constexpr PortBase(std::size_t tag_size, const char* tag, TypeId type_id, bool optional, bool multiple) : tag_(tag_size, tag), @@ -123,7 +131,7 @@ auto GetCollection(CC* cc, const SideOutputBase& port) } template -auto GetOrNull(Collection& collection, const std::string& tag, int index) +auto GetOrNull(Collection& collection, const absl::string_view& tag, int index) -> decltype(&collection.Get(std::declval())) { CollectionItemId id = collection.GetId(tag, index); return id.IsValid() ? &collection.Get(id) : nullptr; @@ -332,6 +340,9 @@ class PortCommon : public Base { using Multiple = PortCommon; using SideFallback = SideFallbackT; + explicit constexpr PortCommon(absl::string_view tag) + : Base(tag, kTypeId, IsOptionalV, IsMultipleV) {} + template explicit constexpr PortCommon(const char (&tag)[N]) : Base(N, tag, kTypeId, IsOptionalV, IsMultipleV) {} diff --git a/mediapipe/framework/api2/port_test.cc b/mediapipe/framework/api2/port_test.cc index 6676e44f0..9b198a84b 100644 --- a/mediapipe/framework/api2/port_test.cc +++ b/mediapipe/framework/api2/port_test.cc @@ -1,11 +1,15 @@ #include "mediapipe/framework/api2/port.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/port/gtest.h" namespace mediapipe { namespace api2 { namespace { +constexpr absl::string_view kInputTag{"INPUT"}; +constexpr absl::string_view kOutputTag{"OUTPUT"}; + TEST(PortTest, IntInput) { static constexpr auto port = Input("FOO"); EXPECT_EQ(port.type_id(), kTypeId); @@ -40,6 +44,14 @@ TEST(PortTest, DeletedCopyConstructorInput) { EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT"); } +TEST(PortTest, DeletedCopyConstructorStringView) { + static constexpr Input kInputPort(kInputTag); + EXPECT_EQ(std::string(kInputPort.Tag()), kInputTag); + + static constexpr Output kOutputPort(kOutputTag); + EXPECT_EQ(std::string(kOutputPort.Tag()), kOutputTag); +} + class AbstractBase { public: virtual ~AbstractBase() = default;