diff --git a/WORKSPACE b/WORKSPACE index 702d1899e..d43394883 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -26,7 +26,7 @@ versions.check(minimum_bazel_version = "3.7.2") http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -35,8 +35,8 @@ http_archive( patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20210324.2", - sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" + strip_prefix = "abseil-cpp-20220623.1", + sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8" ) http_archive( @@ -212,14 +212,14 @@ http_archive( sha256 = "75922da3a1bdb417d820398eb03d4e9bd067c4905a4246d35a44c01d62154d91", ) -# Point to the commit that deprecates the usage of Eigen::MappedSparseMatrix. +# 2022-10-20 http_archive( name = "pybind11", urls = [ - "https://github.com/pybind/pybind11/archive/70a58c577eaf067748c2ec31bfd0b0a614cffba6.zip", + "https://github.com/pybind/pybind11/archive/v2.10.1.zip", ], - sha256 = "b971842fab1b5b8f3815a2302331782b7d137fef0e06502422bc4bc360f4956c", - strip_prefix = "pybind11-70a58c577eaf067748c2ec31bfd0b0a614cffba6", + sha256 = "fcf94065efcfd0a7a828bacf118fa11c43f6390d0c805e3e6342ac119f2e9976", + strip_prefix = "pybind11-2.10.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) diff --git a/docs/BUILD b/docs/BUILD index ad08df66a..8e85dbf86 100644 --- a/docs/BUILD +++ b/docs/BUILD @@ -17,6 +17,7 @@ py_binary( name = "build_java_api_docs", srcs = ["build_java_api_docs.py"], data = [ + "//third_party/android/sdk:api/26.txt", "//third_party/java/doclava/current:doclava.jar", "//third_party/java/jsilver:jsilver_jar", ], diff --git a/docs/build_java_api_docs.py b/docs/build_java_api_docs.py index e96e1fd83..b13e8d1df 100644 --- a/docs/build_java_api_docs.py +++ b/docs/build_java_api_docs.py @@ -20,10 +20,6 @@ from absl import flags from tensorflow_docs.api_generator import gen_java -_JAVA_ROOT = flags.DEFINE_string('java_src', None, - 'Override the Java source path.', - required=False) - _OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/', 'Write docs here.') @@ -37,27 +33,30 @@ _ = flags.DEFINE_bool( 'search_hints', True, '[UNUSED] Include metadata search hints in the generated files') +_ANDROID_SDK = pathlib.Path('android/sdk/api/26.txt') + def main(_) -> None: - if not (java_root := _JAVA_ROOT.value): - # Default to using a relative path to find the Java source. - mp_root = pathlib.Path(__file__) - while (mp_root := mp_root.parent).name != 'mediapipe': - # Find the nearest `mediapipe` dir. - pass + # Default to using a relative path to find the Java source. + mp_root = pathlib.Path(__file__) + while (mp_root := mp_root.parent).name != 'mediapipe': + # Find the nearest `mediapipe` dir. + pass - # Externally, parts of the repo are nested inside a mediapipe/ directory - # that does not exist internally. Support both. - if (mp_root / 'mediapipe').exists(): - mp_root = mp_root / 'mediapipe' + # Find the root from which all packages are relative. + root = mp_root.parent - java_root = mp_root / 'tasks/java' + # Externally, parts of the repo are nested inside a mediapipe/ directory + # that does not exist internally. Support both. + if (mp_root / 'mediapipe').exists(): + mp_root = mp_root / 'mediapipe' gen_java.gen_java_docs( package='com.google.mediapipe', - source_path=pathlib.Path(java_root), + source_path=mp_root / 'tasks/java', output_dir=pathlib.Path(_OUT_DIR.value), - site_path=pathlib.Path(_SITE_PATH.value)) + site_path=pathlib.Path(_SITE_PATH.value), + federated_docs={'https://developer.android.com': root / _ANDROID_SDK}) if __name__ == '__main__': diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fa1e4314f..fe706acd3 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -30,7 +30,7 @@ from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. - import mediapipe # pytype: disable=import-error + import mediapipe as mp # pytype: disable=import-error except ImportError as e: raise ImportError('Please `pip install mediapipe`.') from e @@ -58,11 +58,13 @@ _SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', def gen_api_docs(): """Generates API docs for the mediapipe package.""" + if hasattr(mp, 'solutions'): + del mp.solutions doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[(PROJECT_SHORT_NAME, mediapipe)], - base_dir=os.path.dirname(mediapipe.__file__), + py_modules=[(PROJECT_SHORT_NAME, mp)], + base_dir=os.path.dirname(mp.__file__), code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index ba461e4a7..555f7543f 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -197,7 +197,6 @@ cc_library( ":spectrogram_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ecd878115..3b658eb5b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -341,7 +341,6 @@ cc_test( srcs = ["concatenate_proto_list_calculator_test.cc"], deps = [ ":concatenate_proto_list_calculator", - ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -403,7 +402,6 @@ cc_test( srcs = ["clip_vector_size_calculator_test.cc"], deps = [ ":clip_vector_size_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -956,10 +954,10 @@ cc_library( deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", @@ -1301,6 +1299,7 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc index efc0612ec..4e007329b 100644 --- a/mediapipe/calculators/core/bypass_calculator.cc +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -111,6 +111,10 @@ class BypassCalculator : public Node { cc->Outputs().Get(id).SetAny(); } } + for (auto id = cc->InputSidePackets().BeginId(); + id != cc->InputSidePackets().EndId(); ++id) { + cc->InputSidePackets().Get(id).SetAny(); + } return absl::OkStatus(); } diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 45bace271..5d0594de9 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -85,75 +85,6 @@ std::string SourceString(Timestamp t) { : absl::StrCat("Timestamp(", t.DebugString(), ")"); } -template -std::string SourceString(Packet packet) { - std::ostringstream oss; - if (packet.IsEmpty()) { - oss << "Packet()"; - } else { - oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" - << packet.Get() << ")"; - } - oss << ".At(" << SourceString(packet.Timestamp()) << ")"; - return oss.str(); -} - -template -class PacketsEqMatcher - : public ::testing::MatcherInterface { - public: - PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} - void DescribeTo(::std::ostream* os) const override { - *os << "The expected packet contents: \n"; - Print(packets_, os); - } - bool MatchAndExplain( - const PacketContainer& value, - ::testing::MatchResultListener* listener) const override { - if (!Equals(packets_, value)) { - if (listener->IsInterested()) { - *listener << "The actual packet contents: \n"; - Print(value, listener->stream()); - } - return false; - } - return true; - } - - private: - bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { - if (c1.size() != c2.size()) { - return false; - } - for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { - Packet p1 = *i1, p2 = *i2; - if (p1.Timestamp() != p2.Timestamp() || p1.IsEmpty() != p2.IsEmpty() || - (!p1.IsEmpty() && - p1.Get() != p2.Get())) { - return false; - } - } - return true; - } - void Print(const PacketContainer& packets, ::std::ostream* os) const { - for (auto it = packets.begin(); it != packets.end(); ++it) { - const Packet& packet = *it; - *os << (it == packets.begin() ? "{" : ""); - *os << SourceString(packet); - *os << (std::next(it) == packets.end() ? "}" : ", "); - } - } - - const PacketContainer packets_; -}; - -template -::testing::Matcher PacketsEq( - const PacketContainer& packets) { - return MakeMatcher( - new PacketsEqMatcher(packets)); -} - // A Calculator::Process callback function. typedef std::function @@ -743,9 +674,6 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -839,13 +767,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_2_packets, IntPacketsEq(expected_output_2)); + EXPECT_THAT(out_2_packets, + ElementsAreArray(PacketMatchers(expected_output_2))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -871,7 +802,8 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { MakePacket(true).At(Timestamp(190000)), MakePacket(false).At(Timestamp(200000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } std::vector StripBoundsUpdates(const std::vector& packets, @@ -891,9 +823,6 @@ std::vector StripBoundsUpdates(const std::vector& packets, // Shows how FlowLimiterCalculator releases auxiliary input packets. // In this test, auxiliary input packets arrive at twice the primary rate. TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -1011,7 +940,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(6).At(Timestamp(60000)), Packet().At(Timestamp(80000)), }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); // Packets following input packets 2 and 6, and not input packets 4 and 8. std::vector expected_auxiliary_output = { @@ -1031,12 +961,13 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { }; std::vector actual_2 = StripBoundsUpdates(out_2_packets, Timestamp(90000)); - EXPECT_THAT(actual_2, IntPacketsEq(expected_auxiliary_output)); + EXPECT_THAT(actual_2, + ElementsAreArray(PacketMatchers(expected_auxiliary_output))); std::vector expected_3 = StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999)); std::vector actual_3 = StripBoundsUpdates(out_3_packets, Timestamp(39999)); - EXPECT_THAT(actual_3, IntPacketsEq(expected_3)); + EXPECT_THAT(actual_3, ElementsAreArray(PacketMatchers(expected_3))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -1045,7 +976,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(true).At(Timestamp(60000)), MakePacket(false).At(Timestamp(80000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } } // anonymous namespace diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 51fb46b98..3306e4ff3 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -15,6 +15,7 @@ #include "mediapipe/calculators/core/get_vector_item_calculator.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" namespace mediapipe { @@ -32,5 +33,9 @@ using GetClassificationListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); +using GetDetectionVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetDetectionVectorItemCalculator); + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index a0ce2ae34..88b04a32b 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -41,6 +41,10 @@ class MuxCalculator : public Node { StreamHandler("MuxInputStreamHandler")); absl::Status Process(CalculatorContext* cc) final { + if (kSelect(cc).IsStream() && kSelect(cc).IsEmpty()) { + return absl::OkStatus(); + } + int select = *kSelect(cc); RET_CHECK(0 <= select && select < kIn(cc).Count()); if (!kIn(cc)[select].IsEmpty()) { diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index 86d2fab42..6b9434be9 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -398,6 +398,95 @@ TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) { MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(MuxCalculatorTest, HandlesCloseGracefully) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_TRUE(output_packets.empty()); +} + +TEST(MuxCalculatorTest, HandlesCloseGracefullyWithDeafultInputStreamHandler) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + # This is required in order for EXPECT_DEATH to work everywhere + executor { name: "" type: "ApplicationThreadExecutor" } + + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_EQ(output_packets.size(), 1); + EXPECT_TRUE(output_packets[0].IsEmpty()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 89e2d371c..530dd3d4a 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -16,12 +16,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,7 +30,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "scale_image_calculator_proto", srcs = ["scale_image_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "set_alpha_calculator_proto", srcs = ["set_alpha_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "image_cropping_calculator_proto", srcs = ["image_cropping_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "segmentation_smoothing_calculator_proto", srcs = ["segmentation_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -112,7 +104,6 @@ cc_library( cc_library( name = "opencv_encoded_image_to_image_frame_calculator", srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_encoded_image_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -127,7 +118,6 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -142,7 +132,6 @@ cc_library( cc_library( name = "opencv_put_text_calculator", srcs = ["opencv_put_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", @@ -156,11 +145,10 @@ cc_library( cc_library( name = "set_alpha_calculator", srcs = ["set_alpha_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -183,11 +171,10 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -212,13 +199,11 @@ cc_library( mediapipe_proto_library( name = "rotation_mode_proto", srcs = ["rotation_mode.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", @@ -243,7 +228,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", @@ -287,13 +271,12 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", @@ -330,7 +313,6 @@ cc_test( cc_library( name = "luminance_calculator", srcs = ["luminance_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,7 +326,6 @@ cc_library( cc_library( name = "sobel_edges_calculator", srcs = ["sobel_edges_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -358,15 +339,14 @@ cc_library( cc_library( name = "recolor_calculator", srcs = ["recolor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", + "//mediapipe/util:color_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - "//mediapipe/util:color_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", ] + select({ @@ -385,9 +365,6 @@ cc_library( name = "scale_image_utils", srcs = ["scale_image_utils.cc"], hdrs = ["scale_image_utils.h"], - visibility = [ - "//mediapipe:__subpackages__", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -400,9 +377,6 @@ cc_library( cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":scale_image_utils", "//mediapipe/calculators/image:scale_image_calculator_cc_proto", @@ -429,7 +403,6 @@ cc_library( mediapipe_proto_library( name = "image_clone_calculator_proto", srcs = ["image_clone_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -439,7 +412,6 @@ mediapipe_proto_library( cc_library( name = "image_clone_calculator", srcs = ["image_clone_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_clone_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -459,7 +431,6 @@ cc_library( cc_library( name = "image_properties_calculator", srcs = ["image_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", @@ -524,7 +495,6 @@ cc_test( mediapipe_proto_library( name = "mask_overlay_calculator_proto", srcs = ["mask_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -534,7 +504,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "opencv_encoded_image_to_image_frame_calculator_proto", srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -544,7 +513,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -554,7 +522,6 @@ mediapipe_proto_library( cc_library( name = "mask_overlay_calculator", srcs = ["mask_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":mask_overlay_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -570,7 +537,6 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -597,7 +563,6 @@ cc_library( cc_library( name = "image_file_properties_calculator", srcs = ["image_file_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_file_properties_cc_proto", @@ -627,11 +592,10 @@ cc_test( cc_library( name = "segmentation_smoothing_calculator", srcs = ["segmentation_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", @@ -724,7 +688,6 @@ cc_library( mediapipe_proto_library( name = "warp_affine_calculator_proto", srcs = ["warp_affine_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -736,7 +699,6 @@ cc_library( name = "warp_affine_calculator", srcs = ["warp_affine_calculator.cc"], hdrs = ["warp_affine_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":affine_transformation", ":warp_affine_calculator_cc_proto", @@ -817,7 +779,6 @@ cc_test( cc_library( name = "yuv_to_image_calculator", srcs = ["yuv_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 3f1278397..645189a07 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -30,6 +30,7 @@ exports_files( glob(["testdata/image_to_tensor/*"]), visibility = [ "//mediapipe/calculators/image:__subpackages__", + "//mediapipe/util:__subpackages__", ], ) @@ -433,6 +434,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":inference_calculator_cc_proto", ":inference_calculator_options_lib", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -463,6 +465,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", ], alwayslink = 1, @@ -512,6 +515,7 @@ cc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/util/tflite:config", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", @@ -794,12 +798,12 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", ] + selects.with_or({ ":compute_shader_unavailable": [], @@ -1130,6 +1134,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:image_test_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1279,7 +1284,6 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1378,9 +1382,9 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", - "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/framework/port:statusor", ] + selects.with_or({ "//mediapipe/gpu:disable_gpu": [], diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index d0513518a..9cb23a393 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -43,6 +43,7 @@ namespace api2 { namespace { using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using DftTensorFormat = Options::DftTensorFormat; using FlushMode = Options::FlushMode; std::vector HannWindow(int window_size, bool sqrt_hann) { @@ -188,6 +189,8 @@ class AudioToTensorCalculator : public Node { int padding_samples_before_; int padding_samples_after_; FlushMode flush_mode_; + DftTensorFormat dft_tensor_format_; + Timestamp initial_timestamp_ = Timestamp::Unstarted(); int64 cumulative_input_samples_ = 0; Timestamp next_output_timestamp_ = Timestamp::Unstarted(); @@ -273,6 +276,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { } padding_samples_before_ = options.padding_samples_before(); padding_samples_after_ = options.padding_samples_after(); + dft_tensor_format_ = options.dft_tensor_format(); flush_mode_ = options.flush_mode(); RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ @@ -492,14 +496,43 @@ absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), timestamp); } - Matrix fft_output_matrix = - Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); - fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); - // The last two elements are the DFT Nyquist values. - fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part - fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part - ASSIGN_OR_RETURN(output_tensor, - ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); + switch (dft_tensor_format_) { + case Options::WITH_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(fft_output_matrix, + {2, fft_size_ / 2})); + break; + } + case Options::WITH_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data(), 1, fft_size_); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_ + 2); + fft_output_matrix(1) = 0.0f; // DC imagery part. + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ + 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ + 2) / 2})); + break; + } + case Options::WITHOUT_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ - 2) / 2})); + break; + } + default: + return absl::InvalidArgumentError("Unsupported dft tensor format."); + } + } else { ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(block, {num_channels_, num_samples_})); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index cff6b2878..aa3c1229c 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -68,4 +68,17 @@ message AudioToTensorCalculatorOptions { } optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; + + enum DftTensorFormat { + DFT_TENSOR_FORMAT_UNKNOWN = 0; + // The output dft tensor without dc and nyquist components. + WITHOUT_DC_AND_NYQUIST = 1; + // The output dft tensor contains the nyquist component as the last + // two values. + WITH_NYQUIST = 2; + // The output dft tensor contains the dc component as the first two values + // and the nyquist component as the last two values. + WITH_DC_AND_NYQUIST = 3; + } + optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST]; } diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 07a5f9fe1..ceb1fc502 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -36,22 +36,17 @@ #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/image_test_utils.h" namespace mediapipe { namespace { -cv::Mat GetRgb(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); - return rgb; -} +constexpr char kTestDataDir[] = + "/mediapipe/calculators/tensor/testdata/" + "image_to_tensor/"; -cv::Mat GetRgba(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); - return rgb; +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDir, filename); } // Image to tensor test template. @@ -147,29 +142,34 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, ASSERT_THAT(tensor_vec, testing::SizeIs(1)); const Tensor& tensor = tensor_vec[0]; + const int channels = tensor.shape().dims[3]; + ASSERT_TRUE(channels == 1 || channels == 3); auto view = tensor.GetCpuReadView(); cv::Mat tensor_mat; if (output_int_tensor) { if (range_min < 0) { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8SC1 : CV_8SC3, const_cast(view.buffer())); } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8UC1 : CV_8UC3, const_cast(view.buffer())); } } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_32FC1 : CV_32FC3, const_cast(view.buffer())); } cv::Mat result_rgb; auto transformation = GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); - tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, - transformation.offset); + tensor_mat.convertTo(result_rgb, channels == 1 ? CV_8UC1 : CV_8UC3, + transformation.scale, transformation.offset); cv::Mat diff; cv::absdiff(result_rgb, expected_result, diff); @@ -185,17 +185,27 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, MP_ASSERT_OK(graph.WaitUntilDone()); } +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + CHECK(false) << "Unsupported input image channles: " << image_channels; +} + Packet MakeImageFramePacket(cv::Mat input) { - ImageFrame input_image( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {}); + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input) { mediapipe::Image input_image(std::make_shared( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {})); + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } @@ -237,15 +247,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, - /*border mode*/ {}, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + /*border mode*/ {}, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { @@ -255,11 +262,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -273,11 +277,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -292,11 +293,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -310,16 +309,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb( - "/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*float_ranges=*/{{-1.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation.png")), + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { @@ -329,11 +324,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation_border_zero.png")), /*float_ranges=*/{{-1.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, @@ -347,10 +339,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, @@ -364,15 +354,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, - BorderMode::kZero, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kZero, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { @@ -382,15 +369,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { @@ -400,11 +384,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -418,11 +399,23 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/{}, roi); +} + +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -437,11 +430,26 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/BorderMode::kZero, roi); +} + +TEST(ImageToTensorCalculatorTest, + LargeSubRectKeepAspectWithRotationBorderZeroGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -455,10 +463,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -472,10 +478,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index f910b59f3..95e38f89c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -48,15 +48,19 @@ class OpenCvProcessor : public ImageToTensorConverter { switch (tensor_type_) { case Tensor::ElementType::kInt8: mat_type_ = CV_8SC3; + mat_gray_type_ = CV_8SC1; break; case Tensor::ElementType::kFloat32: mat_type_ = CV_32FC3; + mat_gray_type_ = CV_32FC1; break; case Tensor::ElementType::kUInt8: mat_type_ = CV_8UC3; + mat_gray_type_ = CV_8UC1; break; default: mat_type_ = -1; + mat_gray_type_ = -1; } } @@ -64,36 +68,57 @@ class OpenCvProcessor : public ImageToTensorConverter { float range_min, float range_max, int tensor_buffer_offset, Tensor& output_tensor) override { - if (input.image_format() != mediapipe::ImageFormat::SRGB && - input.image_format() != mediapipe::ImageFormat::SRGBA) { - return InvalidArgumentError( - absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", - static_cast(input.image_format()))); + const bool is_supported_format = + input.image_format() == mediapipe::ImageFormat::SRGB || + input.image_format() == mediapipe::ImageFormat::SRGBA || + input.image_format() == mediapipe::ImageFormat::GRAY8; + if (!is_supported_format) { + return InvalidArgumentError(absl::StrCat( + "Unsupported format: ", static_cast(input.image_format()))); } - // TODO: Remove the check once tensor_buffer_offset > 0 is - // supported. - RET_CHECK_EQ(tensor_buffer_offset, 0) - << "The non-zero tensor_buffer_offset input is not supported yet."; + + RET_CHECK_GE(tensor_buffer_offset, 0) + << "The input tensor_buffer_offset needs to be non-negative."; const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); const int output_height = output_shape.dims[1]; const int output_width = output_shape.dims[2]; const int output_channels = output_shape.dims[3]; + const int num_elements_per_img = + output_height * output_width * output_channels; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; + const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, mat_type_, - buffer_view.buffer()); + RET_CHECK_GE(output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, mat_type_, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(float) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(float)); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, mat_type_, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); break; default: return InvalidArgumentError( @@ -137,7 +162,8 @@ class OpenCvProcessor : public ImageToTensorConverter { auto transform, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); + transformed.convertTo(dst, dst_data_type, transform.scale, + transform.offset); return absl::OkStatus(); } @@ -145,10 +171,9 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { RET_CHECK_EQ(output_shape.dims.size(), 4) << "Wrong output dims size: " << output_shape.dims.size(); - RET_CHECK_EQ(output_shape.dims[0], 1) - << "Handling batch dimension not equal to 1 is not implemented in this " - "converter."; - RET_CHECK_EQ(output_shape.dims[3], 3) + RET_CHECK_GE(output_shape.dims[0], 1) + << "The batch dimension needs to be equal or larger than 1."; + RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); } @@ -156,6 +181,7 @@ class OpenCvProcessor : public ImageToTensorConverter { enum cv::BorderTypes border_mode_; Tensor::ElementType tensor_type_; int mat_type_; + int mat_gray_type_; }; } // namespace diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index 3f4c05d4e..3f91f3dc2 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,7 +253,14 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // All of the processors except for Metal expect 3 channels. + // TODO: Add a unittest here to test the behavior on GPU, i.e. + // failure. + // Only output channel == 1 when running on CPU and the input image channel + // is 1. Ideally, we want to also support GPU for output channel == 1. But + // setting this on the safer side to prevent unintentional failure. + if (!image.UsesGpu() && image.channels() == 1) { + return 1; + } return 3; } diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index bd8eb3eed..27b8bc23a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -20,6 +20,7 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" @@ -154,6 +155,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ @@ -171,6 +176,9 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( // Create and bind output buffers. for (int i = 0; i < output_size_; ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); gpu_buffers_out_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index ad5df849f..8fd55efa7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -236,14 +236,21 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& gpu_delegate_options) { - use_kernel_caching_ = gpu_delegate_options.has_cached_kernel_path(); + // The kernel cache needs a unique filename based on either model_path or the + // model token, to prevent the cache from being overwritten if the graph has + // more than one model. + use_kernel_caching_ = + gpu_delegate_options.has_cached_kernel_path() && + (options.has_model_path() || gpu_delegate_options.has_model_token()); use_serialized_model_ = gpu_delegate_options.has_serialized_model_dir() && gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { - cached_kernel_filename_ = gpu_delegate_options.cached_kernel_path() + - mediapipe::File::Basename(options.model_path()) + - ".ker"; + std::string basename = options.has_model_path() + ? mediapipe::File::Basename(options.model_path()) + : gpu_delegate_options.model_token(); + cached_kernel_filename_ = mediapipe::file::JoinPath( + gpu_delegate_options.cached_kernel_path(), basename + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = @@ -258,9 +265,9 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - gpu_runner->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + gpu_runner->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index a85071f3e..750f0456e 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -22,6 +22,7 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" @@ -245,6 +246,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); // Create and bind input buffer. std::vector dims{tensor->dims->data, tensor->dims->data + tensor->dims->size}; @@ -266,6 +270,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( output_shapes_.resize(output_indices.size()); for (int i = 0; i < output_shapes_.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); RET_CHECK(tensor->dims->size <= 4); // Create and bind output buffers. // Channels are always padded to multiple of 4. diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index d0dfc12ab..0f8f8706a 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "graph_tensors_packet_generator_proto", srcs = ["graph_tensors_packet_generator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework:packet_generator_proto", @@ -32,49 +31,42 @@ proto_library( proto_library( name = "matrix_to_tensor_calculator_options_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "lapped_tensor_buffer_calculator_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "object_detection_tensors_to_detections_calculator_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensorflow_inference_calculator_proto", srcs = ["tensorflow_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_squeeze_dimensions_calculator_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_image_frame_calculator_proto", srcs = ["tensor_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_matrix_calculator_proto", srcs = ["tensor_to_matrix_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:time_series_header_proto", @@ -84,30 +76,24 @@ proto_library( proto_library( name = "tensor_to_vector_float_calculator_options_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_int_calculator_options_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_string_calculator_options_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) mediapipe_proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_proto", "//mediapipe/framework:calculator_proto", @@ -118,14 +104,12 @@ mediapipe_proto_library( proto_library( name = "vector_float_to_tensor_calculator_options_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "vector_string_to_tensor_calculator_options_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -136,7 +120,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":graph_tensors_packet_generator_proto"], ) @@ -147,7 +130,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":image_frame_to_tensor_calculator_proto"], ) @@ -155,7 +137,6 @@ mediapipe_cc_proto_library( name = "matrix_to_tensor_calculator_options_cc_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":matrix_to_tensor_calculator_options_proto"], ) @@ -163,7 +144,6 @@ mediapipe_cc_proto_library( name = "lapped_tensor_buffer_calculator_cc_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":lapped_tensor_buffer_calculator_proto"], ) @@ -171,7 +151,6 @@ mediapipe_cc_proto_library( name = "object_detection_tensors_to_detections_calculator_cc_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":object_detection_tensors_to_detections_calculator_proto"], ) @@ -179,7 +158,6 @@ mediapipe_cc_proto_library( name = "tensorflow_inference_calculator_cc_proto", srcs = ["tensorflow_inference_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensorflow_inference_calculator_proto"], ) @@ -190,7 +168,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_generator_proto"], ) @@ -201,7 +178,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_calculator_proto"], ) @@ -212,7 +188,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_generator_proto"], ) @@ -223,7 +198,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_calculator_proto"], ) @@ -231,7 +205,6 @@ mediapipe_cc_proto_library( name = "tensor_squeeze_dimensions_calculator_cc_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_squeeze_dimensions_calculator_proto"], ) @@ -239,7 +212,6 @@ mediapipe_cc_proto_library( name = "tensor_to_image_frame_calculator_cc_proto", srcs = ["tensor_to_image_frame_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_image_frame_calculator_proto"], ) @@ -250,7 +222,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tensor_to_matrix_calculator_proto"], ) @@ -258,7 +229,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_float_calculator_options_cc_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_float_calculator_options_proto"], ) @@ -266,7 +236,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_int_calculator_options_cc_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_int_calculator_options_proto"], ) @@ -274,7 +243,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_string_calculator_options_cc_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_string_calculator_options_proto"], ) @@ -285,7 +253,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":vector_int_to_tensor_calculator_options_proto"], ) @@ -293,7 +260,6 @@ mediapipe_cc_proto_library( name = "vector_float_to_tensor_calculator_options_cc_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_float_to_tensor_calculator_options_proto"], ) @@ -301,14 +267,12 @@ mediapipe_cc_proto_library( name = "vector_string_to_tensor_calculator_options_cc_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_string_to_tensor_calculator_options_proto"], ) cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_tensors_packet_generator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -323,7 +287,6 @@ cc_library( cc_library( name = "image_frame_to_tensor_calculator", srcs = ["image_frame_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_frame_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -344,10 +307,9 @@ cc_library( cc_library( name = "matrix_to_tensor_calculator", srcs = ["matrix_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -366,7 +328,6 @@ cc_library( cc_library( name = "lapped_tensor_buffer_calculator", srcs = ["lapped_tensor_buffer_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,9 +349,6 @@ cc_library( # Layering check doesn't play nicely with portable proto wrappers. "no_layering_check", ], - visibility = [ - "//visibility:public", - ], deps = [ ":object_detection_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,14 +365,11 @@ cc_library( cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", # build_cleaner: keep + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", @@ -432,9 +387,6 @@ cc_library( cc_library( name = "string_to_sequence_example_calculator", srcs = ["string_to_sequence_example_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -449,10 +401,9 @@ cc_library( cc_library( name = "tensorflow_inference_calculator", srcs = ["tensorflow_inference_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - ":tensorflow_session", ":tensorflow_inference_calculator_cc_proto", + ":tensorflow_session", "@com_google_absl//absl/log:check", "//mediapipe/framework:timestamp", "@com_google_absl//absl/base:core_headers", @@ -487,7 +438,6 @@ cc_library( "tensorflow_session.h", ], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:core", @@ -505,7 +455,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_calculator", srcs = ["tensorflow_session_from_frozen_graph_calculator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", @@ -515,6 +464,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -536,7 +486,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_generator", srcs = ["tensorflow_session_from_frozen_graph_generator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_frozen_graph_generator_cc_proto", @@ -546,6 +495,7 @@ cc_library( "//mediapipe/framework/deps:clock", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -570,7 +520,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_calculator_cc_proto", @@ -609,7 +558,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_generator_cc_proto", @@ -635,7 +583,6 @@ cc_library( cc_library( name = "tensor_squeeze_dimensions_calculator", srcs = ["tensor_squeeze_dimensions_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_squeeze_dimensions_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -649,7 +596,6 @@ cc_library( cc_library( name = "tensor_to_image_frame_calculator", srcs = ["tensor_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -664,10 +610,9 @@ cc_library( cc_library( name = "tensor_to_matrix_calculator", srcs = ["tensor_to_matrix_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -686,7 +631,6 @@ cc_library( cc_library( name = "tfrecord_reader_calculator", srcs = ["tfrecord_reader_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -702,12 +646,11 @@ cc_library( cc_library( name = "tensor_to_vector_float_calculator", srcs = ["tensor_to_vector_float_calculator.cc"], - visibility = ["//visibility:public"], deps = [ + ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - ":tensor_to_vector_float_calculator_options_cc_proto", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -722,7 +665,6 @@ cc_library( cc_library( name = "tensor_to_vector_int_calculator", srcs = ["tensor_to_vector_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_int_calculator_options_cc_proto", "@com_google_absl//absl/base:core_headers", @@ -744,7 +686,6 @@ cc_library( cc_library( name = "tensor_to_vector_string_calculator", srcs = ["tensor_to_vector_string_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -764,9 +705,6 @@ cc_library( cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", @@ -784,7 +722,6 @@ cc_library( cc_library( name = "vector_int_to_tensor_calculator", srcs = ["vector_int_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_int_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -798,7 +735,6 @@ cc_library( cc_library( name = "vector_float_to_tensor_calculator", srcs = ["vector_float_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_float_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -812,7 +748,6 @@ cc_library( cc_library( name = "vector_string_to_tensor_calculator", srcs = ["vector_string_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_string_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -826,7 +761,6 @@ cc_library( cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1083,7 +1017,6 @@ cc_test( linkstatic = 1, deps = [ ":tensor_to_image_frame_calculator", - ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:image_frame", @@ -1236,6 +1169,7 @@ cc_test( data = [":test_frozen_graph"], linkstatic = 1, deps = [ + ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", ":tensorflow_inference_calculator", ":tensorflow_session_from_frozen_graph_generator", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2007a4fe1..db2a27630 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -18,12 +18,11 @@ load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "ssd_anchors_calculator_proto", srcs = ["ssd_anchors_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -33,7 +32,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_custom_op_resolver_calculator_proto", srcs = ["tflite_custom_op_resolver_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -43,7 +41,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_inference_calculator_proto", srcs = ["tflite_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -53,7 +50,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_converter_calculator_proto", srcs = ["tflite_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -63,7 +59,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_segmentation_calculator_proto", srcs = ["tflite_tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -73,7 +68,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_detections_calculator_proto", srcs = ["tflite_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_classification_calculator_proto", srcs = ["tflite_tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_landmarks_calculator_proto", srcs = ["tflite_tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -103,7 +95,6 @@ mediapipe_proto_library( cc_library( name = "ssd_anchors_calculator", srcs = ["ssd_anchors_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -117,7 +108,6 @@ cc_library( cc_library( name = "tflite_custom_op_resolver_calculator", srcs = ["tflite_custom_op_resolver_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_custom_op_resolver_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -208,7 +198,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -287,10 +276,9 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_converter_calculator_cc_proto", + "//mediapipe/util/tflite:config", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -326,7 +314,6 @@ cc_library( cc_library( name = "tflite_model_calculator", srcs = ["tflite_model_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -340,7 +327,6 @@ cc_library( cc_library( name = "tflite_tensors_to_segmentation_calculator", srcs = ["tflite_tensors_to_segmentation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -408,17 +394,16 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/util/tflite:config", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/deps:file_path", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:location", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", ] + selects.with_or({ @@ -444,7 +429,6 @@ cc_library( cc_library( name = "tflite_tensors_to_classification_calculator", srcs = ["tflite_tensors_to_classification_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -476,7 +460,6 @@ cc_library( cc_library( name = "tflite_tensors_to_landmarks_calculator", srcs = ["tflite_tensors_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -490,7 +473,6 @@ cc_library( cc_library( name = "tflite_tensors_to_floats_calculator", srcs = ["tflite_tensors_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index afdc9ed6f..0f7fa933e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -485,9 +485,9 @@ absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 3a9ddc36f..43eadd53b 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -21,10 +21,9 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ + ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "annotation_overlay_calculator_proto", srcs = ["annotation_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -50,7 +48,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detection_label_id_to_text_calculator_proto", srcs = ["detection_label_id_to_text_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -61,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "filter_detections_calculator_proto", srcs = ["filter_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -71,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -81,13 +76,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "latency_proto", srcs = ["latency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "non_max_suppression_calculator_proto", srcs = ["non_max_suppression_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -97,13 +90,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_frequency_proto", srcs = ["packet_frequency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "packet_frequency_calculator_proto", srcs = ["packet_frequency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -113,7 +104,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_latency_calculator_proto", srcs = ["packet_latency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -123,7 +113,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "collection_has_min_size_calculator_proto", srcs = ["collection_has_min_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -133,7 +122,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "association_calculator_proto", srcs = ["association_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -143,7 +131,6 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -188,7 +175,6 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", @@ -228,9 +214,6 @@ cc_test( cc_library( name = "clock_timestamp_calculator", srcs = ["clock_timestamp_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -246,9 +229,6 @@ cc_library( cc_library( name = "clock_latency_calculator", srcs = ["clock_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -263,11 +243,10 @@ cc_library( cc_library( name = "annotation_overlay_calculator", srcs = ["annotation_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", @@ -296,7 +275,6 @@ cc_library( cc_library( name = "detection_label_id_to_text_calculator", srcs = ["detection_label_id_to_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -328,7 +306,6 @@ cc_library( cc_library( name = "timed_box_list_id_to_label_calculator", srcs = ["timed_box_list_id_to_label_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_id_to_label_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -357,7 +334,6 @@ cc_library( cc_library( name = "detection_transformation_calculator", srcs = ["detection_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -391,7 +367,6 @@ cc_test( cc_library( name = "non_max_suppression_calculator", srcs = ["non_max_suppression_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":non_max_suppression_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -408,7 +383,6 @@ cc_library( cc_library( name = "thresholding_calculator", srcs = ["thresholding_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":thresholding_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -421,7 +395,6 @@ cc_library( cc_library( name = "detection_to_landmarks_calculator", srcs = ["detection_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -436,7 +409,6 @@ cc_library( cc_library( name = "filter_detections_calculator", srcs = ["filter_detections_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":filter_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -450,7 +422,6 @@ cc_library( cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_detection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -471,7 +442,6 @@ cc_library( hdrs = [ "detections_to_rects_calculator.h", ], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -489,7 +459,6 @@ cc_library( cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_transformation_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -504,7 +473,6 @@ cc_library( cc_library( name = "rect_projection_calculator", srcs = ["rect_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:rect_cc_proto", @@ -535,7 +503,6 @@ cc_test( mediapipe_proto_library( name = "rect_to_render_data_calculator_proto", srcs = ["rect_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -547,7 +514,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_to_render_scale_calculator_proto", srcs = ["rect_to_render_scale_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -557,7 +523,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -569,7 +534,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -581,7 +545,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_to_render_data_calculator_proto", srcs = ["timed_box_list_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -593,7 +556,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "labels_to_render_data_calculator_proto", srcs = ["labels_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -605,7 +567,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "thresholding_calculator_proto", srcs = ["thresholding_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -617,7 +578,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_rects_calculator_proto", srcs = ["detections_to_rects_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -627,7 +587,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmark_projection_calculator_proto", srcs = ["landmark_projection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -637,7 +596,6 @@ mediapipe_proto_library( cc_library( name = "landmark_visibility_calculator", srcs = ["landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -649,7 +607,6 @@ cc_library( cc_library( name = "set_landmark_visibility_calculator", srcs = ["set_landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -661,7 +618,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_floats_calculator_proto", srcs = ["landmarks_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -671,7 +627,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_transformation_calculator_proto", srcs = ["rect_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -681,7 +636,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_detection_calculator_proto", srcs = ["landmarks_to_detection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -693,7 +647,6 @@ mediapipe_proto_library( cc_library( name = "detections_to_render_data_calculator", srcs = ["detections_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -713,7 +666,6 @@ cc_library( name = "landmarks_to_render_data_calculator", srcs = ["landmarks_to_render_data_calculator.cc"], hdrs = ["landmarks_to_render_data_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -732,7 +684,6 @@ cc_library( cc_library( name = "timed_box_list_to_render_data_calculator", srcs = ["timed_box_list_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -751,11 +702,9 @@ cc_library( cc_library( name = "labels_to_render_data_calculator", srcs = ["labels_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:ret_check", @@ -771,7 +720,6 @@ cc_library( cc_library( name = "rect_to_render_data_calculator", srcs = ["rect_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -786,7 +734,6 @@ cc_library( cc_library( name = "rect_to_render_scale_calculator", srcs = ["rect_to_render_scale_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_scale_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -821,7 +768,6 @@ cc_test( cc_library( name = "detection_letterbox_removal_calculator", srcs = ["detection_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -835,7 +781,6 @@ cc_library( cc_library( name = "detection_projection_calculator", srcs = ["detection_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -868,7 +813,6 @@ cc_test( cc_library( name = "landmark_letterbox_removal_calculator", srcs = ["landmark_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -882,7 +826,6 @@ cc_library( cc_library( name = "landmark_projection_calculator", srcs = ["landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmark_projection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -915,7 +858,6 @@ cc_test( cc_library( name = "world_landmark_projection_calculator", srcs = ["world_landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -929,7 +871,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -939,7 +880,6 @@ mediapipe_proto_library( cc_library( name = "landmarks_smoothing_calculator", srcs = ["landmarks_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -957,7 +897,6 @@ cc_library( mediapipe_proto_library( name = "visibility_smoothing_calculator_proto", srcs = ["visibility_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -967,7 +906,6 @@ mediapipe_proto_library( cc_library( name = "visibility_smoothing_calculator", srcs = ["visibility_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -983,7 +921,6 @@ cc_library( mediapipe_proto_library( name = "visibility_copy_calculator_proto", srcs = ["visibility_copy_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -993,7 +930,6 @@ mediapipe_proto_library( cc_library( name = "visibility_copy_calculator", srcs = ["visibility_copy_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_copy_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1008,7 +944,6 @@ cc_library( cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1055,7 +990,6 @@ cc_test( mediapipe_proto_library( name = "top_k_scores_calculator_proto", srcs = ["top_k_scores_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1065,7 +999,6 @@ mediapipe_proto_library( cc_library( name = "top_k_scores_calculator", srcs = ["top_k_scores_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":top_k_scores_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1109,7 +1042,6 @@ cc_test( mediapipe_proto_library( name = "local_file_contents_calculator_proto", srcs = ["local_file_contents_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1119,7 +1051,6 @@ mediapipe_proto_library( cc_library( name = "local_file_contents_calculator", srcs = ["local_file_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1133,7 +1064,6 @@ cc_library( cc_library( name = "local_file_pattern_contents_calculator", srcs = ["local_file_pattern_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:file_helpers", @@ -1147,7 +1077,6 @@ cc_library( name = "filter_collection_calculator", srcs = ["filter_collection_calculator.cc"], hdrs = ["filter_collection_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:classification_cc_proto", @@ -1165,7 +1094,6 @@ cc_library( name = "collection_has_min_size_calculator", srcs = ["collection_has_min_size_calculator.cc"], hdrs = ["collection_has_min_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":collection_has_min_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1193,7 +1121,6 @@ cc_test( cc_library( name = "association_calculator", hdrs = ["association_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":association_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1210,7 +1137,6 @@ cc_library( cc_library( name = "association_norm_rect_calculator", srcs = ["association_norm_rect_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1225,7 +1151,6 @@ cc_library( cc_library( name = "association_detection_calculator", srcs = ["association_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1260,7 +1185,6 @@ cc_test( cc_library( name = "detections_to_timed_box_list_calculator", srcs = ["detections_to_timed_box_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1275,7 +1199,6 @@ cc_library( cc_library( name = "detection_unique_id_calculator", srcs = ["detection_unique_id_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1288,7 +1211,6 @@ cc_library( mediapipe_proto_library( name = "logic_calculator_proto", srcs = ["logic_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1298,7 +1220,6 @@ mediapipe_proto_library( cc_library( name = "logic_calculator", srcs = ["logic_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":logic_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1311,10 +1232,9 @@ cc_library( cc_library( name = "to_image_calculator", srcs = ["to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image_frame", @@ -1334,10 +1254,9 @@ cc_library( cc_library( name = "from_image_calculator", srcs = ["from_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", @@ -1386,7 +1305,6 @@ cc_test( mediapipe_proto_library( name = "refine_landmarks_from_heatmap_calculator_proto", srcs = ["refine_landmarks_from_heatmap_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1404,7 +1322,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":refine_landmarks_from_heatmap_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1455,7 +1372,6 @@ cc_library( name = "inverse_matrix_calculator", srcs = ["inverse_matrix_calculator.cc"], hdrs = ["inverse_matrix_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 53d968151..f2b8135f2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -21,19 +21,17 @@ load( licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -58,7 +56,6 @@ proto_library( proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", @@ -68,7 +65,6 @@ proto_library( proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", @@ -78,7 +74,6 @@ proto_library( proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", @@ -88,7 +83,6 @@ proto_library( proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", ], @@ -101,7 +95,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:motion_analysis_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_calculator_proto"], ) @@ -112,7 +105,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:flow_packager_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_calculator_proto"], ) @@ -123,7 +115,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_tracker_calculator_proto"], ) @@ -134,7 +125,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_calculator_proto"], ) @@ -145,7 +135,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_detector_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_calculator_proto"], ) @@ -155,7 +144,6 @@ mediapipe_cc_proto_library( cc_deps = [ "//mediapipe/framework:calculator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":video_pre_stream_calculator_proto"], ) @@ -163,7 +151,6 @@ mediapipe_cc_proto_library( name = "flow_to_image_calculator_cc_proto", srcs = ["flow_to_image_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":flow_to_image_calculator_proto"], ) @@ -171,14 +158,12 @@ mediapipe_cc_proto_library( name = "opencv_video_encoder_calculator_cc_proto", srcs = ["opencv_video_encoder_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":opencv_video_encoder_calculator_proto"], ) cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_to_image_calculator_cc_proto", "//mediapipe/calculators/video/tool:flow_quantizer_model", @@ -198,7 +183,6 @@ cc_library( cc_library( name = "opencv_video_decoder_calculator", srcs = ["opencv_video_decoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", @@ -217,7 +201,6 @@ cc_library( cc_library( name = "opencv_video_encoder_calculator", srcs = ["opencv_video_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_video_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -240,7 +223,6 @@ cc_library( cc_library( name = "tvl1_optical_flow_calculator", srcs = ["tvl1_optical_flow_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -256,7 +238,6 @@ cc_library( cc_library( name = "motion_analysis_calculator", srcs = ["motion_analysis_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":motion_analysis_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -282,7 +263,6 @@ cc_library( cc_library( name = "flow_packager_calculator", srcs = ["flow_packager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_packager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -300,7 +280,6 @@ cc_library( cc_library( name = "box_tracker_calculator", srcs = ["box_tracker_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -327,7 +306,6 @@ cc_library( cc_library( name = "box_detector_calculator", srcs = ["box_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_detector_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -342,12 +320,12 @@ cc_library( "//mediapipe/framework/port:opencv_features2d", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/util/tracking:box_tracker_cc_proto", + "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/util/tracking", "//mediapipe/util/tracking:box_detector", "//mediapipe/util/tracking:box_tracker", - "//mediapipe/util/tracking:box_tracker_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util/tracking:tracking_visualization_utilities", ] + select({ "//mediapipe:android": [ @@ -369,7 +347,6 @@ cc_library( cc_library( name = "tracked_detection_manager_calculator", srcs = ["tracked_detection_manager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tracked_detection_manager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -390,7 +367,6 @@ cc_library( cc_library( name = "video_pre_stream_calculator", srcs = ["video_pre_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":video_pre_stream_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,7 +383,6 @@ filegroup( "testdata/format_MKV_VP8_VORBIS.video", "testdata/format_MP4_AVC720P_AAC.video", ], - visibility = ["//visibility:public"], ) cc_test( @@ -480,7 +455,6 @@ mediapipe_binary_graph( name = "parallel_tracker_binarypb", graph = "testdata/parallel_tracker_graph.pbtxt", output_name = "testdata/parallel_tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", @@ -494,7 +468,6 @@ mediapipe_binary_graph( name = "tracker_binarypb", graph = "testdata/tracker_graph.pbtxt", output_name = "testdata/tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD index edf98bf13..27aa088e7 100644 --- a/mediapipe/examples/desktop/hello_world/BUILD +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -14,12 +14,11 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) cc_binary( name = "hello_world", srcs = ["hello_world.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 19c51853c..3cc72b4f1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -139,7 +139,7 @@ mediapipe_proto_library( name = "test_calculators_proto", testonly = 1, srcs = ["test_calculators.proto"], - visibility = ["//visibility:public"], + visibility = [":mediapipe_internal"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1039,7 +1039,6 @@ cc_library( ":graph_service_manager", ":port", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1469,6 +1468,7 @@ cc_test( "//mediapipe/framework/stream_handler:mux_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/strings", ], ) @@ -1659,9 +1659,6 @@ cc_test( "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:default_side_packet_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:template_parser", diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 6d3323b97..19273bf44 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -412,11 +412,11 @@ using GenericNode = Node; template class Node : public NodeBase { public: - Node() : NodeBase(Calc::kCalculatorName) {} + Node() : NodeBase(std::string(Calc::kCalculatorName)) {} // Overrides the built-in calculator type string with the provided argument. // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces - Node(const std::string& type_override) : NodeBase(type_override) {} + Node(std::string type_override) : NodeBase(std::move(type_override)) {} // These methods only allow access to ports declared in the contract. // The argument must be a tag object created with the MPP_TAG macro. diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index c17a2e1e2..526a74835 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -98,14 +98,13 @@ void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { manager_->LockIntroData(); } +void CalculatorGraph::GraphInputStream::SetNextTimestampBound( + Timestamp timestamp) { + shard_.SetNextTimestampBound(timestamp); +} + void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { - // Since GraphInputStream doesn't allow SetOffset() and - // SetNextTimestampBound(), the timestamp bound to propagate is only - // determined by the timestamp of the output packets. - CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name() - << "\" failed"; - manager_->PropagateUpdatesToMirrors( - shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_); + manager_->PropagateUpdatesToMirrors(shard_.NextTimestampBound(), &shard_); } void CalculatorGraph::GraphInputStream::Close() { @@ -868,6 +867,19 @@ absl::Status CalculatorGraph::AddPacketToInputStream( return AddPacketToInputStreamInternal(stream_name, std::move(packet)); } +absl::Status CalculatorGraph::SetInputStreamTimestampBound( + const std::string& stream_name, Timestamp timestamp) { + std::unique_ptr* stream = + mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "SetInputStreamTimestampBound called on input stream \"$0\" which is not " + "a graph input stream.", + stream_name); + (*stream)->SetNextTimestampBound(timestamp); + (*stream)->PropagateUpdatesToMirrors(); + return absl::OkStatus(); +} + // We avoid having two copies of this code for AddPacketToInputStream( // const Packet&) and AddPacketToInputStream(Packet &&) by having this // internal-only templated version. T&& is a forwarding reference here, so diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index c51476102..04f9de45f 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -257,6 +257,10 @@ class CalculatorGraph { absl::Status AddPacketToInputStream(const std::string& stream_name, Packet&& packet); + // Indicates that input will arrive no earlier than a certain timestamp. + absl::Status SetInputStreamTimestampBound(const std::string& stream_name, + Timestamp timestamp); + // Sets the queue size of a graph input stream, overriding the graph default. absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name, int max_queue_size); @@ -425,6 +429,8 @@ class CalculatorGraph { void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); } + void SetNextTimestampBound(Timestamp timestamp); + void PropagateUpdatesToMirrors(); void Close(); diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index b55f9459d..d149337cc 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/strings/str_replace.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,6 +26,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { @@ -1536,7 +1539,7 @@ class EmptyPacketCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(EmptyPacketCalculator); -// This test shows that an output timestamp bound can be specified by outputing +// This test shows that an output timestamp bound can be specified by outputting // an empty packet with a settled timestamp. TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { // OffsetAndBoundCalculator runs on parallel threads and sends ts @@ -1580,6 +1583,195 @@ TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); } + // Shut down the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows that input timestamp bounds can be specified using +// CalculatorGraph::SetInputStreamTimestampBound. +TEST(CalculatorGraphBoundsTest, SetInputStreamTimestampBound) { + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "input_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send in timestamp bounds. + for (int i = 0; i < 9; ++i) { + const int ts = 10 + i * 10; + MP_ASSERT_OK(graph.SetInputStreamTimestampBound( + "input_0", Timestamp(ts).NextAllowedInStream())); + MP_ASSERT_OK(graph.WaitUntilIdle()); + } + + // 9 timestamp bounds are converted to packets. + EXPECT_EQ(output_0_packets.size(), 9); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); + } + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows how an input stream with infrequent packets, such as +// configuration protobufs, can be consumed while processing more frequent +// packets, such as video frames. +TEST(CalculatorGraphBoundsTest, TimestampBoundsForInfrequentInput) { + // PassThroughCalculator consuming two input streams, with default ISH. + std::string config_str = R"pb( + input_stream: "INFREQUENT:config" + input_stream: "FREQUENT:frame" + node { + calculator: "PassThroughCalculator" + input_stream: "CONFIG:config" + input_stream: "VIDEO:frame" + output_stream: "VIDEO:output_frame" + output_stream: "CONFIG:output_config" + } + )pb"; + + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector frame_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_frame", + [&](const Packet& p) { + frame_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + std::vector config_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_config", + [&](const Packet& p) { + config_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Utility functions to send packets or timestamp bounds. + auto send_fn = [&](std::string stream, std::string value, int ts) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + stream, + MakePacket(absl::StrCat(value)).At(Timestamp(ts)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + auto bound_fn = [&](std::string stream, int ts) { + MP_ASSERT_OK(graph.SetInputStreamTimestampBound(stream, Timestamp(ts))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + // Send in a frame packet. + send_fn("frame", "frame_0", 0); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, ElementsAreArray(PacketMatchers({}))); + bound_fn("config", 10000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_1", 20000); + // The frame is not processed yet. + // The PassThroughCalculator with TimestampOffset 0 now propagates + // Timestamp bound 10000 to both "output_frame" and "output_config", + // which appears here as Packet().At(Timestamp(9999). The timestamp + // bounds at 29999 and 50000 are propagated similarly. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + }))); + bound_fn("config", 30000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_2", 40000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + }))); + send_fn("config", "config_1", 50000); + // The frame is processed after a fresh config arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_3", 60000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + }))); + bound_fn("config", 70000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + MakePacket("frame_3").At(Timestamp(60000)), + }))); + + // One config packet is deleivered. + EXPECT_THAT(config_packets, + ElementsAreArray(PacketMatchers({ + Packet().At(Timestamp(0)), + Packet().At(Timestamp(9999)), + Packet().At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + Packet().At(Timestamp(40000)), + MakePacket("config_1").At(Timestamp(50000)), + Packet().At(Timestamp(60000)), + }))); + // Shutdown the graph. MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index a39d7476e..95ab21707 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -225,6 +225,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index b39a1e293..1a33b2b24 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -26,10 +26,12 @@ #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/deps/registration_token.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -159,7 +161,7 @@ class FunctionRegistry { FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete; - RegistrationToken Register(const std::string& name, Function func) + RegistrationToken Register(absl::string_view name, Function func) ABSL_LOCKS_EXCLUDED(lock_) { std::string normalized_name = GetNormalizedName(name); absl::WriterMutexLock lock(&lock_); @@ -189,14 +191,15 @@ class FunctionRegistry { absl::enable_if_t, std::tuple>::value, int> = 0> - ReturnType Invoke(const std::string& name, Args2&&... args) + ReturnType Invoke(absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { Function function; { absl::ReaderMutexLock lock(&lock_); auto it = functions_.find(name); if (it == functions_.end()) { - return absl::NotFoundError("No registered object with name: " + name); + return absl::NotFoundError( + absl::StrCat("No registered object with name: ", name)); } function = it->second; } @@ -206,7 +209,7 @@ class FunctionRegistry { // Invokes the specified factory function and returns the result. // Namespaces in |name| and |ns| are separated by kNameSep. template - ReturnType Invoke(const std::string& ns, const std::string& name, + ReturnType Invoke(absl::string_view ns, absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { return Invoke(GetQualifiedName(ns, name), args...); } @@ -214,14 +217,14 @@ class FunctionRegistry { // Note that it's possible for registered implementations to be subsequently // unregistered, though this will never happen with registrations made via // MEDIAPIPE_REGISTER_FACTORY_FUNCTION. - bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) { + bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { absl::ReaderMutexLock lock(&lock_); return functions_.count(name) != 0; } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - bool IsRegistered(const std::string& ns, const std::string& name) const + bool IsRegistered(absl::string_view ns, absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { return IsRegistered(GetQualifiedName(ns, name)); } @@ -244,7 +247,7 @@ class FunctionRegistry { // Normalizes a C++ qualified name. Validates the name qualification. // The name must be either unqualified or fully qualified with a leading "::". // The leading "::" in a fully qualified name is stripped. - std::string GetNormalizedName(const std::string& name) { + std::string GetNormalizedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); if (names[0].empty()) { @@ -259,8 +262,8 @@ class FunctionRegistry { // Returns the registry key for a name specified within a namespace. // Namespaces are separated by kNameSep. - std::string GetQualifiedName(const std::string& ns, - const std::string& name) const { + std::string GetQualifiedName(absl::string_view ns, + absl::string_view name) const { using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kNameSep; std::vector names = absl::StrSplit(name, kNameSep); @@ -287,10 +290,10 @@ class FunctionRegistry { private: mutable absl::Mutex lock_; - std::unordered_map functions_ ABSL_GUARDED_BY(lock_); + absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); // For names included in NamespaceAllowlist, strips the namespace. - std::string GetAdjustedName(const std::string& name) { + std::string GetAdjustedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); std::string base_name = names.back(); @@ -299,10 +302,10 @@ class FunctionRegistry { if (NamespaceAllowlist::TopNamespaces().count(ns)) { return base_name; } - return name; + return std::string(name); } - void Unregister(const std::string& name) { + void Unregister(absl::string_view name) { absl::WriterMutexLock lock(&lock_); std::string adjusted_name = GetAdjustedName(name); if (adjusted_name != name) { @@ -317,7 +320,7 @@ class GlobalFactoryRegistry { using Functions = FunctionRegistry; public: - static RegistrationToken Register(const std::string& name, + static RegistrationToken Register(absl::string_view name, typename Functions::Function func) { return functions()->Register(name, std::move(func)); } @@ -326,7 +329,7 @@ class GlobalFactoryRegistry { // If using namespaces with this registry, the variant with a namespace // argument should be used. template - static typename Functions::ReturnType CreateByName(const std::string& name, + static typename Functions::ReturnType CreateByName(absl::string_view name, Args2&&... args) { return functions()->Invoke(name, std::forward(args)...); } @@ -334,7 +337,7 @@ class GlobalFactoryRegistry { // Returns true if the specified factory function is available. // If using namespaces with this registry, the variant with a namespace // argument should be used. - static bool IsRegistered(const std::string& name) { + static bool IsRegistered(absl::string_view name) { return functions()->IsRegistered(name); } @@ -350,13 +353,13 @@ class GlobalFactoryRegistry { std::tuple>::value, int> = 0> static typename Functions::ReturnType CreateByNameInNamespace( - const std::string& ns, const std::string& name, Args2&&... args) { + absl::string_view ns, absl::string_view name, Args2&&... args) { return functions()->Invoke(ns, name, std::forward(args)...); } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - static bool IsRegistered(const std::string& ns, const std::string& name) { + static bool IsRegistered(absl::string_view ns, absl::string_view name) { return functions()->IsRegistered(ns, name); } diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 70775949d..0202b8689 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -97,39 +97,24 @@ absl::Status StatusBuilder::Impl::JoinMessageToStatus() { }()); } -StatusBuilder::Impl::Impl(const absl::Status& status, const char* file, - int line) - : status(status), line(line), file(file), stream() {} - -StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line) - : status(std::move(status)), line(line), file(file), stream() {} - StatusBuilder::Impl::Impl(const absl::Status& status, mediapipe::source_location location) - : status(status), - line(location.line()), - file(location.file_name()), - stream() {} + : status(status), location(location), stream() {} StatusBuilder::Impl::Impl(absl::Status&& status, mediapipe::source_location location) - : status(std::move(status)), - line(location.line()), - file(location.file_name()), - stream() {} + : status(std::move(status)), location(location), stream() {} StatusBuilder::Impl::Impl(const Impl& other) : status(other.status), - line(other.line), - file(other.file), + location(other.location), no_logging(other.no_logging), stream(other.stream.str()), join_style(other.join_style) {} StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) { status = other.status; - line = other.line; - file = other.file; + location = other.location; no_logging = other.no_logging; stream = std::ostringstream(other.stream.str()); join_style = other.join_style; diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index d2e40d575..ae11699d2 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -60,17 +60,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { ? nullptr : std::make_unique(absl::Status(code, ""), location)) {} - StatusBuilder(const absl::Status& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(original_status, file, line)) {} - - StatusBuilder(absl::Status&& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(std::move(original_status), file, - line)) {} - bool ok() const { return !impl_; } StatusBuilder& SetAppend() &; @@ -109,8 +98,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { kPrepend, }; - Impl(const absl::Status& status, const char* file, int line); - Impl(absl::Status&& status, const char* file, int line); Impl(const absl::Status& status, mediapipe::source_location location); Impl(absl::Status&& status, mediapipe::source_location location); Impl(const Impl&); @@ -120,10 +107,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // The status that the result will be based on. absl::Status status; - // The line to record if this file is logged. - int line; - // Not-owned: The file to record if this status is logged. - const char* file; + // The source location to record if this file is logged. + mediapipe::source_location location; // Logging disabled if true. bool no_logging = false; // The additional messages added with `<<`. This is nullptr when status_ is diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index 560acd3c6..f517bb909 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -33,21 +33,6 @@ TEST(StatusBuilder, OkStatusRvalue) { ASSERT_EQ(status, absl::OkStatus()); } -TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - -TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) { - const auto original_status = absl::OkStatus(); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - TEST(StatusBuilder, AnnotateMode) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, "original message"), @@ -60,30 +45,6 @@ TEST(StatusBuilder, AnnotateMode) { "original message; annotated message1 annotated message2"); } -TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, - "original message"), - "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - -TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) { - const auto original_status = - absl::Status(absl::StatusCode::kNotFound, "original message"); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - TEST(StatusBuilder, PrependModeLvalue) { StatusBuilder builder( absl::Status(absl::StatusCode::kInvalidArgument, "original message"), diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index 757d99392..92bbf0b84 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -81,11 +81,11 @@ // MP_RETURN_IF_ERROR(foo.Method(args...)); // return absl::OkStatus(); // } -#define MP_RETURN_IF_ERROR(expr) \ - STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ - status_macro_internal_adaptor = {(expr), __FILE__, __LINE__}) { \ - } else /* NOLINT */ \ +#define MP_RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr), MEDIAPIPE_LOC}) { \ + } else /* NOLINT */ \ return status_macro_internal_adaptor.Consume() // Executes an expression `rexpr` that returns a `absl::StatusOr`. On @@ -156,14 +156,14 @@ return mediapipe::StatusBuilder( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__)) + MEDIAPIPE_LOC)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ mediapipe::StatusBuilder _( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__); \ + MEDIAPIPE_LOC); \ (void)_; /* error_expression is allowed to not use this variable */ \ return (error_expression)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ @@ -201,18 +201,17 @@ namespace status_macro_internal { // that declares a variable. class StatusAdaptorForMacros { public: - StatusAdaptorForMacros(const absl::Status& status, const char* file, int line) - : builder_(status, file, line) {} + StatusAdaptorForMacros(const absl::Status& status, source_location location) + : builder_(status, location) {} - StatusAdaptorForMacros(absl::Status&& status, const char* file, int line) - : builder_(std::move(status), file, line) {} + StatusAdaptorForMacros(absl::Status&& status, source_location location) + : builder_(std::move(status), location) {} - StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(const StatusBuilder& builder, + source_location /*location*/) : builder_(builder) {} - StatusAdaptorForMacros(StatusBuilder&& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(StatusBuilder&& builder, source_location /*location*/) : builder_(std::move(builder)) {} StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index c3241d911..fdb698c48 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -17,7 +17,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) @@ -26,7 +26,6 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats:location_data_proto"], ) @@ -45,7 +44,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "classification_proto", srcs = ["classification.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -64,46 +62,39 @@ mediapipe_register_type( mediapipe_proto_library( name = "image_format_proto", srcs = ["image_format.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "matrix_data_proto", srcs = ["matrix_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "location_data_proto", srcs = ["location_data.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "affine_transform_data_proto", srcs = ["affine_transform_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "time_series_header_proto", srcs = ["time_series_header.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_file_properties_proto", srcs = ["image_file_properties.proto"], - visibility = ["//visibility:public"], ) cc_library( name = "deleting_file", srcs = ["deleting_file.cc"], hdrs = ["deleting_file.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", ], @@ -113,7 +104,6 @@ cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/formats:matrix_data_cc_proto", @@ -129,13 +119,10 @@ cc_library( name = "affine_transform", srcs = ["affine_transform.cc"], hdrs = ["affine_transform.h"], - visibility = [ - "//visibility:public", - ], deps = [ + ":affine_transform_data_cc_proto", "//mediapipe/framework:port", "//mediapipe/framework:type_map", - "//mediapipe/framework/formats:affine_transform_data_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:point", @@ -154,7 +141,6 @@ cc_library( name = "image_frame", srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/base", @@ -179,7 +165,6 @@ cc_library( name = "image_frame_opencv", srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "//mediapipe/framework/formats:image_format_cc_proto", @@ -206,11 +191,10 @@ cc_library( name = "location", srcs = ["location.cc"], hdrs = ["location.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -238,9 +222,9 @@ cc_library( name = "location_opencv", srcs = ["location_opencv.cc"], hdrs = ["location_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":location", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:opencv_imgproc", ], alwayslink = 1, @@ -251,6 +235,7 @@ cc_test( srcs = ["location_opencv_test.cc"], deps = [ ":location_opencv", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:rectangle", ], @@ -259,7 +244,6 @@ cc_test( cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", ], @@ -268,7 +252,6 @@ cc_library( cc_library( name = "yuv_image", hdrs = ["yuv_image.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:integral_types", "@libyuv", @@ -292,7 +275,6 @@ cc_test( mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -310,9 +292,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = [ - "//visibility:public", - ], ) mediapipe_register_type( @@ -344,10 +323,9 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", @@ -374,7 +352,6 @@ cc_library( name = "image_multi_pool", srcs = ["image_multi_pool.cc"], hdrs = ["image_multi_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_frame_pool", @@ -411,7 +388,6 @@ cc_library( hdrs = [ "image_opencv.h", ], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_format_cc_proto", @@ -425,7 +401,6 @@ cc_library( name = "image_frame_pool", srcs = ["image_frame_pool.cc"], hdrs = ["image_frame_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "@com_google_absl//absl/memory", @@ -476,7 +451,6 @@ cc_library( "-landroid", ], }), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 328001e85..9bcb7bccd 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -16,7 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -24,12 +24,10 @@ mediapipe_proto_library( name = "locus_proto", srcs = ["locus.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "rasterization_proto", srcs = ["rasterization.proto"], - visibility = ["//visibility:public"], ) diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 28e0bfc6a..f1bbc0289 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,22 +16,20 @@ # Description: # Working with dense optical flow in mediapipe. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") -package(default_visibility = ["//visibility:private"]) +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "optical_flow_field_data_cc_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], deps = [":optical_flow_field_data_proto"], ) @@ -39,9 +37,6 @@ cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", diff --git a/mediapipe/framework/formats/object_detection/BUILD b/mediapipe/framework/formats/object_detection/BUILD index 39940acdc..35292e1cc 100644 --- a/mediapipe/framework/formats/object_detection/BUILD +++ b/mediapipe/framework/formats/object_detection/BUILD @@ -19,17 +19,15 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "anchor_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "anchor_cc_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], deps = [":anchor_proto"], ) diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index d1dffa414..a7bd9ef43 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -354,7 +354,9 @@ NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) { } } *min_stream_timestamp = std::min(min_packet, min_bound); - if (*min_stream_timestamp == Timestamp::Done()) { + if (*min_stream_timestamp >= Timestamp::OneOverPostStream()) { + // Either OneOverPostStream or Done indicates no more packets. + *min_stream_timestamp = Timestamp::Done(); last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream(); return NodeReadiness::kReadyForClose; } diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 87944d80f..e499ca3a6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -311,6 +311,17 @@ cc_library( ], ) +cc_library( + name = "opencv_videoio", + hdrs = ["opencv_videoio_inc.h"], + visibility = ["//visibility:public"], + deps = [ + ":opencv_core", + "//mediapipe/framework:port", + "//third_party:opencv", + ], +) + cc_library( name = "parse_text_proto", hdrs = [ diff --git a/mediapipe/framework/port/opencv_videoio_inc.h b/mediapipe/framework/port/opencv_videoio_inc.h new file mode 100644 index 000000000..63029b69f --- /dev/null +++ b/mediapipe/framework/port/opencv_videoio_inc.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ + +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "third_party/OpenCV/videoio.hpp" + +#endif // MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 237aa825f..b53a1ac39 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -334,6 +334,10 @@ cc_library( "graph_profiler_stub.h", ], visibility = ["//mediapipe/framework:__pkg__"], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + ], ) cc_test( diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 8771a8773..01ef6ee86 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,40 +13,36 @@ # limitations under the License. # +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -54,7 +50,6 @@ mediapipe_cc_proto_library( name = "default_input_stream_handler_cc_proto", srcs = ["default_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":default_input_stream_handler_proto"], ) @@ -62,7 +57,6 @@ mediapipe_cc_proto_library( name = "fixed_size_input_stream_handler_cc_proto", srcs = ["fixed_size_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":fixed_size_input_stream_handler_proto"], ) @@ -70,7 +64,6 @@ mediapipe_cc_proto_library( name = "sync_set_input_stream_handler_cc_proto", srcs = ["sync_set_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":sync_set_input_stream_handler_proto"], ) @@ -78,14 +71,12 @@ mediapipe_cc_proto_library( name = "timestamp_align_input_stream_handler_cc_proto", srcs = ["timestamp_align_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":timestamp_align_input_stream_handler_proto"], ) cc_library( name = "barrier_input_stream_handler", srcs = ["barrier_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -96,7 +87,6 @@ cc_library( name = "default_input_stream_handler", srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", @@ -108,7 +98,6 @@ cc_library( cc_library( name = "early_close_input_stream_handler", srcs = ["early_close_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "@com_google_absl//absl/strings", @@ -119,7 +108,6 @@ cc_library( cc_library( name = "fixed_size_input_stream_handler", srcs = ["fixed_size_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ ":default_input_stream_handler", "//mediapipe/framework:input_stream_handler", @@ -131,7 +119,6 @@ cc_library( cc_library( name = "immediate_input_stream_handler", srcs = ["immediate_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -142,7 +129,6 @@ cc_library( name = "in_order_output_stream_handler", srcs = ["in_order_output_stream_handler.cc"], hdrs = ["in_order_output_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -160,7 +146,6 @@ cc_library( cc_library( name = "mux_input_stream_handler", srcs = ["mux_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/port:logging", @@ -173,7 +158,6 @@ cc_library( cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -192,7 +176,6 @@ cc_library( cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index e721afb02..e5de7f0c9 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -230,6 +230,43 @@ TEST_F(ImmediateInputStreamHandlerTest, StreamDoneReady) { input_stream_handler_->ClearCurrentInputs(cc_); } +// This test checks that the state is ReadyForClose after all streams reach +// Timestamp::Max. +TEST_F(ImmediateInputStreamHandlerTest, ReadyForCloseAfterTimestampMax) { + Timestamp min_stream_timestamp; + std::list packets; + + // One packet arrives, ready for process. + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(10))); + input_stream_handler_->AddPackets(name_to_id_["input_a"], packets); + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp(10), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + // No packets arrive, not ready. + EXPECT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Unset(), cc_->InputTimestamp()); + + // Timestamp::Max arrives, ready for close. + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_a"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_b"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_c"], Timestamp::Max().NextAllowedInStream()); + + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Done(), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); +} + // This test checks that when any stream is done, the state is ready to close. TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) { Timestamp min_stream_timestamp; diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index e54fb2177..52d04b4b1 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,12 +299,12 @@ mediapipe_cc_test( data = [":node_chain_subgraph.proto"], requires_full_emulation = False, deps = [ + ":node_chain_subgraph_cc_proto", ":options_field_util", ":options_registry", ":options_syntax_util", ":options_util", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:basic_types_registration", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -312,6 +312,7 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", @@ -486,7 +487,6 @@ cc_library( deps = [ ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", @@ -738,9 +738,7 @@ cc_test( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:graph_service_manager", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:status_handler", @@ -923,7 +921,6 @@ cc_test( "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:subgraph", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 6433c93d2..c7ed063e0 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -258,11 +258,8 @@ std::string GetTestFilePath(absl::string_view relative_path) { return file::JoinPath(GetTestRootDir(), relative_path); } -absl::StatusOr> LoadTestImage( - absl::string_view path, ImageFormat::Format format) { - std::string encoded; - MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); - +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format) { // stbi_load determines the output pixel format based on the desired channels. // 0 means "use whatever's in the file". int desired_channels = format == ImageFormat::UNKNOWN ? 0 @@ -274,10 +271,10 @@ absl::StatusOr> LoadTestImage( << "unsupported output format requested: " << format; int width, height, channels_in_file; - auto data = stbi_load_from_memory(reinterpret_cast(encoded.data()), - encoded.size(), &width, &height, - &channels_in_file, desired_channels); - RET_CHECK(data) << "failed to decode image data from: " << path; + auto data = stbi_load_from_memory( + reinterpret_cast(encoded.data()), encoded.size(), &width, + &height, &channels_in_file, desired_channels); + RET_CHECK(data) << "failed to decode image data"; // If we didn't specify a desired format, it will be determined by what the // file contains. @@ -295,6 +292,13 @@ absl::StatusOr> LoadTestImage( format, width, height, width * output_channels, data, stbi_image_free); } +absl::StatusOr> LoadTestImage( + absl::string_view path, ImageFormat::Format format) { + std::string encoded; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); + return DecodeTestImage(encoded, format); +} + std::unique_ptr LoadTestPng(absl::string_view path, ImageFormat::Format format) { return nullptr; diff --git a/mediapipe/framework/tool/test_util.h b/mediapipe/framework/tool/test_util.h index 71c096db7..80b768e3d 100644 --- a/mediapipe/framework/tool/test_util.h +++ b/mediapipe/framework/tool/test_util.h @@ -81,6 +81,10 @@ std::string GetTestDataDir(absl::string_view package_base_path); // Loads a binary graph from path. Returns true iff successful. bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path); +// Loads an image from memory. +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format = ImageFormat::SRGBA); + // Loads an image from path. absl::StatusOr> LoadTestImage( absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA); diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 16aad6e9b..01e3da83e 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -1048,6 +1048,14 @@ absl::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( for (const auto& required_item : required_side_packets_) { auto iter = side_packet_types.find(required_item.first); if (iter == side_packet_types.end()) { + bool is_optional = true; + for (int index : required_item.second) { + is_optional &= input_side_packets_[index].packet_type->IsOptional(); + } + if (is_optional) { + // Side packets that are optional and not provided are ignored. + continue; + } statuses.push_back(mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << required_item.first << "\" is required but was not provided."); diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9c2f47469..7a8aa6557 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -176,6 +176,16 @@ cc_library( "-fobjc-arc", # enable reference-counting ], }), + linkopts = select({ + "//conditions:default": [], + "//mediapipe:ios": [ + "-framework OpenGLES", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework AppKit", + ], + }), visibility = ["//visibility:public"], deps = [ ":attachments", @@ -204,8 +214,10 @@ cc_library( }) + select({ "//conditions:default": [ ], - "//mediapipe:ios": [], - "//mediapipe:macos": [], + "//mediapipe:ios": [ + ], + "//mediapipe:macos": [ + ], }), ) @@ -221,12 +233,18 @@ cc_library( ":gpu_buffer_format", ":gpu_buffer_storage", ":gpu_buffer_storage_image_frame", + "@com_google_absl//absl/memory", # TODO: remove this dependency. Some other teams' tests # depend on having an indirect image_frame dependency, need to be # fixed first. "//mediapipe/framework/formats:image_frame", - "@com_google_absl//absl/memory", - ], + ] + select({ + "//conditions:default": [], + ":platform_ios_with_gpu": [ + ":gl_texture_util", + ":gpu_buffer_storage_cv_pixel_buffer", + ], + }), ) cc_library( @@ -344,6 +362,60 @@ cc_library( ], ) +mediapipe_cc_test( + name = "gpu_buffer_storage_cv_pixel_buffer_test", + size = "small", + timeout = "moderate", + srcs = ["gpu_buffer_storage_cv_pixel_buffer_test.cc"], + platforms = ["ios"], + deps = [ + ":gl_texture_buffer", + ":gl_texture_util", + ":gpu_buffer", + ":gpu_buffer_storage_cv_pixel_buffer", + ":gpu_test_base", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:test_util", + "//mediapipe/objc:util", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cv_texture_cache_manager", + srcs = ["cv_texture_cache_manager.cc"], + hdrs = ["cv_texture_cache_manager.h"], + deps = [ + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "cv_pixel_buffer_pool_wrapper", + srcs = ["cv_pixel_buffer_pool_wrapper.cc"], + hdrs = ["cv_pixel_buffer_pool_wrapper.h"], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", + ], + }), + deps = [ + ":cv_texture_cache_manager", + ":gpu_buffer_format", + ":multi_pool", + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "//mediapipe/objc:util", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_buffer_storage_image_frame", hdrs = ["gpu_buffer_storage_image_frame.h"], @@ -410,12 +482,9 @@ objc_library( ) objc_library( - name = "MPPGraphGPUData", - srcs = [ - "MPPGraphGPUData.mm", - "gpu_shared_data_internal.cc", - ], - hdrs = ["MPPGraphGPUData.h"], + name = "metal_shared_resources", + srcs = ["metal_shared_resources.mm"], + hdrs = ["metal_shared_resources.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", @@ -424,24 +493,9 @@ objc_library( sdk_frameworks = [ "CoreVideo", "Metal", - ] + select({ - "//conditions:default": [ - "OpenGLES", - ], - "//mediapipe:macos": [ - "OpenGL", - "AppKit", - ], - }), + ], visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", - ":gpu_buffer_multi_pool", - ":gpu_shared_data_header", - ":graph_support", - "//mediapipe/gpu:gl_context_options_cc_proto", - "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", "@google_toolbox_for_mac//:GTM_Defines", ] + [ @@ -489,12 +543,7 @@ cc_library( name = "gpu_shared_data_header", textual_hdrs = [ "gpu_shared_data_internal.h", - ] + select({ - "//conditions:default": [], - "//mediapipe:apple": [ - "MPPGraphGPUData.h", - ], - }), + ], visibility = ["//visibility:private"], deps = [ ":gl_base", @@ -528,16 +577,19 @@ cc_library( cc_library( name = "gpu_shared_data_internal_actual", - srcs = select({ - "//conditions:default": [ - "gpu_shared_data_internal.cc", - ], - # iOS uses an Objective-C++ version of this, built in MPPGraphGPUData. - "//mediapipe:apple": [], - }), + srcs = [ + "gpu_shared_data_internal.cc", + ], hdrs = [ "gpu_shared_data_internal.h", ], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + }), visibility = ["//visibility:private"], deps = [ "//mediapipe/gpu:gl_context_options_cc_proto", @@ -554,7 +606,8 @@ cc_library( ] + select({ "//conditions:default": [], "//mediapipe:apple": [ - ":MPPGraphGPUData", + ":metal_shared_resources", + ":cv_texture_cache_manager", ], }), ) @@ -569,6 +622,8 @@ cc_library( ":gl_texture_buffer", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", + ":reusable_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -577,6 +632,22 @@ cc_library( ], ) +cc_library( + name = "reusable_pool", + hdrs = ["reusable_pool.h"], + deps = [ + ":multi_pool", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "multi_pool", + hdrs = ["multi_pool.h"], + deps = ["//mediapipe/util:resource_cache"], +) + cc_library( name = "gpu_buffer_multi_pool", srcs = ["gpu_buffer_multi_pool.cc"], @@ -604,6 +675,7 @@ cc_library( ":gl_base", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -617,11 +689,15 @@ cc_library( ":gl_texture_buffer_pool", ], "//mediapipe:ios": [ + ":cv_pixel_buffer_pool_wrapper", + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], "//mediapipe:macos": [ + ":cv_pixel_buffer_pool_wrapper", + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", ":gl_texture_buffer", ":gl_texture_buffer_pool", @@ -629,6 +705,17 @@ cc_library( }), ) +cc_library( + name = "gl_texture_util", + srcs = ["gl_texture_util.cc"], + hdrs = ["gl_texture_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":gl_base", + ":gl_texture_view", + ], +) + cc_library( name = "shader_util", srcs = ["shader_util.cc"], @@ -653,11 +740,9 @@ cc_library( name = "gl_calculator_helper", srcs = [ "gl_calculator_helper.cc", - "gl_calculator_helper_impl_common.cc", ], hdrs = [ "gl_calculator_helper.h", - "gl_calculator_helper_impl.h", ], linkopts = select({ "//conditions:default": [], @@ -689,7 +774,7 @@ cc_library( ":image_frame_view", ":shader_util", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_cc_proto", + "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_contract", @@ -715,20 +800,6 @@ cc_library( }), ) -# TODO: remove -objc_library( - name = "gl_calculator_helper_ios", - copts = [ - "-Wno-shorten-64-to-32", - ], - visibility = ["//visibility:public"], - deps = [ - ":gl_calculator_helper", - "//mediapipe/objc:mediapipe_framework_ios", - "//mediapipe/objc:util", - ], -) - objc_library( name = "MPPMetalHelper", srcs = ["MPPMetalHelper.mm"], @@ -821,6 +892,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", + ":gpu_buffer_storage_image_frame", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", @@ -1062,8 +1135,8 @@ objc_library( name = "gl_ios_test_lib", testonly = 1, srcs = [ - "MPPGraphGPUDataTests.mm", "gl_ios_test.mm", + "metal_shared_resources_test.mm", ], copts = [ "-Wno-shorten-64-to-32", @@ -1073,7 +1146,7 @@ objc_library( ], features = ["-layering_check"], deps = [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":gl_scaler_calculator", ":gpu_buffer_to_image_frame_calculator", ":gpu_shared_data_internal", diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h deleted file mode 100644 index 3d8fc0c94..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ -#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ - -#import -#import -#import - -#import "mediapipe/gpu/gl_base.h" -#import "mediapipe/gpu/gl_context.h" - -namespace mediapipe { -class GlContext; -class GpuBufferMultiPool; -} // namespace mediapipe - -@interface MPPGraphGPUData : NSObject { - // Shared buffer pool for GPU calculators. - mediapipe::GpuBufferMultiPool* _gpuBufferPool; - mediapipe::GlContext* _glContext; -} - -- (instancetype)init NS_UNAVAILABLE; - -/// Initialize. The provided multipool pointer must remain valid throughout -/// this object's lifetime. -- (instancetype)initWithContext:(mediapipe::GlContext*)context - multiPool:(mediapipe::GpuBufferMultiPool*)pool NS_DESIGNATED_INITIALIZER; - -/// Shared texture pool for GPU calculators. -/// For internal use by GlCalculatorHelper. -@property(readonly) mediapipe::GpuBufferMultiPool* gpuBufferPool; - -/// Shared OpenGL context. -#if TARGET_OS_OSX -@property(readonly) NSOpenGLContext* glContext; -@property(readonly) NSOpenGLPixelFormat* glPixelFormat; -#else -@property(readonly) EAGLContext* glContext; -#endif // TARGET_OS_OSX - -/// Shared texture cache. -#if TARGET_OS_OSX -@property(readonly) CVOpenGLTextureCacheRef textureCache; -#else -@property(readonly) CVOpenGLESTextureCacheRef textureCache; -#endif // TARGET_OS_OSX - -/// Shared Metal resources. -@property(readonly) id mtlDevice; -@property(readonly) id mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@property(readonly) CVMetalTextureCacheRef mtlTextureCache; -#endif - -@end - -#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm deleted file mode 100644 index 8ac1eefa5..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/gpu/MPPGraphGPUData.h" - -#import "GTMDefines.h" - -#include "mediapipe/gpu/gl_context.h" -#include "mediapipe/gpu/gpu_buffer_multi_pool.h" - -#if TARGET_OS_OSX -#import -#else -#import -#endif // TARGET_OS_OSX - -@implementation MPPGraphGPUData - -@synthesize textureCache = _textureCache; -@synthesize mtlDevice = _mtlDevice; -@synthesize mtlCommandQueue = _mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@synthesize mtlTextureCache = _mtlTextureCache; -#endif - -#if TARGET_OS_OSX -typedef CVOpenGLTextureCacheRef CVTextureCacheType; -#else -typedef CVOpenGLESTextureCacheRef CVTextureCacheType; -#endif // TARGET_OS_OSX - -- (instancetype)initWithContext:(mediapipe::GlContext *)context - multiPool:(mediapipe::GpuBufferMultiPool *)pool { - self = [super init]; - if (self) { - _gpuBufferPool = pool; - _glContext = context; - } - return self; -} - -- (void)dealloc { - if (_textureCache) { - _textureCache = NULL; - } -#if COREVIDEO_SUPPORTS_METAL - if (_mtlTextureCache) { - CFRelease(_mtlTextureCache); - _mtlTextureCache = NULL; - } -#endif -} - -#if TARGET_OS_OSX -- (NSOpenGLContext *)glContext { - return _glContext->nsgl_context(); -} - -- (NSOpenGLPixelFormat *) glPixelFormat { - return _glContext->nsgl_pixel_format(); -} -#else -- (EAGLContext *)glContext { - return _glContext->eagl_context(); -} -#endif // TARGET_OS_OSX - -- (CVTextureCacheType)textureCache { - @synchronized(self) { - if (!_textureCache) { - _textureCache = _glContext->cv_texture_cache(); - } - } - return _textureCache; -} - -- (mediapipe::GpuBufferMultiPool *)gpuBufferPool { - return _gpuBufferPool; -} - -- (id)mtlDevice { - @synchronized(self) { - if (!_mtlDevice) { - _mtlDevice = MTLCreateSystemDefaultDevice(); - } - } - return _mtlDevice; -} - -- (id)mtlCommandQueue { - @synchronized(self) { - if (!_mtlCommandQueue) { - _mtlCommandQueue = [self.mtlDevice newCommandQueue]; - } - } - return _mtlCommandQueue; -} - -#if COREVIDEO_SUPPORTS_METAL -- (CVMetalTextureCacheRef)mtlTextureCache { - @synchronized(self) { - if (!_mtlTextureCache) { - CVReturn __unused err = - CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); - NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); - // TODO: register and flush metal caches too. - } - } - return _mtlTextureCache; -} -#endif - -@end diff --git a/mediapipe/gpu/MPPGraphGPUDataTests.mm b/mediapipe/gpu/MPPGraphGPUDataTests.mm deleted file mode 100644 index e8b50845b..000000000 --- a/mediapipe/gpu/MPPGraphGPUDataTests.mm +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/port/threadpool.h" - -#import "mediapipe/gpu/MPPGraphGPUData.h" -#import "mediapipe/gpu/gpu_shared_data_internal.h" - -@interface MPPGraphGPUDataTests : XCTestCase { -} -@end - -@implementation MPPGraphGPUDataTests - -// This test verifies that the internal Objective-C object is correctly -// released when the C++ wrapper is released. -- (void)testCorrectlyReleased { - __weak id gpuData = nil; - std::weak_ptr gpuRes; - @autoreleasepool { - mediapipe::GpuSharedData gpu_shared; - gpuRes = gpu_shared.gpu_resources; - gpuData = gpu_shared.gpu_resources->ios_gpu_data(); - XCTAssertNotEqual(gpuRes.lock(), nullptr); - XCTAssertNotNil(gpuData); - } - XCTAssertEqual(gpuRes.lock(), nullptr); - XCTAssertNil(gpuData); -} - -// This test verifies that the lazy initialization of the glContext instance -// variable is thread-safe. All threads should read the same value. -- (void)testGlContextThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - EAGLContext* ogl_context[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &ogl_context, i] { - ogl_context[i] = gpu_shared.gpu_resources->ios_gpu_data().glContext; - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(ogl_context[i], ogl_context[i + 1]); - } -} - -// This test verifies that the lazy initialization of the textureCache instance -// variable is thread-safe. All threads should read the same value. -- (void)testTextureCacheThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - CFHolder texture_cache[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &texture_cache, i] { - texture_cache[i].reset(gpu_shared.gpu_resources->ios_gpu_data().textureCache); - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(*texture_cache[i], *texture_cache[i + 1]); - } -} - -@end diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h index f3662422e..6ae0f3cf9 100644 --- a/mediapipe/gpu/MPPMetalHelper.h +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -21,37 +21,35 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" NS_ASSUME_NONNULL_BEGIN @interface MPPMetalHelper : NSObject { - MPPGraphGPUData* _gpuShared; } - (instancetype)init NS_UNAVAILABLE; /// Initialize. This initializer is recommended for calculators. -- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc; +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext *)cc; /// Initialize. -- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources +- (instancetype)initWithGpuResources:(mediapipe::GpuResources *)gpuResources NS_DESIGNATED_INITIALIZER; /// Configures a calculator's contract for accessing GPU resources. /// Calculators should use this in GetContract. -+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc; ++ (absl::Status)updateContract:(mediapipe::CalculatorContract *)cc; /// Deprecated initializer. -- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet &)inputSidePackets; /// Deprecated initializer. -- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared; +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData *)gpuShared; /// Configures a calculator's side packets for accessing GPU resources. /// Calculators should use this in FillExpectations. -+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet *)inputSidePackets; /// Get a metal command buffer. /// Calculators should use this method instead of getting a buffer from the @@ -63,23 +61,23 @@ NS_ASSUME_NONNULL_BEGIN /// Creates a CVMetalTextureRef linked to the provided GpuBuffer. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Creates a CVMetalTextureRef linked to the provided GpuBuffer given a specific plane. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Returns a MTLTexture linked to the provided GpuBuffer. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Returns a MTLTexture linked to the provided GpuBuffer given a specific plane. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Obtains a new GpuBuffer to be used as an output destination. @@ -91,7 +89,7 @@ NS_ASSUME_NONNULL_BEGIN format:(mediapipe::GpuBufferFormat)format; /// Convenience method to load a Metal library stored as a bundle resource. -- (id)newLibraryWithResourceName:(NSString*)name error:(NSError* _Nullable*)error; +- (id)newLibraryWithResourceName:(NSString *)name error:(NSError *_Nullable *)error; /// Shared Metal resources. @property(readonly) id mtlDevice; diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index ce6620972..1acf7cbfb 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,11 +14,18 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/graph_support.h" +#import "mediapipe/gpu/metal_shared_resources.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" +@interface MPPMetalHelper () { + mediapipe::GpuResources* _gpuResources; +} +@end + namespace mediapipe { // Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. @@ -40,7 +47,7 @@ class MetalHelperLegacySupport { - (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources { self = [super init]; if (self) { - _gpuShared = gpuResources->ios_gpu_data(); + _gpuResources = gpuResources; } return self; } @@ -105,19 +112,19 @@ class MetalHelperLegacySupport { } - (id)mtlDevice { - return _gpuShared.mtlDevice; + return _gpuResources->metal_shared().resources().mtlDevice; } - (id)mtlCommandQueue { - return _gpuShared.mtlCommandQueue; + return _gpuResources->metal_shared().resources().mtlCommandQueue; } - (CVMetalTextureCacheRef)mtlTextureCache { - return _gpuShared.mtlTextureCache; + return _gpuResources->metal_shared().resources().mtlTextureCache; } - (id)commandBuffer { - return [_gpuShared.mtlCommandQueue commandBuffer]; + return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; } - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer @@ -169,8 +176,9 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( - NULL, _gpuShared.mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, - metalPixelFormat, width, height, plane, &texture); + NULL, _gpuResources->metal_shared().resources().mtlTextureCache, + mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, + &texture); CHECK_EQ(err, kCVReturnSuccess); return texture; } @@ -191,19 +199,20 @@ class MetalHelperLegacySupport { } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { - return _gpuShared.gpuBufferPool->GetBuffer(width, height); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height format:(mediapipe::GpuBufferFormat)format { - return _gpuShared.gpuBufferPool->GetBuffer(width, height, format); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); } - (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { - return [_gpuShared.mtlDevice newLibraryWithFile:[[NSBundle bundleForClass:[self class]] - pathForResource:name ofType:@"metallib"] - error:error]; + return [_gpuResources->metal_shared().resources().mtlDevice + newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name + ofType:@"metallib"] + error:error]; } @end diff --git a/mediapipe/gpu/attachments.h b/mediapipe/gpu/attachments.h index ca9f074c4..3a73e4676 100644 --- a/mediapipe/gpu/attachments.h +++ b/mediapipe/gpu/attachments.h @@ -31,8 +31,8 @@ class AttachmentBase {}; template class Attachment : public AttachmentBase { public: - using FactoryT = std::function(Context&)>; - Attachment(FactoryT factory) : factory_(factory) {} + using FactoryT = AttachmentPtr (*)(Context&); + explicit constexpr Attachment(FactoryT factory) : factory_(factory) {} Attachment(const Attachment&) = delete; Attachment(Attachment&&) = delete; diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc new file mode 100644 index 000000000..6e077ae6e --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -0,0 +1,84 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" + +#include + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/objc/CFHolder.h" +#include "mediapipe/objc/util.h" + +namespace mediapipe { + +CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( + int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, + CvTextureCacheManager* texture_caches) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + pool_ = MakeCFHolderAdopting( + /* keep count is 0 because the age param keeps buffers around anyway */ + CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); + texture_caches_ = texture_caches; +} + +CFHolder CvPixelBufferPoolWrapper::GetBuffer() { + CVPixelBufferRef buffer; + int threshold = 1; + NSMutableDictionary* auxAttributes = + [NSMutableDictionary dictionaryWithCapacity:1]; + CVReturn err; + bool tried_flushing = false; + while (1) { + auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, + &buffer); + if (err != kCVReturnWouldExceedAllocationThreshold) break; + if (texture_caches_ && !tried_flushing) { + // Call the flush function to potentially release old holds on buffers + // and try again to create a pixel buffer. + // This is used to flush CV texture caches, which may retain buffers until + // flushed. + texture_caches_->FlushTextureCaches(); + tried_flushing = true; + } else { + ++threshold; + } + } + CHECK(!err) << "Error creating pixel buffer: " << err; + count_ = threshold; + return MakeCFHolderAdopting(buffer); +} + +std::string CvPixelBufferPoolWrapper::GetDebugString() const { + auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); + return [(__bridge NSString*)*description UTF8String]; +} + +void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } + +CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( + const internal::GpuBufferSpec& spec) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + CVPixelBufferRef buffer; + CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, + cv_format, &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + return MakeCFHolderAdopting(buffer); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h new file mode 100644 index 000000000..4d71adbf2 --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -0,0 +1,66 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ +#define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/gpu/cv_texture_cache_manager.h" +#include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/multi_pool.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvPixelBufferPoolWrapper { + public: + CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, + CFTimeInterval maxAge, + CvTextureCacheManager* texture_caches); + + static std::shared_ptr Create( + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options, + CvTextureCacheManager* texture_caches = nullptr) { + return std::make_shared( + spec.width, spec.height, spec.format, options.max_inactive_buffer_age, + texture_caches); + } + + CFHolder GetBuffer(); + + int GetBufferCount() const { return count_; } + std::string GetDebugString() const; + + void Flush(); + + static CFHolder CreateBufferWithoutPool( + const internal::GpuBufferSpec& spec); + + private: + CFHolder pool_; + int count_ = 0; + CvTextureCacheManager* texture_caches_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ diff --git a/mediapipe/gpu/cv_texture_cache_manager.cc b/mediapipe/gpu/cv_texture_cache_manager.cc new file mode 100644 index 000000000..b977a8993 --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.cc @@ -0,0 +1,55 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_texture_cache_manager.h" + +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +void CvTextureCacheManager::FlushTextureCaches() { + absl::MutexLock lock(&mutex_); + for (const auto& cache : texture_caches_) { +#if TARGET_OS_OSX + CVOpenGLTextureCacheFlush(*cache, 0); +#else + CVOpenGLESTextureCacheFlush(*cache, 0); +#endif // TARGET_OS_OSX + } +} + +void CvTextureCacheManager::RegisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) + << "Attempting to register a texture cache twice"; + texture_caches_.emplace_back(cache); +} + +void CvTextureCacheManager::UnregisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); + CHECK(it != texture_caches_.end()) + << "Attempting to unregister an unknown texture cache"; + texture_caches_.erase(it); +} + +CvTextureCacheManager::~CvTextureCacheManager() { + CHECK_EQ(texture_caches_.size(), 0) + << "Failed to unregister texture caches before deleting manager"; +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_texture_cache_manager.h b/mediapipe/gpu/cv_texture_cache_manager.h new file mode 100644 index 000000000..17e44fc6e --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.h @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ +#define MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvTextureCacheManager { + public: + ~CvTextureCacheManager(); + + // TODO: add tests for the texture cache registration. + + // Inform the pool of a cache that should be flushed when it is low on + // reusable buffers. + void RegisterTextureCache(CVTextureCacheType cache); + + // Remove a texture cache from the list of caches to be flushed. + void UnregisterTextureCache(CVTextureCacheType cache); + + void FlushTextureCaches(); + + private: + absl::Mutex mutex_; + std::vector> texture_caches_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index ba1423977..9b217ddfd 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -20,38 +20,37 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_service.h" namespace mediapipe { -// The constructor and destructor need to be defined here so that -// std::unique_ptr can see the full definition of GlCalculatorHelperImpl. -// In the header, it is an incomplete type. GlCalculatorHelper::GlCalculatorHelper() {} GlCalculatorHelper::~GlCalculatorHelper() {} +void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, + GpuResources* gpu_resources) { + gpu_resources_ = gpu_resources; + gl_context_ = gpu_resources_->gl_context(cc); +} + absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { CHECK(cc); auto gpu_service = cc->Service(kGpuService); RET_CHECK(gpu_service.IsAvailable()) << "GPU service not available. Did you forget to call " "GlCalculatorHelper::UpdateContract?"; - // TODO return error from impl_ (needs two-stage init) - impl_ = - absl::make_unique(cc, &gpu_service.GetObject()); + InitializeInternal(cc, &gpu_service.GetObject()); return absl::OkStatus(); } void GlCalculatorHelper::InitializeForTest(GpuSharedData* gpu_shared) { - impl_ = absl::make_unique( - nullptr, gpu_shared->gpu_resources.get()); + InitializeInternal(nullptr, gpu_shared->gpu_resources.get()); } void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { - impl_ = absl::make_unique(nullptr, gpu_resources); + InitializeInternal(nullptr, gpu_resources); } // static @@ -88,44 +87,109 @@ absl::Status GlCalculatorHelper::SetupInputSidePackets( return absl::OkStatus(); } +absl::Status GlCalculatorHelper::RunInGlContext( + std::function gl_func, + CalculatorContext* calculator_context) { + if (calculator_context) { + return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), + calculator_context->InputTimestamp()); + } else { + return gl_context_->Run(std::move(gl_func)); + } +} + absl::Status GlCalculatorHelper::RunInGlContext( std::function gl_func) { - if (!impl_) return absl::InternalError("helper not initialized"); + if (!Initialized()) return absl::InternalError("helper not initialized"); // TODO: Remove LegacyCalculatorSupport from MediaPipe OSS. auto calculator_context = LegacyCalculatorSupport::Scoped::current(); - return impl_->RunInGlContext(gl_func, calculator_context); + return RunInGlContext(gl_func, calculator_context); } -GLuint GlCalculatorHelper::framebuffer() const { return impl_->framebuffer(); } +GLuint GlCalculatorHelper::framebuffer() const { return framebuffer_; } + +void GlCalculatorHelper::CreateFramebuffer() { + // Our framebuffer will have a color attachment but no depth attachment, + // so it's important that the depth test be off. It is disabled by default, + // but we wanted to be explicit. + // TODO: move this to glBindFramebuffer? Or just remove. + glDisable(GL_DEPTH_TEST); + framebuffer_ = kUtilityFramebuffer.Get(*gl_context_); +} void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { - return impl_->BindFramebuffer(dst); +#ifdef __ANDROID__ + // On (some?) Android devices, attaching a new texture to the frame buffer + // does not seem to detach the old one. As a result, using that texture + // for texturing can produce incorrect output. See b/32091368 for details. + // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 + // or glFramebufferTexture2D with a texture ID of 0. + glBindFramebuffer(GL_FRAMEBUFFER, 0); +#endif + if (!framebuffer_) { + CreateFramebuffer(); + } + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, dst.width(), dst.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), + dst.name(), 0); + +#ifndef NDEBUG + GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + if (status != GL_FRAMEBUFFER_COMPLETE) { + VLOG(2) << "incomplete framebuffer: " << status; + } +#endif } -GlTexture GlCalculatorHelper::CreateSourceTexture( - const GpuBuffer& pixel_buffer) { - return impl_->CreateSourceTexture(pixel_buffer); +GlTexture GlCalculatorHelper::MapGpuBuffer(const GpuBuffer& gpu_buffer, + GlTextureView view) { + if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { + // TODO: do the params need to be reset here?? + glBindTexture(view.target(), view.name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + gpu_buffer.format(), view.plane(), GetGlVersion()); + gl_context_->SetStandardTextureParams(view.target(), + info.gl_internal_format); + glBindTexture(view.target(), 0); + } + + return GlTexture(std::move(view), gpu_buffer); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer) { + return CreateSourceTexture(gpu_buffer, 0); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer, + int plane) { + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const ImageFrame& image_frame) { - return impl_->CreateSourceTexture(image_frame); -} - -GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer, - int plane) { - return impl_->CreateSourceTexture(pixel_buffer, plane); + auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); } GpuBuffer GlCalculatorHelper::GpuBufferWithImageFrame( std::shared_ptr image_frame) { - return impl_->GpuBufferWithImageFrame(std::move(image_frame)); + return GpuBuffer( + std::make_shared(std::move(image_frame))); } GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( const ImageFrame& image_frame) { - return impl_->GpuBufferCopyingImageFrame(image_frame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); + // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only + // deals with absl::Status in MediaPipe OSS. + CHECK_OK(maybe_buffer.status()); + return GpuBuffer(std::move(maybe_buffer).value()); +#else + return GpuBuffer(GlTextureBuffer::Create(image_frame)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, @@ -136,23 +200,36 @@ void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, *height = pixel_buffer.height(); } -GlTexture GlCalculatorHelper::CreateDestinationTexture(int output_width, - int output_height, +GlTexture GlCalculatorHelper::CreateDestinationTexture(int width, int height, GpuBufferFormat format) { - return impl_->CreateDestinationTexture(output_width, output_height, format); -} + if (!framebuffer_) { + CreateFramebuffer(); + } -GlContext& GlCalculatorHelper::GetGlContext() const { - return impl_->GetGlContext(); -} - -GlVersion GlCalculatorHelper::GetGlVersion() const { - return impl_->GetGlVersion(); + GpuBuffer gpu_buffer = + gpu_resources_->gpu_buffer_pool().GetBuffer(width, height, format); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const mediapipe::Image& image) { - return impl_->CreateSourceTexture(image.GetGpuBuffer()); + return CreateSourceTexture(image.GetGpuBuffer()); +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + std::shared_ptr view = + gpu_buffer_.GetReadView(); + auto copy = absl::make_unique(); + copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); + return copy; +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + return absl::make_unique(gpu_buffer_); } template <> diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index e44523202..af897bbe9 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -17,6 +17,7 @@ #include +#include "absl/base/attributes.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" @@ -33,7 +34,6 @@ namespace mediapipe { -class GlCalculatorHelperImpl; class GlTexture; class GpuResources; struct GpuSharedData; @@ -62,6 +62,7 @@ class GlCalculatorHelper { // Can be used to initialize the helper outside of a calculator. Useful for // testing. void InitializeForTest(GpuResources* gpu_resources); + ABSL_DEPRECATED("Use InitializeForTest(GpuResources)") void InitializeForTest(GpuSharedData* gpu_shared); // This method can be called from GetContract to set up the needed GPU @@ -70,6 +71,7 @@ class GlCalculatorHelper { // This method can be called from FillExpectations to set the correct types // for the shared GL input side packet(s). + ABSL_DEPRECATED("Use UpdateContract") static absl::Status SetupInputSidePackets(PacketTypeSet* input_side_packets); // Execute the provided function within the helper's GL context. On some @@ -161,15 +163,30 @@ class GlCalculatorHelper { // TODO: do we need an unbind method too? void BindFramebuffer(const GlTexture& dst); - GlContext& GetGlContext() const; + GlContext& GetGlContext() const { return *gl_context_; } - GlVersion GetGlVersion() const; + GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } // Check if the calculator helper has been previously initialized. - bool Initialized() { return impl_ != nullptr; } + bool Initialized() { return gpu_resources_ != nullptr; } private: - std::unique_ptr impl_; + void InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources); + + absl::Status RunInGlContext(std::function gl_func, + CalculatorContext* calculator_context); + + // Makes a GpuBuffer accessible as a texture in the GL context. + GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); + + // Create the framebuffer for rendering. + void CreateFramebuffer(); + + std::shared_ptr gl_context_; + + GLuint framebuffer_ = 0; + + GpuResources* gpu_resources_ = nullptr; }; // Represents an OpenGL texture, and is a 'view' into the memory pool. @@ -201,9 +218,13 @@ class GlTexture { void Release() { view_ = std::make_shared(); } private: - explicit GlTexture(GlTextureView view) - : view_(std::make_shared(std::move(view))) {} - friend class GlCalculatorHelperImpl; + explicit GlTexture(GlTextureView view, GpuBuffer gpu_buffer) + : gpu_buffer_(std::move(gpu_buffer)), + view_(std::make_shared(std::move(view))) {} + friend class GlCalculatorHelper; + // We store the GpuBuffer to support GetFrame, and to ensure that the storage + // outlives the view. + GpuBuffer gpu_buffer_; std::shared_ptr view_; }; @@ -217,12 +238,14 @@ class GlTexture { // it is better to keep const-safety and accept having two versions of the // same thing. template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(const T& collection, const std::string& tag, int index) -> decltype(collection.Tag(tag)) { return collection.UsesTags() ? collection.Tag(tag) : collection.Index(index); } template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(T* collection, const std::string& tag, int index) -> decltype(collection->Tag(tag)) { return collection->UsesTags() ? collection->Tag(tag) @@ -230,12 +253,14 @@ auto TagOrIndex(T* collection, const std::string& tag, int index) } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(const T& collection, const std::string& tag, int index) { return collection.UsesTags() ? collection.HasTag(tag) : index < collection.NumEntries(); } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(T* collection, const std::string& tag, int index) { return collection->UsesTags() ? collection->HasTag(tag) : index < collection->NumEntries(); diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h deleted file mode 100644 index 72b3265fe..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ -#define MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ - -#include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" - -#ifdef __OBJC__ -#import -#import -#endif // __OBJC__ - -#ifdef __ANDROID__ -#include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif - -namespace mediapipe { - -// This class implements the GlCalculatorHelper for iOS and Android. -// See GlCalculatorHelper for details on these methods. -class GlCalculatorHelperImpl { - public: - explicit GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources); - ~GlCalculatorHelperImpl(); - - absl::Status RunInGlContext(std::function gl_func, - CalculatorContext* calculator_context); - - GlTexture CreateSourceTexture(const ImageFrame& image_frame); - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer); - - // Note: multi-plane support is currently only available on iOS. - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer, int plane); - - // Creates a framebuffer and returns the texture that it is bound to. - GlTexture CreateDestinationTexture(int output_width, int output_height, - GpuBufferFormat format); - - GpuBuffer GpuBufferWithImageFrame(std::shared_ptr image_frame); - GpuBuffer GpuBufferCopyingImageFrame(const ImageFrame& image_frame); - - GLuint framebuffer() const { return framebuffer_; } - void BindFramebuffer(const GlTexture& dst); - - GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } - - GlContext& GetGlContext() const; - - // For internal use. - static void ReadTexture(const GlTextureView& view, void* output, size_t size); - - private: - // Makes a GpuBuffer accessible as a texture in the GL context. - GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); - - // Create the framebuffer for rendering. - void CreateFramebuffer(); - - std::shared_ptr gl_context_; - - GLuint framebuffer_ = 0; - - GpuResources& gpu_resources_; -}; - -} // namespace mediapipe - -#endif // MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc deleted file mode 100644 index c5c028d4f..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" -#include "mediapipe/gpu/gpu_buffer_format.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -#include "mediapipe/gpu/image_frame_view.h" - -namespace mediapipe { - -GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources) - : gpu_resources_(*gpu_resources) { - gl_context_ = gpu_resources_.gl_context(cc); -} - -GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} - -GlContext& GlCalculatorHelperImpl::GetGlContext() const { return *gl_context_; } - -absl::Status GlCalculatorHelperImpl::RunInGlContext( - std::function gl_func, - CalculatorContext* calculator_context) { - if (calculator_context) { - return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), - calculator_context->InputTimestamp()); - } else { - return gl_context_->Run(std::move(gl_func)); - } -} - -void GlCalculatorHelperImpl::CreateFramebuffer() { - // Our framebuffer will have a color attachment but no depth attachment, - // so it's important that the depth test be off. It is disabled by default, - // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? - glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); -} - -void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) { -#ifdef __ANDROID__ - // On (some?) Android devices, attaching a new texture to the frame buffer - // does not seem to detach the old one. As a result, using that texture - // for texturing can produce incorrect output. See b/32091368 for details. - // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 - // or glFramebufferTexture2D with a texture ID of 0. - glBindFramebuffer(GL_FRAMEBUFFER, 0); -#endif - if (!framebuffer_) { - CreateFramebuffer(); - } - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - glViewport(0, 0, dst.width(), dst.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), - dst.name(), 0); - -#ifndef NDEBUG - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); - if (status != GL_FRAMEBUFFER_COMPLETE) { - VLOG(2) << "incomplete framebuffer: " << status; - } -#endif -} - -GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, - GlTextureView view) { - if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { - // TODO: do the params need to be reset here?? - glBindTexture(view.target(), view.name()); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - gpu_buffer.format(), view.plane(), GetGlVersion()); - gl_context_->SetStandardTextureParams(view.target(), - info.gl_internal_format); - glBindTexture(view.target(), 0); - } - - return GlTexture(std::move(view)); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer) { - return CreateSourceTexture(gpu_buffer, 0); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer, int plane) { - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const ImageFrame& image_frame) { - auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferWithImageFrame( - std::shared_ptr image_frame) { - return GpuBuffer( - std::make_shared(std::move(image_frame))); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferCopyingImageFrame( - const ImageFrame& image_frame) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); - // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only - // deals with absl::Status in MediaPipe OSS. - CHECK_OK(maybe_buffer.status()); - return GpuBuffer(std::move(maybe_buffer).value()); -#else - return GpuBuffer(GlTextureBuffer::Create(image_frame)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - view_->DoneWriting(); - std::shared_ptr view = - view_->gpu_buffer().GetReadView(); - auto copy = absl::make_unique(); - copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); - return copy; -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - auto gpu_buffer = view_->gpu_buffer(); -#ifdef __EMSCRIPTEN__ - // When WebGL is used, the GL context may be spontaneously lost which can - // cause GpuBuffer allocations to fail. In that case, return a dummy buffer - // to allow processing of the current frame complete. - if (!gpu_buffer) { - return std::make_unique(); - } -#endif // __EMSCRIPTEN__ - view_->DoneWriting(); - return absl::make_unique(gpu_buffer); -} - -GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( - int width, int height, GpuBufferFormat format) { - if (!framebuffer_) { - CreateFramebuffer(); - } - - GpuBuffer gpu_buffer = - gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); -} - -} // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 7f7ba0e23..99b995dda 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -290,8 +290,15 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // some Emscripten cases), there might be some existing tripped error. ForceClearExistingGlErrors(); - absl::string_view version_string( - reinterpret_cast(glGetString(GL_VERSION))); + absl::string_view version_string; + const GLubyte* version_string_ptr = glGetString(GL_VERSION); + if (version_string_ptr != nullptr) { + version_string = reinterpret_cast(version_string_ptr); + } else { + // This may happen when using SwiftShader, but the numeric versions are + // available and will be used instead. + LOG(WARNING) << "failed to get GL_VERSION string"; + } // We will decide later whether we want to use the version numbers we query // for, or instead derive that information from the context creation result, @@ -333,7 +340,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { } LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ - << " (" << glGetString(GL_VERSION) << ")"; + << " (" << version_string << ")"; { auto status = GetGlExtensions(); if (!status.ok()) { @@ -826,10 +833,14 @@ std::shared_ptr GlContext::CreateSyncToken() { return token; } -bool GlContext::IsAnyContextCurrent() { +PlatformGlContext GlContext::GetCurrentNativeContext() { ContextBinding ctx; GetCurrentContextBinding(&ctx); - return ctx.context != kPlatformGlContextNone; + return ctx.context; +} + +bool GlContext::IsAnyContextCurrent() { + return GetCurrentNativeContext() != kPlatformGlContextNone; } std::shared_ptr @@ -1043,4 +1054,16 @@ void GlContext::SetStandardTextureParams(GLenum target, GLint internal_format) { glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); } +const GlContext::Attachment kUtilityFramebuffer( + [](GlContext&) -> GlContext::Attachment::Ptr { + GLuint framebuffer; + glGenFramebuffers(1, &framebuffer); + if (!framebuffer) return nullptr; + return {new GLuint(framebuffer), [](void* ptr) { + GLuint* fb = static_cast(ptr); + glDeleteFramebuffers(1, fb); + delete fb; + }}; + }); + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 957cb510f..4f2390404 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -307,6 +307,10 @@ class GlContext : public std::enable_shared_from_this { // the GlContext class, is current. static bool IsAnyContextCurrent(); + // Returns the current native context, whether managed by this class or not. + // Useful as a cross-platform way to get the current PlatformGlContext. + static PlatformGlContext GetCurrentNativeContext(); + // Creates a synchronization token for the current, non-GlContext-owned // context. This can be passed to MediaPipe so it can synchronize with the // commands issued in the external context up to this point. @@ -470,6 +474,12 @@ class GlContext : public std::enable_shared_from_this { bool destructing_ = false; }; +// A framebuffer that the framework can use to attach textures for rendering +// etc. +// This could just be a member of GlContext, but it serves as a basic example +// of an attachment. +ABSL_CONST_INIT extern const GlContext::Attachment kUtilityFramebuffer; + // For backward compatibility. TODO: migrate remaining callers. ABSL_DEPRECATED( "Prefer passing an explicit GlVersion argument (use " diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 31500ed9a..ad867c2be 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -37,7 +37,6 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // VIDEO or index 0: GpuBuffers to be rendered. // Side inputs: // SURFACE: unique_ptr to an EglSurfaceHolder to draw to. -// GPU_SHARED: shared GPU resources. // // See GlSurfaceSinkCalculatorOptions for options. class GlSurfaceSinkCalculator : public Node { diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index fbb91a8f5..69b9889c7 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -15,9 +15,15 @@ #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/gl_texture_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + namespace mediapipe { std::unique_ptr GlTextureBuffer::Wrap( @@ -250,39 +256,46 @@ void GlTextureBuffer::WaitForConsumersOnGpu() { // precisely, on only one GL context. } -GlTextureView GlTextureBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { +GlTextureView GlTextureBuffer::GetReadView(internal::types, + int plane) const { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); - GlTextureView::DetachFn detach = [this](GlTextureView& texture) { - // Inform the GlTextureBuffer that we have finished accessing its - // contents, and create a consumer sync point. - DidRead(texture.gl_context()->CreateSyncToken()); - }; + GlTextureView::DetachFn detach = + [texbuf = shared_from_this()](GlTextureView& texture) { + // Inform the GlTextureBuffer that we have finished accessing its + // contents, and create a consumer sync point. + texbuf->DidRead(texture.gl_context()->CreateSyncToken()); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, std::move(detach), - nullptr); + plane, std::move(detach), nullptr); } -GlTextureView GlTextureBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { +GlTextureView GlTextureBuffer::GetWriteView(internal::types, + int plane) { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); Reuse(); // TODO: the producer wait should probably be part of Reuse in the // case when there are no consumers. GlTextureView::DoneWritingFn done_writing = - [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; + [texbuf = shared_from_this()](const GlTextureView& texture) { + texbuf->ViewDoneWriting(texture); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, nullptr, - std::move(done_writing)); + plane, nullptr, std::move(done_writing)); } void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { @@ -321,8 +334,8 @@ void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { #endif // __ANDROID__ } -static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, - void* output, size_t size) { +static void ReadTexture(GlContext& ctx, const GlTextureView& view, + GpuBufferFormat format, void* output, size_t size) { // TODO: check buffer size? We could use glReadnPixels where available // (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read // won't overflow the buffer with glReadPixels, we'd also need to check or @@ -332,13 +345,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); - GLint previous_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - - // We use a temp fbo to avoid depending on the app having an existing one. - // TODO: keep a utility fbo around in the context? - GLuint fbo = 0; - glGenFramebuffers(1, &fbo); + GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), view.name(), 0); @@ -346,9 +353,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, output); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, 0); - // TODO: just set the binding to 0 to avoid the get call? - glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); - glDeleteFramebuffers(1, &fbo); + glBindFramebuffer(GL_FRAMEBUFFER, 0); } static std::shared_ptr ConvertToImageFrame( @@ -358,9 +363,11 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - buf->GetProducerContext()->Run([buf, &output] { - auto view = buf->GetReadView(internal::types{}, nullptr, 0); - ReadTexture(view, buf->format(), output->MutablePixelData(), + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output, &ctx] { + auto view = buf->GetReadView(internal::types{}, /*plane=*/0); + ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); return std::make_shared(std::move(output)); @@ -380,4 +387,30 @@ static auto kConverterRegistration2 = .RegisterConverter( ConvertFromImageFrame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +static std::shared_ptr ConvertToCvPixelBuffer( + std::shared_ptr buf) { + auto output = absl::make_unique( + buf->width(), buf->height(), buf->format()); + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output] { + TempGlFramebuffer framebuffer; + auto src = buf->GetReadView(internal::types{}, /*plane=*/0); + auto dst = + output->GetWriteView(internal::types{}, /*plane=*/0); + CopyGlTexture(src, dst); + glFlush(); + }); + return output; +} + +static auto kConverterRegistrationCvpb = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertToCvPixelBuffer); + +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 124a0ec2f..f785571a1 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -35,7 +35,8 @@ class GlCalculatorHelperImpl; // Implements a GPU memory buffer as an OpenGL texture. For internal use. class GlTextureBuffer : public internal::GpuBufferStorageImpl< - GlTextureBuffer, internal::ViewProvider> { + GlTextureBuffer, internal::ViewProvider>, + public std::enable_shared_from_this { public: // This is called when the texture buffer is deleted. It is passed a sync // token created at that time on the GlContext. If the GlTextureBuffer has @@ -71,6 +72,11 @@ class GlTextureBuffer // Create a texture with a copy of the data in image_frame. static std::unique_ptr Create(const ImageFrame& image_frame); + static std::unique_ptr Create( + const internal::GpuBufferSpec& spec) { + return Create(spec.width, spec.height, spec.format); + } + // Wraps an existing texture, but does not take ownership of it. // deletion_callback is invoked when the GlTextureBuffer is released, so // the caller knows that the texture is no longer in use. @@ -90,10 +96,8 @@ class GlTextureBuffer GpuBufferFormat format() const { return format_; } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; // If this texture is going to be used outside of the context that produced @@ -138,6 +142,10 @@ class GlTextureBuffer return producer_context_; } +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + static constexpr bool kDisableGpuBufferRegistration = true; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + private: // Creates a texture of dimensions width x height and allocates space for it. // If data is provided, it is uploaded to the texture; otherwise, it can be diff --git a/mediapipe/gpu/gl_texture_buffer_pool.cc b/mediapipe/gpu/gl_texture_buffer_pool.cc index 3d5a8cdaa..599381a34 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.cc +++ b/mediapipe/gpu/gl_texture_buffer_pool.cc @@ -16,79 +16,4 @@ #include "absl/synchronization/mutex.h" -namespace mediapipe { - -GlTextureBufferPool::GlTextureBufferPool(int width, int height, - GpuBufferFormat format, int keep_count) - : width_(width), - height_(height), - format_(format), - keep_count_(keep_count) {} - -GlTextureBufferSharedPtr GlTextureBufferPool::GetBuffer() { - std::unique_ptr buffer; - bool reuse = false; - - { - absl::MutexLock lock(&mutex_); - if (available_.empty()) { - buffer = GlTextureBuffer::Create(width_, height_, format_); - if (!buffer) return nullptr; - } else { - buffer = std::move(available_.back()); - available_.pop_back(); - reuse = true; - } - - ++in_use_count_; - } - - // This needs to wait on consumer sync points, therefore it should not be - // done while holding the mutex. - if (reuse) { - buffer->Reuse(); - } - - // Return a shared_ptr with a custom deleter that adds the buffer back - // to our available list. - std::weak_ptr weak_pool(shared_from_this()); - return std::shared_ptr( - buffer.release(), [weak_pool](GlTextureBuffer* buf) { - auto pool = weak_pool.lock(); - if (pool) { - pool->Return(absl::WrapUnique(buf)); - } else { - delete buf; - } - }); -} - -std::pair GlTextureBufferPool::GetInUseAndAvailableCounts() { - absl::MutexLock lock(&mutex_); - return {in_use_count_, available_.size()}; -} - -void GlTextureBufferPool::Return(std::unique_ptr buf) { - std::vector> trimmed; - { - absl::MutexLock lock(&mutex_); - --in_use_count_; - available_.emplace_back(std::move(buf)); - TrimAvailable(&trimmed); - } - // The trimmed buffers will be released without holding the lock. -} - -void GlTextureBufferPool::TrimAvailable( - std::vector>* trimmed) { - int keep = std::max(keep_count_ - in_use_count_, 0); - if (available_.size() > keep) { - auto trim_it = std::next(available_.begin(), keep); - if (trimmed) { - std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); - } - available_.erase(trim_it, available_.end()); - } -} - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index 4dcad305e..726d0528d 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -23,11 +23,12 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gl_texture_buffer.h" +#include "mediapipe/gpu/multi_pool.h" +#include "mediapipe/gpu/reusable_pool.h" namespace mediapipe { -class GlTextureBufferPool - : public std::enable_shared_from_this { +class GlTextureBufferPool : public ReusablePool { public: // Creates a pool. This pool will manage buffers of the specified dimensions, // and will keep keep_count buffers around for reuse. @@ -36,42 +37,32 @@ class GlTextureBufferPool static std::shared_ptr Create(int width, int height, GpuBufferFormat format, int keep_count) { - return std::shared_ptr( - new GlTextureBufferPool(width, height, format, keep_count)); + return Create({width, height, format}, {.keep_count = keep_count}); } - // Obtains a buffers. May either be reused or created anew. - // A GlContext must be current when this is called. - GlTextureBufferSharedPtr GetBuffer(); + static std::shared_ptr Create( + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options) { + return std::shared_ptr( + new GlTextureBufferPool(spec, options)); + } - int width() const { return width_; } - int height() const { return height_; } - GpuBufferFormat format() const { return format_; } + int width() const { return spec_.width; } + int height() const { return spec_.height; } + GpuBufferFormat format() const { return spec_.format; } - // This method is meant for testing. - std::pair GetInUseAndAvailableCounts(); + static GlTextureBufferSharedPtr CreateBufferWithoutPool( + const internal::GpuBufferSpec& spec) { + return GlTextureBuffer::Create(spec); + } - private: - GlTextureBufferPool(int width, int height, GpuBufferFormat format, - int keep_count); + protected: + GlTextureBufferPool(const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) + : ReusablePool( + [this] { return GlTextureBuffer::Create(spec_); }, options), + spec_(spec) {} - // Return a buffer to the pool. - void Return(std::unique_ptr buf); - - // If the total number of buffers is greater than keep_count, destroys any - // surplus buffers that are no longer in use. - void TrimAvailable(std::vector>* trimmed) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - const int width_; - const int height_; - const GpuBufferFormat format_; - const int keep_count_; - - absl::Mutex mutex_; - int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; - std::vector> available_ - ABSL_GUARDED_BY(mutex_); + const internal::GpuBufferSpec spec_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_util.cc b/mediapipe/gpu/gl_texture_util.cc new file mode 100644 index 000000000..603e82a46 --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.cc @@ -0,0 +1,30 @@ +#include "mediapipe/gpu/gl_texture_util.h" + +namespace mediapipe { + +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { + glViewport(0, 0, src.width(), src.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), + src.name(), 0); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(dst.target(), dst.name()); + glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); + + glBindTexture(dst.target(), 0); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, + 0); +} + +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, + float a) { + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glClearColor(r, g, b, a); + glClear(GL_COLOR_BUFFER_BIT); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, + 0); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_util.h b/mediapipe/gpu/gl_texture_util.h new file mode 100644 index 000000000..73ac37ade --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.h @@ -0,0 +1,34 @@ +#ifndef MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ +#define MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ + +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_texture_view.h" + +namespace mediapipe { + +// Copies a texture to another. +// Assumes a framebuffer is already set up +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst); + +// Fills a texture with a color. +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, float a); + +// RAII class to set up a temporary framebuffer. Mainly for test use. +class TempGlFramebuffer { + public: + TempGlFramebuffer() { + glGenFramebuffers(1, &framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + } + ~TempGlFramebuffer() { + glBindFramebuffer(GL_FRAMEBUFFER, 0); + glDeleteFramebuffers(1, &framebuffer_); + } + + private: + GLuint framebuffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ diff --git a/mediapipe/gpu/gl_texture_view.cc b/mediapipe/gpu/gl_texture_view.cc index 5d1862ddc..cae4039a4 100644 --- a/mediapipe/gpu/gl_texture_view.cc +++ b/mediapipe/gpu/gl_texture_view.cc @@ -7,7 +7,6 @@ void GlTextureView::Release() { if (detach_) detach_(*this); detach_ = nullptr; gl_context_ = nullptr; - gpu_buffer_ = nullptr; plane_ = 0; name_ = 0; width_ = 0; diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 8b47d620b..8a257cf53 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -25,8 +25,6 @@ namespace mediapipe { class GlContext; -class GlTextureViewManager; -class GpuBuffer; class GlTextureView { public: @@ -43,7 +41,6 @@ class GlTextureView { name_ = other.name_; width_ = other.width_; height_ = other.height_; - gpu_buffer_ = std::move(other.gpu_buffer_); plane_ = other.plane_; detach_ = std::exchange(other.detach_, nullptr); done_writing_ = std::exchange(other.done_writing_, nullptr); @@ -55,26 +52,23 @@ class GlTextureView { int height() const { return height_; } GLenum target() const { return target_; } GLuint name() const { return name_; } - const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; } int plane() const { return plane_; } using DetachFn = std::function; using DoneWritingFn = std::function; private: - friend class GpuBuffer; friend class GlTextureBuffer; friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageAhwb; GlTextureView(GlContext* context, GLenum target, GLuint name, int width, - int height, std::shared_ptr gpu_buffer, int plane, - DetachFn detach, DoneWritingFn done_writing) + int height, int plane, DetachFn detach, + DoneWritingFn done_writing) : gl_context_(context), target_(target), name_(name), width_(width), height_(height), - gpu_buffer_(std::move(gpu_buffer)), plane_(plane), detach_(std::move(detach)), done_writing_(std::move(done_writing)) {} @@ -93,7 +87,6 @@ class GlTextureView { // Note: when scale is not 1, we still give the nominal size of the image. int width_ = 0; int height_ = 0; - std::shared_ptr gpu_buffer_; // using shared_ptr temporarily int plane_ = 0; DetachFn detach_; mutable DoneWritingFn done_writing_; @@ -112,12 +105,8 @@ class ViewProvider { // the same view implement the same signature. // Note that we allow different views to have custom signatures, providing // additional view-specific arguments that may be needed. - virtual GlTextureView GetReadView(types, - std::shared_ptr gpu_buffer, - int plane) const = 0; - virtual GlTextureView GetWriteView(types, - std::shared_ptr gpu_buffer, - int plane) = 0; + virtual GlTextureView GetReadView(types, int plane) const = 0; + virtual GlTextureView GetWriteView(types, int plane) = 0; }; } // namespace internal diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index e570ce8ba..388960b11 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -1,6 +1,7 @@ #include "mediapipe/gpu/gpu_buffer.h" #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const { "]"); } -internal::GpuBufferStorage& GpuBuffer::GetStorageForView( +internal::GpuBufferStorage* GpuBuffer::GetStorageForView( TypeId view_provider_type, bool for_writing) const { const std::shared_ptr* chosen_storage = nullptr; @@ -45,45 +46,58 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView( // TODO: choose best conversion. if (!chosen_storage) { for (const auto& s : storages_) { - auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider(view_provider_type, - s->storage_type()); - if (converter) { - storages_.push_back(converter(s)); - chosen_storage = &storages_.back(); + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + if (auto new_storage = converter(s)) { + storages_.push_back(new_storage); + chosen_storage = &storages_.back(); + break; + } } } } if (for_writing) { - if (!chosen_storage) { - // Allocate a new storage supporting the requested view. - auto factory = internal::GpuBufferStorageRegistry::Get() - .StorageFactoryForViewProvider(view_provider_type); - if (factory) { - storages_ = {factory(width(), height(), format())}; - chosen_storage = &storages_.back(); - } - } else { + if (chosen_storage) { // Discard all other storages. storages_ = {*chosen_storage}; chosen_storage = &storages_.back(); + } else { + // Allocate a new storage supporting the requested view. + if (auto factory = + internal::GpuBufferStorageRegistry::Get() + .StorageFactoryForViewProvider(view_provider_type)) { + if (auto new_storage = factory(width(), height(), format())) { + storages_ = {std::move(new_storage)}; + chosen_storage = &storages_.back(); + } + } } } + return chosen_storage ? chosen_storage->get() : nullptr; +} +internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( + TypeId view_provider_type, bool for_writing) const { + auto* chosen_storage = + GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " << absl::StrJoin(storages_, ", ", StorageTypeFormatter()); - DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); - return **chosen_storage; + DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); + return *chosen_storage; } #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer) { - auto p = buffer.internal_storage(); - if (p) return **p; + if (buffer.GetStorageForView( + kTypeId>, + /*for_writing=*/false) != nullptr) { + return *buffer.GetReadView(); + } return nullptr; } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 57e077151..56507d92f 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -105,18 +105,16 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetReadView(Args... args) const { - return GetViewProvider(false)->GetReadView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + return GetViewProviderOrDie(false).GetReadView( + internal::types{}, std::forward(args)...); } // Gets a write view of the specified type. The arguments depend on the // specific view type; see the corresponding ViewProvider. template decltype(auto) GetWriteView(Args... args) { - return GetViewProvider(true)->GetWriteView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + return GetViewProviderOrDie(true).GetWriteView( + internal::types{}, std::forward(args)...); } // Attempts to access an underlying storage object of the specified type. @@ -147,13 +145,17 @@ class GpuBuffer { GpuBufferFormat format_ = GpuBufferFormat::kUnknown; }; - internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, bool for_writing) const; + internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, + bool for_writing) const; + template - internal::ViewProvider* GetViewProvider(bool for_writing) const { + internal::ViewProvider& GetViewProviderOrDie(bool for_writing) const { using VP = internal::ViewProvider; - return GetStorageForView(kTypeId, for_writing).template down_cast(); + return *GetStorageForViewOrDie(kTypeId, for_writing) + .template down_cast(); } std::shared_ptr& no_storage() const { @@ -175,6 +177,10 @@ class GpuBuffer { // This is mutable because view methods that do not change the contents may // still need to allocate new storages. mutable std::vector> storages_; + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 45f054d31..06c5a0439 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -153,6 +153,34 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) { #endif // __APPLE__ +namespace internal { + +struct GpuBufferSpec { + GpuBufferSpec(int w, int h, GpuBufferFormat f) + : width(w), height(h), format(f) {} + + template + friend H AbslHashValue(H h, const GpuBufferSpec& spec) { + return H::combine(std::move(h), spec.width, spec.height, + static_cast(spec.format)); + } + + int width; + int height; + GpuBufferFormat format; +}; + +// BufferSpec equality operators +inline bool operator==(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return lhs.width == rhs.width && lhs.height == rhs.height && + lhs.format == rhs.format; +} +inline bool operator!=(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace internal + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_FORMAT_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 6e4fd38ea..e2ed523e4 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -16,204 +16,7 @@ #include -#include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -#include "CoreFoundation/CFBase.h" -#include "mediapipe/objc/CFHolder.h" -#include "mediapipe/objc/util.h" -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -namespace mediapipe { - -// Keep this many buffers allocated for a given frame size. -static constexpr int kKeepCount = 2; -// The maximum size of the GpuBufferMultiPool. When the limit is reached, the -// oldest BufferSpec will be dropped. -static constexpr int kMaxPoolCount = 10; -// Time in seconds after which an inactive buffer can be dropped from the pool. -// Currently only used with CVPixelBufferPool. -static constexpr float kMaxInactiveBufferAge = 0.25; -// Skip allocating a buffer pool until at least this many requests have been -// made for a given BufferSpec. -static constexpr int kMinRequestsBeforePool = 2; -// Do a deeper flush every this many requests. -static constexpr int kRequestCountScrubInterval = 50; - -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( - const GpuBufferMultiPool::BufferSpec& spec, CFTimeInterval maxAge) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - pool_ = MakeCFHolderAdopting( - /* keep count is 0 because the age param keeps buffers around anyway */ - CreateCVPixelBufferPool(spec.width, spec.height, cv_format, 0, maxAge)); -} - -GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { - CVPixelBufferRef buffer; - int threshold = 1; - NSMutableDictionary* auxAttributes = - [NSMutableDictionary dictionaryWithCapacity:1]; - CVReturn err; - bool tried_flushing = false; - while (1) { - auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); - err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( - kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, - &buffer); - if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush && !tried_flushing) { - // Call the flush function to potentially release old holds on buffers - // and try again to create a pixel buffer. - // This is used to flush CV texture caches, which may retain buffers until - // flushed. - flush(); - tried_flushing = true; - } else { - ++threshold; - } - } - CHECK(!err) << "Error creating pixel buffer: " << err; - count_ = threshold; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - -std::string CvPixelBufferPoolWrapper::GetDebugString() const { - auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); - return [(__bridge NSString*)*description UTF8String]; -} - -void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } - -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec) { - return std::make_shared(spec, - kMaxInactiveBufferAge); -} - -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - CVPixelBufferRef buffer; - CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, - cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - -void GpuBufferMultiPool::FlushTextureCaches() { - absl::MutexLock lock(&mutex_); - for (const auto& cache : texture_caches_) { -#if TARGET_OS_OSX - CVOpenGLTextureCacheFlush(*cache, 0); -#else - CVOpenGLESTextureCacheFlush(*cache, 0); -#endif // TARGET_OS_OSX - } -} - -// Turning this on disables the pixel buffer pools when using the simulator. -// It is no longer necessary, since the helper code now supports non-contiguous -// buffers. We leave the code in for now for the sake of documentation. -#define FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR 0 - -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { -#if TARGET_IPHONE_SIMULATOR && FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR - // On the simulator, syncing the texture with the pixelbuffer does not work, - // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not - // available in OpenGL ES 2, we should create the buffer so the pixels are - // contiguous. - // - // TODO: verify if we can use kIOSurfaceBytesPerRow to force the - // pool to give us contiguous data. - return GetBufferWithoutPool(spec); -#else - return pool->GetBuffer([this]() { FlushTextureCaches(); }); -#endif // TARGET_IPHONE_SIMULATOR -} - -#else - -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const BufferSpec& spec) { - return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - kKeepCount); -} - -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - return GpuBuffer( - GlTextureBuffer::Create(spec.width, spec.height, spec.format)); -} - -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool->GetBuffer()); -} - -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( - const BufferSpec& spec) { - SimplePool pool; - std::vector evicted; - { - absl::MutexLock lock(&mutex_); - pool = - cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { - return (request_count >= kMinRequestsBeforePool) - ? MakeSimplePool(spec) - : nullptr; - }); - evicted = cache_.Evict(kMaxPoolCount, kRequestCountScrubInterval); - } - // Evicted pools, and their buffers, will be released without holding the - // lock. - return pool; -} - -GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, - GpuBufferFormat format) { - BufferSpec key(width, height, format); - SimplePool pool = RequestPool(key); - if (pool) { - // Note: we release our multipool lock before accessing the simple pool. - return GetBufferFromSimplePool(key, pool); - } else { - return GetBufferWithoutPool(key); - } -} - -GpuBufferMultiPool::~GpuBufferMultiPool() { -#ifdef __APPLE__ - CHECK_EQ(texture_caches_.size(), 0) - << "Failed to unregister texture caches before deleting pool"; -#endif // defined(__APPLE__) -} - -#ifdef __APPLE__ -void GpuBufferMultiPool::RegisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == - texture_caches_.end()) - << "Attempting to register a texture cache twice"; - texture_caches_.emplace_back(cache); -} - -void GpuBufferMultiPool::UnregisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); - CHECK(it != texture_caches_.end()) - << "Attempting to unregister an unknown texture cache"; - texture_caches_.erase(it); -} -#endif // defined(__APPLE__) - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 5ea6e314f..827cf514a 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -22,120 +22,35 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ -#include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" -#include "mediapipe/util/resource_cache.h" +#include "mediapipe/gpu/multi_pool.h" -#ifdef __APPLE__ -#include "mediapipe/gpu/pixel_buffer_pool_util.h" -#endif // __APPLE__ - -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" +#else #include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER namespace mediapipe { -struct GpuSharedData; class CvPixelBufferPoolWrapper; -class GpuBufferMultiPool { - public: - GpuBufferMultiPool() {} - explicit GpuBufferMultiPool(void* ignored) {} - ~GpuBufferMultiPool(); - - // Obtains a buffer. May either be reused or created anew. - GpuBuffer GetBuffer(int width, int height, - GpuBufferFormat format = GpuBufferFormat::kBGRA32); - -#ifdef __APPLE__ - // TODO: add tests for the texture cache registration. - - // Inform the pool of a cache that should be flushed when it is low on - // reusable buffers. - void RegisterTextureCache(CVTextureCacheType cache); - - // Remove a texture cache from the list of caches to be flushed. - void UnregisterTextureCache(CVTextureCacheType cache); - - void FlushTextureCaches(); -#endif // defined(__APPLE__) - - // This class is not intended as part of the public api of this class. It is - // public only because it is used as a map key type, and the map - // implementation needs access to, e.g., the equality operator. - struct BufferSpec { - BufferSpec(int w, int h, mediapipe::GpuBufferFormat f) - : width(w), height(h), format(f) {} - - template - friend H AbslHashValue(H h, const BufferSpec& spec) { - return H::combine(std::move(h), spec.width, spec.height, - static_cast(spec.format)); - } - - int width; - int height; - mediapipe::GpuBufferFormat format; - }; - - private: +class GpuBufferMultiPool : public MultiPool< #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - using SimplePool = std::shared_ptr; + CvPixelBufferPoolWrapper, #else - using SimplePool = std::shared_ptr; + GlTextureBufferPool, #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - - SimplePool MakeSimplePool(const BufferSpec& spec); - // Requests a simple buffer pool for the given spec. This may return nullptr - // if we have not yet reached a sufficient number of requests to allocate a - // pool, in which case the caller should invoke GetBufferWithoutPool instead - // of GetBufferFromSimplePool. - SimplePool RequestPool(const BufferSpec& spec); - GpuBuffer GetBufferFromSimplePool(BufferSpec spec, const SimplePool& pool); - GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); - - absl::Mutex mutex_; - mediapipe::ResourceCache> - cache_ ABSL_GUARDED_BY(mutex_); - -#ifdef __APPLE__ - // Texture caches used with this pool. - std::vector> texture_caches_ - ABSL_GUARDED_BY(mutex_); -#endif // defined(__APPLE__) -}; - -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -class CvPixelBufferPoolWrapper { + internal::GpuBufferSpec, GpuBuffer> { public: - CvPixelBufferPoolWrapper(const GpuBufferMultiPool::BufferSpec& spec, - CFTimeInterval maxAge); - GpuBuffer GetBuffer(std::function flush); + using MultiPool::MultiPool; - int GetBufferCount() const { return count_; } - std::string GetDebugString() const; - - void Flush(); - - private: - CFHolder pool_; - int count_ = 0; + GpuBuffer GetBuffer(int width, int height, + GpuBufferFormat format = GpuBufferFormat::kBGRA32) { + return Get(internal::GpuBufferSpec(width, height, format)); + } }; -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -// BufferSpec equality operators -inline bool operator==(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return lhs.width == rhs.width && lhs.height == rhs.height && - lhs.format == rhs.format; -} -inline bool operator!=(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return !operator==(lhs, rhs); -} } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 3d872eb66..19661d930 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -13,22 +13,57 @@ #include "mediapipe/gpu/gpu_buffer_format.h" namespace mediapipe { -class GpuBuffer; namespace internal { template struct types {}; +// This template must be specialized for each view type V. Each specialization +// should define a pair of virtual methods called GetReadView and GetWriteView, +// whose first argument is a types tag object. The result type and optional +// further arguments will depend on the view type. +// +// Example: +// template <> +// class ViewProvider { +// public: +// virtual ~ViewProvider() = default; +// virtual MyView GetReadView(types) const = 0; +// virtual MyView GetWriteView(types) = 0; +// }; +// +// The additional arguments and result type are reflected in GpuBuffer's +// GetReadView and GetWriteView methods. +// +// Using a type tag for the first argument allows the methods to be overloaded, +// so that a single storage can implement provider methods for multiple views. +// Since these methods are not template methods, they can (and should) be +// virtual, which allows storage classes to override them, enforcing that all +// storages providing a given view type implement the same interface. template class ViewProvider; -// Interface for a backing storage for GpuBuffer. +// Generic interface for a backing storage for GpuBuffer. +// +// GpuBuffer is an opaque handle to an image. Its contents are handled by +// Storage classes. Application code does not interact with the storages +// directly; to access the data, it asks the GpuBuffer for a View, and in turn +// GpuBuffer looks for a storage that can provide that view. +// This architecture decouples application code from the underlying storage, +// making it possible to use platform-specific optimized storage systems, e.g. +// for zero-copy data sharing between CPU and GPU. +// +// Storage implementations should inherit from GpuBufferStorageImpl. See that +// class for details. class GpuBufferStorage { public: virtual ~GpuBufferStorage() = default; + + // Concrete storage types should override the following three accessors. virtual int width() const = 0; virtual int height() const = 0; virtual GpuBufferFormat format() const = 0; + // We can't use dynamic_cast since we want to support building without RTTI. // The public methods delegate to the type-erased private virtual method. template @@ -72,19 +107,33 @@ class GpuBufferStorageRegistry { return *registry; } + // Registers a storage type by automatically creating a factory for it. + // This is normally called by GpuBufferImpl. template RegistryToken Register() { - return Register( + return RegisterFactory( [](int width, int height, GpuBufferFormat format) -> std::shared_ptr { return CreateStorage(overload_priority<10>{}, width, height, format); - }, - Storage::GetProviderTypes()); + }); } + // Registers a new factory for a storage type. + template + RegistryToken RegisterFactory(F&& factory) { + if constexpr (kDisableRegistration) { + return {}; + } + return Register(factory, Storage::GetProviderTypes()); + } + + // Registers a new converter from storage type StorageFrom to StorageTo. template RegistryToken RegisterConverter(F&& converter) { + if constexpr (kDisableRegistration) { + return {}; + } return Register( [converter](std::shared_ptr source) -> std::shared_ptr { @@ -115,6 +164,13 @@ class GpuBufferStorageRegistry { return std::make_shared(args...); } + // Temporary workaround: a Storage class can define a static constexpr + // kDisableGpuBufferRegistration member to true to prevent registering any + // factory of converter that would produce it. + // TODO: better solution for storage priorities. + template + static constexpr bool kDisableRegistration = false; + RegistryToken Register(StorageFactory factory, std::vector provider_hashes); RegistryToken Register(StorageConverter converter, @@ -126,6 +182,13 @@ class GpuBufferStorageRegistry { converter_for_view_provider_and_existing_storage_; }; +// Putting this outside the class body to work around a GCC bug. +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=71954 +template +constexpr bool GpuBufferStorageRegistry::kDisableRegistration< + Storage, std::void_t> = + Storage::kDisableGpuBufferRegistration; + // Defining a member of this type causes P to be ODR-used, which forces its // instantiation if it's a static member of a template. template @@ -138,21 +201,41 @@ struct ForceStaticInstantiation { #endif // _MSC_VER }; -// T: storage type -// U...: ViewProvider +// Inherit from this class to define a new storage type. The storage type itself +// should be passed as the first template argument (CRTP), followed by one or +// more specializations of ViewProvider. +// +// Concrete storage types should implement the basic accessors from +// GpuBufferStorage, plus the view read/write getters for each ViewProvider they +// implement. This class handles the rest. +// +// Arguments: +// T: storage type +// U...: ViewProvider +// Example: +// class MyStorage : public GpuBufferStorageImpl< +// MyStorage, ViewProvider> template class GpuBufferStorageImpl : public GpuBufferStorage, public U... { public: static const std::vector& GetProviderTypes() { - static std::vector kHashes{kTypeId...}; - return kHashes; + static std::vector kProviderIds{kTypeId...}; + return kProviderIds; + } + + // Exposing this as a function allows dependent initializers to call this to + // ensure proper ordering. + static GpuBufferStorageRegistry::RegistryToken RegisterOnce() { + static auto registration = GpuBufferStorageRegistry::Get().Register(); + return registration; } private: - virtual const void* down_cast(TypeId to) const override { + // Allows a down_cast to any of the view provider types in U. + const void* down_cast(TypeId to) const final { return down_cast_impl(to, types{}); } - TypeId storage_type() const override { return kTypeId; } + TypeId storage_type() const final { return kTypeId; } const void* down_cast_impl(TypeId to, types<>) const { return nullptr; } template @@ -161,8 +244,7 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { return down_cast_impl(to, types{}); } - inline static auto registration = - GpuBufferStorageRegistry::Get().Register(); + inline static auto registration = RegisterOnce(); using RequireStatics = ForceStaticInstantiation<®istration>; }; diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index d68ac0db0..014cc1c69 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -26,8 +26,7 @@ GpuBufferStorageCvPixelBuffer::GpuBufferStorageCvPixelBuffer( } GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( - std::shared_ptr gpu_buffer, int plane, - GlTextureView::DoneWritingFn done_writing) const { + int plane, GlTextureView::DoneWritingFn done_writing) const { CVReturn err; auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); @@ -60,39 +59,20 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( cv_texture.adopt(cv_texture_temp); return GlTextureView( gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture), - CVOpenGLESTextureGetName(*cv_texture), width(), height(), - std::move(gpu_buffer), plane, + CVOpenGLESTextureGetName(*cv_texture), width(), height(), plane, [cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ }, done_writing); #endif // TARGET_OS_OSX } GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { - return GetTexture(std::move(gpu_buffer), plane, nullptr); + internal::types, int plane) const { + return GetTexture(plane, nullptr); } -GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { - return GetTexture( - std::move(gpu_buffer), plane, - [this](const mediapipe::GlTextureView& view) { ViewDoneWriting(view); }); -} - -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer) const { - return CreateImageFrameForCVPixelBuffer(**this); -} -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer) { - return CreateImageFrameForCVPixelBuffer(**this); -} - -void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { #if TARGET_IPHONE_SIMULATOR - CVPixelBufferRef pixel_buffer = **this; +static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, + const GlTextureView& view) { CHECK(pixel_buffer); CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) @@ -130,7 +110,30 @@ void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) << "CVPixelBufferUnlockBaseAddress failed: " << err; -#endif +} +#endif // TARGET_IPHONE_SIMULATOR + +GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, int plane) { + return GetTexture(plane, +#if TARGET_IPHONE_SIMULATOR + [pixel_buffer = CFHolder(*this)]( + const mediapipe::GlTextureView& view) { + ViewDoneWritingSimulatorWorkaround(*pixel_buffer, view); + } +#else + nullptr +#endif // TARGET_IPHONE_SIMULATOR + ); +} + +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types) const { + return CreateImageFrameForCVPixelBuffer(**this); +} +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types) { + return CreateImageFrameForCVPixelBuffer(**this); } static std::shared_ptr ConvertFromImageFrame( diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index 017771dc7..8723a1087 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -12,10 +12,25 @@ namespace mediapipe { class GlContext; +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual CFHolder GetReadView( + internal::types) const = 0; + virtual CFHolder GetWriteView( + internal::types) = 0; +}; + +} // namespace internal + class GpuBufferStorageCvPixelBuffer : public internal::GpuBufferStorageImpl< GpuBufferStorageCvPixelBuffer, internal::ViewProvider, - internal::ViewProvider>, + internal::ViewProvider, + internal::ViewProvider>, public CFHolder { public: using CFHolder::CFHolder; @@ -33,24 +48,32 @@ class GpuBufferStorageCvPixelBuffer CVPixelBufferGetPixelFormatType(**this)); } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override; + internal::types) const override; std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override; + internal::types) override; + CFHolder GetReadView( + internal::types) const override; + CFHolder GetWriteView( + internal::types) override; private: - GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, + GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; - void ViewDoneWriting(const GlTextureView& view); }; +inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types) const { + return *this; +} +inline CFHolder GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types) { + return *this; +} + namespace internal { // These functions enable backward-compatible construction of a GpuBuffer from // CVPixelBufferRef without having to expose that type in the main GpuBuffer diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index 2cea3445e..ab547b9ea 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -29,13 +29,11 @@ class GpuBufferStorageImageFrame std::shared_ptr image_frame() const { return image_frame_; } std::shared_ptr image_frame() { return image_frame_; } std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override { + internal::types) const override { return image_frame_; } std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override { + internal::types) override { return image_frame_; } diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 3fd519b21..145b71806 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -14,10 +14,13 @@ #include "mediapipe/gpu/gpu_buffer.h" +#include + #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" @@ -41,47 +44,6 @@ void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) { } } -// Assumes a framebuffer is already set up -void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { - glViewport(0, 0, src.width(), src.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), - src.name(), 0); - - glActiveTexture(GL_TEXTURE0); - glBindTexture(dst.target(), dst.name()); - glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); - - glBindTexture(dst.target(), 0); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, - 0); -} - -void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, - float a) { - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); - glClearColor(r, g, b, a); - glClear(GL_COLOR_BUFFER_BIT); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, - 0); -} - -class TempGlFramebuffer { - public: - TempGlFramebuffer() { - glGenFramebuffers(1, &framebuffer_); - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - } - ~TempGlFramebuffer() { - glBindFramebuffer(GL_FRAMEBUFFER, 0); - glDeleteFramebuffers(1, &framebuffer_); - } - - private: - GLuint framebuffer_; -}; - class GpuBufferTest : public GpuTestBase {}; TEST_F(GpuBufferTest, BasicTest) { @@ -127,7 +89,7 @@ TEST_F(GpuBufferTest, GlTextureView) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view")); } @@ -162,7 +124,7 @@ TEST_F(GpuBufferTest, ImageFrame) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view")); } @@ -196,7 +158,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view")); } @@ -230,7 +192,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame green(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(green, 0, 255, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, green, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view")); } @@ -240,11 +202,31 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame blue(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(blue, 0, 0, 255, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, blue, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view")); } } +TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + RunInGlContext([buffer = std::move(buffer)]() mutable { + // This is not a recommended pattern, but let's make sure that we don't + // crash if the buffer is released before the view. The view can hold + // callbacks into its underlying storage. + auto view = buffer.GetReadView(0); + buffer = nullptr; + }); + // We're really checking that we haven't crashed. + EXPECT_TRUE(true); +} + } // anonymous namespace } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index a8bf0c3a3..203a8dfd1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -21,7 +21,7 @@ #include "mediapipe/gpu/graph_support.h" #if __APPLE__ -#import "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/metal_shared_resources.h" #endif // __APPLE__ namespace mediapipe { @@ -80,28 +80,40 @@ GpuResources::StatusOrGpuResources GpuResources::Create( return gpu_resources; } -GpuResources::GpuResources(std::shared_ptr gl_context) { +GpuResources::GpuResources(std::shared_ptr gl_context) +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + : texture_caches_(std::make_shared()), + gpu_buffer_pool_( + [tc = texture_caches_](const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) { + return CvPixelBufferPoolWrapper::Create(spec, options, tc.get()); + }) +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +{ gl_key_context_[SharedContextKey()] = gl_context; named_executors_[kGpuExecutorName] = std::make_shared(gl_context.get()); #if __APPLE__ - gpu_buffer_pool().RegisterTextureCache(gl_context->cv_texture_cache()); - ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() - multiPool:&gpu_buffer_pool_]; +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + metal_shared_ = std::make_unique(); #endif // __APPLE__ } GpuResources::~GpuResources() { #if __APPLE__ - // Note: on Apple platforms, this object contains Objective-C objects. The - // destructor will release them, but ARC must be on. + // Note: on Apple platforms, this object contains Objective-C objects. + // The destructor will release them, but ARC must be on. #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER for (auto& kv : gl_key_context_) { - gpu_buffer_pool().UnregisterTextureCache(kv.second->cv_texture_cache()); + texture_caches_->UnregisterTextureCache(kv.second->cv_texture_cache()); } -#endif +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // __APPLE__ } absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { @@ -174,17 +186,43 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GlContext::Create(*gl_key_context_[SharedContextKey()], kGlContextUseDedicatedThread)); it = gl_key_context_.emplace(key, new_context).first; -#if __APPLE__ - gpu_buffer_pool_.RegisterTextureCache(it->second->cv_texture_cache()); -#endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(it->second->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } return it->second; } GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} -#if __APPLE__ -MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } -#endif // __APPLE__ +extern const GraphService kGpuService; + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +static std::shared_ptr GetGlTextureBufferFromPool( + int width, int height, GpuBufferFormat format) { + std::shared_ptr texture_buffer; + const auto cc = LegacyCalculatorSupport::Scoped::current(); + + if (cc && cc->Service(kGpuService).IsAvailable()) { + GpuBufferMultiPool* pool = + &cc->Service(kGpuService).GetObject().gpu_buffer_pool(); + // Note that the "gpu_buffer_pool" serves GlTextureBuffers on non-Apple + // platforms. TODO: refactor into storage pools. + texture_buffer = pool->GetBuffer(width, height, format) + .internal_storage(); + } else { + texture_buffer = GlTextureBuffer::Create(width, height, format); + } + return texture_buffer; +} + +static auto kGlTextureBufferPoolRegistration = [] { + // Ensure that the GlTextureBuffer's own factory is already registered, so we + // can override it. + GlTextureBuffer::RegisterOnce(); + return internal::GpuBufferStorageRegistry::Get() + .RegisterFactory(GetGlTextureBufferFromPool); +}(); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 62d6bb27e..3f7c67e2e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -30,15 +30,15 @@ #include "mediapipe/gpu/gpu_buffer_multi_pool.h" #ifdef __APPLE__ -#ifdef __OBJC__ -@class MPPGraphGPUData; -#else -struct MPPGraphGPUData; -#endif // __OBJC__ +#include "mediapipe/gpu/cv_texture_cache_manager.h" #endif // defined(__APPLE__) namespace mediapipe { +#ifdef __APPLE__ +class MetalSharedResources; +#endif // defined(__APPLE__) + // TODO: rename to GpuService or GpuManager or something. class GpuResources { public: @@ -55,9 +55,7 @@ class GpuResources { // Shared GL context for calculators. // TODO: require passing a context or node identifier. - const std::shared_ptr& gl_context() { - return gl_context(nullptr); - }; + const std::shared_ptr& gl_context() { return gl_context(nullptr); } const std::shared_ptr& gl_context(CalculatorContext* cc); @@ -65,7 +63,7 @@ class GpuResources { GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; } #ifdef __APPLE__ - MPPGraphGPUData* ios_gpu_data(); + MetalSharedResources& metal_shared() { return *metal_shared_; } #endif // defined(__APPLE__)§ absl::Status PrepareGpuNode(CalculatorNode* node); @@ -86,13 +84,16 @@ class GpuResources { std::map node_key_; std::map> gl_key_context_; +#ifdef MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + std::shared_ptr texture_caches_; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + // The pool must be destructed before the gl_context, but after the // ios_gpu_data, so the declaration order is important. GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - // Note that this is an Objective-C object. - MPPGraphGPUData* ios_gpu_data_; + std::unique_ptr metal_shared_; #endif // defined(__APPLE__) std::map> named_executors_; diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h index e9fd64725..6ec53603b 100644 --- a/mediapipe/gpu/gpu_test_base.h +++ b/mediapipe/gpu/gpu_test_base.h @@ -24,13 +24,14 @@ namespace mediapipe { class GpuTestBase : public ::testing::Test { protected: - GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); } + GpuTestBase() { helper_.InitializeForTest(gpu_resources_.get()); } void RunInGlContext(std::function gl_func) { helper_.RunInGlContext(std::move(gl_func)); } GpuSharedData gpu_shared_; + std::shared_ptr gpu_resources_ = gpu_shared_.gpu_resources; GlCalculatorHelper helper_; }; diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..c67fb0c62 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -12,73 +12,63 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif - namespace mediapipe { +namespace api2 { -// Convert ImageFrame to GpuBuffer. -class ImageFrameToGpuBufferCalculator : public CalculatorBase { +class ImageFrameToGpuBufferCalculator + : public RegisteredNode { public: - ImageFrameToGpuBufferCalculator() {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; - static absl::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; private: -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; -REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -absl::Status ImageFrameToGpuBufferCalculator::GetContract( +absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); // Note: we call this method even on platforms where we don't use the helper, // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. - MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return absl::OkStatus(); + return GlCalculatorHelper::UpdateContract(cc); } absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we always output at the same timestamp - // as we receive a packet at. - cc->SetOffset(TimestampDiff(0)); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - CFHolder buffer; - MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( - cc->Inputs().Index(0).Value(), &buffer)); - cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); -#else - const auto& input = cc->Inputs().Index(0).Get(); - helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); - }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket(kIn(cc).packet())); + auto gpu_buffer = api2::MakePacket( + std::make_shared( + std::move(image_frame))) + .At(cc->InputTimestamp()); + // This calculator's behavior has been to do the texture upload eagerly, and + // some graphs may rely on running this on a separate GL context to avoid + // blocking another context with the read operation. So let's request GPU + // access here to ensure that the behavior stays the same. + // TODO: have a better way to do this, or defer until later. + helper_.RunInGlContext( + [&gpu_buffer] { auto view = gpu_buffer->GetReadView(0); }); + kOut(cc).Send(std::move(gpu_buffer)); return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/gpu/image_frame_view.h b/mediapipe/gpu/image_frame_view.h index 2fc6f2495..b7e58a824 100644 --- a/mediapipe/gpu/image_frame_view.h +++ b/mediapipe/gpu/image_frame_view.h @@ -12,9 +12,8 @@ class ViewProvider { public: virtual ~ViewProvider() = default; virtual std::shared_ptr GetReadView( - types, std::shared_ptr gpu_buffer) const = 0; - virtual std::shared_ptr GetWriteView( - types, std::shared_ptr gpu_buffer) = 0; + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; }; } // namespace internal diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h new file mode 100644 index 000000000..e677c3bbf --- /dev/null +++ b/mediapipe/gpu/multi_pool.h @@ -0,0 +1,119 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_MULTI_POOL_H_ +#define MEDIAPIPE_GPU_MULTI_POOL_H_ + +#include "mediapipe/util/resource_cache.h" + +namespace mediapipe { + +struct MultiPoolOptions { + // Keep this many buffers allocated for a given frame size. + int keep_count = 2; + // The maximum size of the GpuBufferMultiPool. When the limit is reached, the + // oldest BufferSpec will be dropped. + int max_pool_count = 10; + // Time in seconds after which an inactive buffer can be dropped from the + // pool. Currently only used with CVPixelBufferPool. + float max_inactive_buffer_age = 0.25; + // Skip allocating a buffer pool until at least this many requests have been + // made for a given BufferSpec. + int min_requests_before_pool = 2; + // Do a deeper flush every this many requests. + int request_count_scrub_interval = 50; +}; + +static constexpr MultiPoolOptions kDefaultMultiPoolOptions; + +// MultiPool is a generic class for vending reusable resources of type Item, +// which are assumed to be relatively expensive to create, so that reusing them +// is beneficial. +// Items are classified by Spec; when an item with a given Spec is requested, +// an old Item with the same Spec can be reused, if available; otherwise a new +// Item will be created. When user code is done with an Item, it is returned +// to the pool for reuse. +// In order to manage this, a MultiPool contains a map of Specs to SimplePool; +// each SimplePool manages Items with the same Spec, which are thus considered +// interchangeable. +// Item retention and eviction policies are controlled by options. +// A concrete example would be a pool of GlTextureBuffer, grouped by dimensions +// and format. +template +class MultiPool { + public: + using SimplePoolFactory = std::function( + const Spec& spec, const MultiPoolOptions& options)>; + + MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, + MultiPoolOptions options = kDefaultMultiPoolOptions) + : create_simple_pool_(factory), options_(options) {} + explicit MultiPool(MultiPoolOptions options) + : MultiPool(DefaultMakeSimplePool, options) {} + + // Obtains an item. May either be reused or created anew. + Item Get(const Spec& spec); + + private: + static std::shared_ptr DefaultMakeSimplePool( + const Spec& spec, const MultiPoolOptions& options) { + return SimplePool::Create(spec, options); + } + + // Requests a simple buffer pool for the given spec. This may return nullptr + // if we have not yet reached a sufficient number of requests to allocate a + // pool, in which case the caller should invoke CreateBufferWithoutPool. + std::shared_ptr RequestPool(const Spec& spec); + + absl::Mutex mutex_; + mediapipe::ResourceCache> cache_ + ABSL_GUARDED_BY(mutex_); + SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; + MultiPoolOptions options_; +}; + +template +std::shared_ptr MultiPool::RequestPool( + const Spec& spec) { + std::shared_ptr pool; + std::vector> evicted; + { + absl::MutexLock lock(&mutex_); + pool = cache_.Lookup(spec, [this](const Spec& spec, int request_count) { + return (request_count >= options_.min_requests_before_pool) + ? create_simple_pool_(spec, options_) + : nullptr; + }); + evicted = cache_.Evict(options_.max_pool_count, + options_.request_count_scrub_interval); + } + // Evicted pools, and their buffers, will be released without holding the + // lock. + return pool; +} + +template +Item MultiPool::Get(const Spec& spec) { + std::shared_ptr pool = RequestPool(spec); + if (pool) { + // Note: we release our multipool lock before accessing the simple pool. + return Item(pool->GetBuffer()); + } else { + return Item(SimplePool::CreateBufferWithoutPool(spec)); + } +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_MULTI_POOL_H_ diff --git a/mediapipe/gpu/reusable_pool.h b/mediapipe/gpu/reusable_pool.h new file mode 100644 index 000000000..ddeaa5ba7 --- /dev/null +++ b/mediapipe/gpu/reusable_pool.h @@ -0,0 +1,145 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Consider this file an implementation detail. None of this is part of the +// public API. + +#ifndef MEDIAPIPE_GPU_REUSABLE_POOL_H_ +#define MEDIAPIPE_GPU_REUSABLE_POOL_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/multi_pool.h" + +namespace mediapipe { + +template +class ReusablePool : public std::enable_shared_from_this> { + public: + using ItemFactory = absl::AnyInvocable() const>; + + // Creates a pool. This pool will manage buffers of the specified dimensions, + // and will keep keep_count buffers around for reuse. + // We enforce creation as a shared_ptr so that we can use a weak reference in + // the buffers' deleters. + static std::shared_ptr> Create( + ItemFactory item_factory, const MultiPoolOptions& options) { + return std::shared_ptr>( + new ReusablePool(std::move(item_factory), options)); + } + + // Obtains a buffer. May either be reused or created anew. + // A GlContext must be current when this is called. + std::shared_ptr GetBuffer(); + + // This method is meant for testing. + std::pair GetInUseAndAvailableCounts(); + + protected: + ReusablePool(ItemFactory item_factory, const MultiPoolOptions& options) + : item_factory_(std::move(item_factory)), + keep_count_(options.keep_count) {} + + private: + // Return a buffer to the pool. + void Return(std::unique_ptr buf); + + // If the total number of buffers is greater than keep_count, destroys any + // surplus buffers that are no longer in use. + void TrimAvailable(std::vector>* trimmed) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + const ItemFactory item_factory_; + const int keep_count_; + + absl::Mutex mutex_; + int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; + std::vector> available_ ABSL_GUARDED_BY(mutex_); +}; + +template +inline std::shared_ptr ReusablePool::GetBuffer() { + std::unique_ptr buffer; + bool reuse = false; + + { + absl::MutexLock lock(&mutex_); + if (available_.empty()) { + buffer = item_factory_(); + if (!buffer) return nullptr; + } else { + buffer = std::move(available_.back()); + available_.pop_back(); + reuse = true; + } + + ++in_use_count_; + } + + // This needs to wait on consumer sync points, therefore it should not be + // done while holding the mutex. + if (reuse) { + buffer->Reuse(); + } + + // Return a shared_ptr with a custom deleter that adds the buffer back + // to our available list. + std::weak_ptr> weak_pool(this->shared_from_this()); + return std::shared_ptr(buffer.release(), [weak_pool](Item* buf) { + auto pool = weak_pool.lock(); + if (pool) { + pool->Return(absl::WrapUnique(buf)); + } else { + delete buf; + } + }); +} + +template +inline std::pair ReusablePool::GetInUseAndAvailableCounts() { + absl::MutexLock lock(&mutex_); + return {in_use_count_, available_.size()}; +} + +template +void ReusablePool::Return(std::unique_ptr buf) { + std::vector> trimmed; + { + absl::MutexLock lock(&mutex_); + --in_use_count_; + available_.emplace_back(std::move(buf)); + TrimAvailable(&trimmed); + } + // The trimmed buffers will be released without holding the lock. +} + +template +void ReusablePool::TrimAvailable( + std::vector>* trimmed) { + int keep = std::max(keep_count_ - in_use_count_, 0); + if (available_.size() > keep) { + auto trim_it = std::next(available_.begin(), keep); + if (trimmed) { + std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); + } + available_.erase(trim_it, available_.end()); + } +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_REUSABLE_POOL_H_ diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index d93eea7b5..04265cab5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -55,7 +55,11 @@ public class PacketCreator { public Packet createRgbImage(ByteBuffer buffer, int width, int height) { int widthStep = (((width * 3) + 3) / 4) * 4; if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + widthStep * height + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -123,7 +127,11 @@ public class PacketCreator { */ public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) { if (width * height * 4 != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + width * height * 4); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -136,7 +144,7 @@ public class PacketCreator { */ public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) { if (width * height != buffer.capacity()) { - throw new RuntimeException( + throw new IllegalArgumentException( "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); } return Packet.create( @@ -150,7 +158,11 @@ public class PacketCreator { */ public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -163,7 +175,11 @@ public class PacketCreator { */ public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -354,25 +370,24 @@ public class PacketCreator { *

For 3 and 4 channel images, the pixel rows should have 4-byte alignment. */ public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { + int widthStep; if (numChannels == 4) { - if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); - } + widthStep = width * 4; } else if (numChannels == 3) { - int widthStep = (((width * 3) + 3) / 4) * 4; - if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); - } + widthStep = (((width * 3) + 3) / 4) * 4; } else if (numChannels == 1) { - if (width * height != buffer.capacity()) { - throw new RuntimeException( - "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); - } + widthStep = width; } else { - throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); + throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels); + } + int expectedSize = widthStep * height; + if (buffer.capacity() != expectedSize) { + throw new IllegalArgumentException( + "The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity()); } return Packet.create( - nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); + nativeCreateCpuImage( + mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels)); } /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ @@ -430,7 +445,7 @@ public class PacketCreator { long context, int name, int width, int height, TextureReleaseCallback releaseCallback); private native long nativeCreateCpuImage( - long context, ByteBuffer buffer, int width, int height, int numChannels); + long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels); private native long nativeCreateInt32Array(long context, int[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index bb3be318d..d9508c1f7 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -20,9 +20,7 @@ android_library( name = "image", srcs = glob(["*.java"]), manifest = "AndroidManifest.xml", - visibility = [ - "//mediapipe:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ "//third_party:androidx_legacy_support_v4", "//third_party:autovalue", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 6a67c01cb..23bd553af 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -231,8 +231,6 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { *graph_config(), absl::StrCat("egl_surface_sink_", output_stream_name))); sink_node->set_calculator("GlSurfaceSinkCalculator"); sink_node->add_input_stream(output_stream_name); - sink_node->add_input_side_packet( - absl::StrCat(kGpuSharedTagName, ":", kGpuSharedSidePacketName)); const std::string input_side_packet_name = mediapipe::tool::GetUnusedSidePacketName( diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 250d7c938..46ea1ce41 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/formats/image.h" @@ -107,55 +109,31 @@ absl::StatusOr CreateGpuBuffer( // Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java // ByteBuffer. -std::unique_ptr CreateImageFrameFromByteBuffer( - JNIEnv* env, jobject byte_buffer, jint width, jint height, - mediapipe::ImageFormat::Format format) { - switch (format) { - case mediapipe::ImageFormat::SRGBA: - case mediapipe::ImageFormat::SRGB: - case mediapipe::ImageFormat::GRAY8: - break; - default: - LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; - return nullptr; - } - - auto image_frame = std::make_unique( - format, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - +absl::StatusOr> +CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, + jint height, jint width_step, + mediapipe::ImageFormat::Format format) { const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - const int num_channels = image_frame->NumberOfChannels(); - const int expected_buffer_size = - num_channels == 1 ? width * height : image_frame->PixelDataSize(); - - if (buffer_size != expected_buffer_size) { - if (num_channels != 1) - LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; - return nullptr; + const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); } - // Copy buffer data to image frame's pixel_data_. - if (num_channels == 1) { - const int width_step = image_frame->WidthStep(); - const char* src_row = - reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); - char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); - for (int i = height; i > 0; --i) { - std::memcpy(dst_row, src_row, width); - src_row += width; - dst_row += width_step; - } - } else { - // 3 and 4 channels. - const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); - std::memcpy(image_frame->MutablePixelData(), buffer_data, - image_frame->PixelDataSize()); - } + const int expected_buffer_size = height * width_step; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; + + auto image_frame = std::make_unique(); + // TODO: we could retain the buffer with a special deleter and use + // the data directly without a copy. May need a new Java API since existing + // code might expect to be able to overwrite the buffer after creating an + // ImageFrame from it. + image_frame->CopyPixelData( + format, width, height, width_step, static_cast(buffer_data), + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; } @@ -176,77 +154,83 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); - if (nullptr == image_frame) return 0L; + // We require 4-byte alignment. See Java method. + constexpr int kAlignment = 4; + int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1; + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, + width_step, mediapipe::ImageFormat::SRGB); + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } +absl::StatusOr> CreateRgbImageFromRgba( + JNIEnv* env, jobject byte_buffer, jint width, jint height) { + const uint8_t* rgba_data = + static_cast(env->GetDirectBufferAddress(byte_buffer)); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + + const int expected_buffer_size = width * height * 4; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, + image_frame->MutablePixelData(), + image_frame->WidthStep()); + return image_frame; +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const uint8_t* rgba_data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != width * height * 4) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << width * height * 4 - << ", Image width: " << width; - return 0L; - } - mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, - image_frame->MutablePixelData(), - image_frame->WidthStep()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); - if (nullptr == image_frame) return 0L; + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8); + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::VEC32F1); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); - if (nullptr == image_frame) return 0L; - - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::SRGBA); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -291,6 +275,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)( jint num_samples) { const uint8_t* audio_sample = reinterpret_cast(env->GetDirectBufferAddress(data)); + if (!audio_sample) { + ThrowIfError(env, absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It " + "should be created using allocateDirect.")); + return 0L; + } mediapipe::Packet packet = createAudioPacket(audio_sample, num_samples, num_channels); return CreatePacketWithContext(context, packet); @@ -360,8 +350,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data) { if (env->GetArrayLength(data) != rows * cols) { - LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " - << rows * cols; + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Please check the matrix data size, has to be rows * cols = ", + rows * cols))); return 0L; } std::unique_ptr matrix(new mediapipe::Matrix(rows, cols)); @@ -379,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels) { + jint height, jint width_step, jint num_channels) { mediapipe::ImageFormat::Format format; switch (num_channels) { case 4: @@ -392,16 +384,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( format = mediapipe::ImageFormat::GRAY8; break; default: - LOG(ERROR) << "Channels must be either 1, 3, or 4."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Channels must be either 1, 3, or 4, but are ", + num_channels))); return 0L; } - auto image_frame = - CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); - if (nullptr == image_frame) return 0L; + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width_step, format); + if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = - mediapipe::MakePacket(std::move(image_frame)); + mediapipe::MakePacket(*std::move(image_frame_or)); return CreatePacketWithContext(context, packet); } @@ -502,7 +496,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( jbyte* data_ref = env->GetByteArrayElements(data, nullptr); auto options = absl::make_unique(); if (!options->ParseFromArray(data_ref, count)) { - LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Parsing binary-encoded CalculatorOptions failed."))); return 0L; } mediapipe::Packet packet = mediapipe::Adopt(options.release()); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index d6f44b0a3..b3b1043fb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels); + jint height, jint width_step, jint num_channels); JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index c215dd929..737f6db72 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" @@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( : GetFromNativeHandle(packet); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } // Assume byte buffer stores pixel data contiguously. const int expected_buffer_size = image.Width() * image.Height() * image.ByteDepth() * image.NumberOfChannels(); if (buffer_size != expected_buffer_size) { - LOG(ERROR) << "Expected buffer size " << expected_buffer_size - << " got: " << buffer_size << ", width " << image.Width() - << ", height " << image.Height() << ", channels " - << image.NumberOfChannels(); + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Expected buffer size ", expected_buffer_size, + " got: ", buffer_size, ", width ", image.Width(), ", height ", + image.Height(), ", channels ", image.NumberOfChannels()))); return false; } switch (image.ByteDepth()) { case 1: { - uint8* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint8* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 2: { - uint16* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint16* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 4: { - float* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + float* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } @@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( uint8_t* rgba_data = static_cast(env->GetDirectBufferAddress(byte_buffer)); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } if (buffer_size != image.Width() * image.Height() * 4) { - LOG(ERROR) << "Buffer size has to be width*height*4\n" - << "Image width: " << image.Width() - << ", Image height: " << image.Height() - << ", Buffer size: " << buffer_size << ", Buffer size needed: " - << image.Width() * image.Height() * 4; + ThrowIfError(env, + absl::InvalidArgumentError(absl::StrCat( + "Buffer size has to be width*height*4\n" + "Image width: ", + image.Width(), ", Image height: ", image.Height(), + ", Buffer size: ", buffer_size, ", Buffer size needed: ", + image.Width() * image.Height() * 4))); return false; } mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(), diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 200726864..f376edffa 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -91,6 +91,10 @@ class Classifier(custom_model.CustomModel): self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, + # `steps_per_epoch` is intentionally set to None in case the dataset + # is not repeated. Otherwise, the training process will stop when the + # dataset is exhausted even if there are epochs remaining. + steps_per_epoch=None, validation_data=validation_dataset, callbacks=self._callbacks) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 12fef631f..492bba0a9 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -45,6 +45,7 @@ py_test( name = "model_util_test", srcs = ["model_util_test.py"], deps = [ + ":file_util", ":model_util", ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index bef9c8a97..f0020db25 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -13,10 +13,13 @@ # limitations under the License. import os +from typing import Optional +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util @@ -24,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_keras_model(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_keras_model(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') model.save(saved_model_path) + # model_util.load_keras_model takes in a relative path to files within the + # model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = saved_model_path loaded_model = model_util.load_keras_model(saved_model_path) input_tensors = test_util.create_random_sample(size=[1, input_dim]) @@ -36,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) - def test_load_tflite_model_buffer(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_tflite_model_buffer(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_model = model_util.convert_to_tflite(model) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) - + # model_util.load_tflite_model_buffer takes in a relative path to files + # within the model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = tflite_file tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) test_util.test_tflite( keras_model=model, @@ -76,8 +86,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]), expected_steps_per_epoch=2)) - def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, - expected_steps_per_epoch): + def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int], + batch_size: Optional[int], + train_data: Optional[tf.data.Dataset], + expected_steps_per_epoch: int): estimated_steps_per_epoch = model_util.get_steps_per_epoch( steps_per_epoch=steps_per_epoch, batch_size=batch_size, @@ -130,7 +142,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): testcase_name='float16_quantize', config=quantization.QuantizationConfig.for_float16(), model_size=1468)) - def test_convert_to_tflite_quantized(self, config, model_size): + def test_convert_to_tflite_quantized(self, + config: quantization.QuantizationConfig, + model_size: int): input_dim = 16 num_classes = 2 max_input_value = 5 @@ -157,5 +171,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): test_util.test_tflite_file( keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + if __name__ == '__main__': tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 0c35e7966..7bb41351e 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -21,9 +21,16 @@ package( licenses(["notice"]) +###################################################################### +# Public target of the MediaPipe Model Maker TextCassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/text/text_classifier/customize for +# more information about the MediaPipe Model Maker TextCassifier APIs. +###################################################################### py_library( name = "text_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":model_options", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b7d334d9c..256447a8d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -24,9 +24,9 @@ package( # TODO: Remove the unncessary test data once the demo data are moved to an open-sourced # directory. filegroup( - name = "test_data", + name = "testdata", srcs = glob([ - "test_data/**", + "testdata/**", ]), ) @@ -53,7 +53,7 @@ py_test( name = "dataset_test", srcs = ["dataset_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], deps = [ @@ -103,9 +103,16 @@ py_library( ], ) +###################################################################### +# Public target of the MediaPipe Model Maker GestureRecognizer APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer/customize +# for more information about the MediaPipe Model Maker GestureRecognizer APIs. +###################################################################### py_library( name = "gesture_recognizer_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":gesture_recognizer", @@ -129,7 +136,7 @@ py_test( size = "large", srcs = ["gesture_recognizer_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, @@ -144,7 +151,7 @@ py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], data = [ - ":test_data", + ":testdata", ], deps = [ ":metadata_writer", @@ -157,7 +164,7 @@ py_binary( name = "gesture_recognizer_demo", srcs = ["gesture_recognizer_demo.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], python_version = "PY3", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py index 06075fbc6..1cf9f0619 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -31,7 +31,7 @@ FLAGS = flags.FLAGS # TODO: Move hand gesture recognizer demo dataset to an # open-sourced directory. -TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data' +TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data' def define_flags(): diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index eb2b1d171..280fc6a82 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import tempfile from unittest import mock as unittest_mock import zipfile @@ -24,7 +25,8 @@ from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' +tf.keras.backend.experimental.enable_tf_random_generator() class GestureRecognizerTest(tf.test.TestCase): @@ -40,30 +42,36 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - self._model_options = gesture_recognizer.ModelOptions() - self._hparams = gesture_recognizer.HParams(epochs=2) - self._gesture_recognizer_options = ( - gesture_recognizer.GestureRecognizerOptions( - model_options=self._model_options, hparams=self._hparams)) + tf.keras.utils.set_random_seed(87654321) all_data = self._load_data() - # Splits data, 90% data for training, 10% for testing - self._train_data, self._test_data = all_data.split(0.9) + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) self._test_accuracy(model) def test_export_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) model.export_model() - model_bundle_file = os.path.join(self._hparams.export_dir, + model_bundle_file = os.path.join(hparams.export_dir, 'gesture_recognizer.task') with zipfile.ZipFile(model_bundle_file) as zf: self.assertEqual( @@ -87,10 +95,11 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.5): - _, accuracy = model.evaluate(self._test_data) - tf.compat.v1.logging.info(f'accuracy: {accuracy}') - self.assertGreaterEqual(accuracy, threshold) + def _test_accuracy(self, model, threshold=0.0): + # Test on _train_data because of our limited dataset size + _, accuracy = model.evaluate(self._train_data) + tf.compat.v1.logging.info(f'train accuracy: {accuracy}') + self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( gesture_recognizer.hyperparameters, @@ -102,27 +111,32 @@ class GestureRecognizerTest(tf.test.TestCase): 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) - def test_create_hparams_and_model_options_if_none_in_image_classifier_options( + def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( self, mock_hparams, mock_model_options): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=options) mock_hparams.assert_called_once() mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, - options=self._gesture_recognizer_options) + validation_data=self._validation_data, + options=gesture_recognizer_options) self._test_accuracy(model) self.assertRegex(mock_stdout.getvalue(), 'Resuming from') diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index e1101e066..83998141d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -23,7 +23,7 @@ from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writ from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata" +_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata" _EXPECTED_JSON = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json")) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index c581d9fbc..29ae189e9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -21,9 +21,16 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) +###################################################################### +# Public target of the MediaPipe Model Maker ImageClassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize for +# more information about the MediaPipe Model Maker ImageClassifier APIs. +###################################################################### py_library( name = "image_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":hyperparameters", diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 1ff6132b4..df71a8fef 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -177,7 +177,7 @@ class ImageClassifier(classifier.Classifier): Args: model_name: File name to save TFLite model with metadata. The full export - path is {self._hparams.model_dir}/{model_name}. + path is {self._hparams.export_dir}/{model_name}. quantization_config: The configuration for model quantization. """ if not tf.io.gfile.exists(self._hparams.export_dir): diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py index 5832ea53a..f382e28aa 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str, data = image_classifier.Dataset.from_folder(data_dir) train_data, rest_data = data.split(0.8) validation_data, test_data = rest_data.split(0.5) - + model_options = image_classifier.ImageClassifierOptions( + supported_model=model_spec, + hparams=image_classifier.HParams(export_dir=export_dir), + ) model = image_classifier.ImageClassifier.create( - model_spec=model_spec, train_data=train_data, validation_data=validation_data, - hparams=image_classifier.HParams(model_dir=export_dir)) + options=model_options) _, acc = model.evaluate(test_data) print('Test accuracy: %f' % acc) @@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str, raise ValueError(f'Quantization: {quantization} is not recognized') model.export_model(quantization_config=quantization_config) - model.export_labels(export_dir) def main(_) -> None: diff --git a/mediapipe/modules/holistic_landmark/calculators/BUILD b/mediapipe/modules/holistic_landmark/calculators/BUILD index c3c091924..bc00b697c 100644 --- a/mediapipe/modules/holistic_landmark/calculators/BUILD +++ b/mediapipe/modules/holistic_landmark/calculators/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "hand_detections_from_pose_to_rects_calculator", srcs = ["hand_detections_from_pose_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "roi_tracking_calculator_proto", srcs = ["roi_tracking_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -49,7 +47,6 @@ mediapipe_proto_library( cc_library( name = "roi_tracking_calculator", srcs = ["roi_tracking_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":roi_tracking_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 48c9b181a..fafdfee8a 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -83,11 +83,11 @@ objc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:threadpool", - "//mediapipe/gpu:MPPGraphGPUData", "//mediapipe/gpu:gl_base", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", "@com_google_absl//absl/base:core_headers", @@ -147,7 +147,7 @@ objc_library( visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", ], @@ -173,7 +173,7 @@ objc_library( deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", ], ) diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 080cca20f..1bd177e80 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -24,7 +24,6 @@ #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/graph_service.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/objc/util.h" diff --git a/mediapipe/objc/MPPLayerRenderer.m b/mediapipe/objc/MPPLayerRenderer.m index 7c3027fb6..edd2216ee 100644 --- a/mediapipe/objc/MPPLayerRenderer.m +++ b/mediapipe/objc/MPPLayerRenderer.m @@ -54,10 +54,11 @@ glGenRenderbuffers(1, &renderbuffer_); glBindRenderbuffer(GL_RENDERBUFFER, renderbuffer_); glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_RENDERBUFFER, renderbuffer_); - BOOL success = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER fromDrawable:_layer]; + BOOL success __unused = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER + fromDrawable:_layer]; NSAssert(success, @"could not create renderbuffer storage for layer with bounds %@", NSStringFromCGRect(_layer.bounds)); - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + GLenum status __unused = glCheckFramebufferStatus(GL_FRAMEBUFFER); NSAssert(status == GL_FRAMEBUFFER_COMPLETE, @"failed to make complete framebuffer object %x", status); } diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index bebcbe97c..1b8b173f7 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MediaPipe solution drawing utils.""" import math @@ -135,15 +134,14 @@ def draw_landmarks( the image. connections: A list of landmark index tuples that specifies how landmarks to be connected in the drawing. - landmark_drawing_spec: Either a DrawingSpec object or a mapping from - hand landmarks to the DrawingSpecs that specifies the landmarks' drawing - settings such as color, line thickness, and circle radius. - If this argument is explicitly set to None, no landmarks will be drawn. - connection_drawing_spec: Either a DrawingSpec object or a mapping from - hand connections to the DrawingSpecs that specifies the - connections' drawing settings such as color and line thickness. - If this argument is explicitly set to None, no landmark connections will - be drawn. + landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand + landmarks to the DrawingSpecs that specifies the landmarks' drawing + settings such as color, line thickness, and circle radius. If this + argument is explicitly set to None, no landmarks will be drawn. + connection_drawing_spec: Either a DrawingSpec object or a mapping from hand + connections to the DrawingSpecs that specifies the connections' drawing + settings such as color and line thickness. If this argument is explicitly + set to None, no landmark connections will be drawn. Raises: ValueError: If one of the followings: @@ -197,14 +195,13 @@ def draw_landmarks( drawing_spec.color, drawing_spec.thickness) -def draw_axis( - image: np.ndarray, - rotation: np.ndarray, - translation: np.ndarray, - focal_length: Tuple[float, float] = (1.0, 1.0), - principal_point: Tuple[float, float] = (0.0, 0.0), - axis_length: float = 0.1, - axis_drawing_spec: DrawingSpec = DrawingSpec()): +def draw_axis(image: np.ndarray, + rotation: np.ndarray, + translation: np.ndarray, + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + axis_length: float = 0.1, + axis_drawing_spec: DrawingSpec = DrawingSpec()): """Draws the 3D axis on the image. Args: @@ -214,8 +211,8 @@ def draw_axis( focal_length: camera focal length along x and y directions. principal_point: camera principal point in x and y. axis_length: length of the axis in the drawing. - axis_drawing_spec: A DrawingSpec object that specifies the xyz axis - drawing settings such as line thickness. + axis_drawing_spec: A DrawingSpec object that specifies the xyz axis drawing + settings such as line thickness. Raises: ValueError: If one of the followings: @@ -226,7 +223,7 @@ def draw_axis( image_rows, image_cols, _ = image.shape # Create axis points in camera coordinate frame. axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation + axis_cam = np.matmul(rotation, axis_length * axis_world.T).T + translation x = axis_cam[..., 0] y = axis_cam[..., 1] z = axis_cam[..., 2] @@ -274,8 +271,9 @@ def plot_landmarks(landmark_list: landmark_pb2.NormalizedLandmarkList, connections' drawing settings such as color and line thickness. elevation: The elevation from which to view the plot. azimuth: the azimuth angle to rotate the plot. + Raises: - ValueError: If any connetions contain invalid landmark index. + ValueError: If any connection contains an invalid landmark index. """ if not landmark_list: return diff --git a/mediapipe/tasks/BUILD b/mediapipe/tasks/BUILD index 242a88cfc..98ddd5777 100644 --- a/mediapipe/tasks/BUILD +++ b/mediapipe/tasks/BUILD @@ -21,3 +21,10 @@ package_group( "//mediapipe/tasks/...", ], ) + +package_group( + name = "users", + includes = [ + ":internal", + ], +) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 1955adfe7..f61472413 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -16,6 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Classifier +# https://developers.google.com/mediapipe/solutions/audio/audio_classifier +cc_library( + name = "audio_classifier", + srcs = ["audio_classifier.cc"], + hdrs = ["audio_classifier.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_classifier_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_classifier_graph", srcs = ["audio_classifier_graph.cc"], @@ -26,7 +55,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", @@ -52,28 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_classifier", - srcs = ["audio_classifier.cc"], - hdrs = ["audio_classifier.h"], - deps = [ - ":audio_classifier_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index b982ef39a..6a0f627b2 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -16,6 +16,36 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Embedder +# https://developers.google.com/mediapipe/solutions/audio/audio_embedder +cc_library( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_embedder_graph", srcs = ["audio_embedder_graph.cc"], @@ -26,7 +56,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", @@ -51,29 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_embedder", - srcs = ["audio_embedder.cc"], - hdrs = ["audio_embedder.h"], - deps = [ - ":audio_embedder_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:embedding_result", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedder_options", - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:cosine_similarity", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index 7667feaa3..187f11f7f 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -100,6 +100,46 @@ void ConfigureAudioToTensorCalculator( } } // namespace +// An "AudioEmebdderGraph" performs embedding extractions. +// - Accepts CPU audio buffer and outputs embedding results on CPU. +// +// Inputs: +// AUDIO - Matrix +// Audio buffer to perform classification on. +// SAMPLE_RATE - double @Optional +// The sample rate of the corresponding audio data in the "AUDIO" stream. +// If sample rate is not provided, the "AUDIO" stream must carry a time +// series stream header with sample rate info. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Only produces results if +// the graph if the 'use_stream_mode' option is true. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Only +// produces results if the graph if the 'use_stream_mode' option is false. +// +// Example: +// node { +// calculator: "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph" +// input_stream: "AUDIO:audio_in" +// input_stream: "SAMPLE_RATE:sample_rate_in" +// output_stream: "EMBEDDINGS:embeddings_out" +// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out" +// options { +// [mediapipe.tasks.audio.audio_embedder.proto.AudioEmbedderGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// embedder_options { +// l2_normalize: true +// } +// } +// } +// } class AudioEmbedderGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( @@ -158,10 +198,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio embedding on // audio files. Disables timestamp aggregation by not connecting the diff --git a/mediapipe/tasks/cc/audio/core/BUILD b/mediapipe/tasks/cc/audio/core/BUILD index 93362fd3d..016faa10f 100644 --- a/mediapipe/tasks/cc/audio/core/BUILD +++ b/mediapipe/tasks/cc/audio/core/BUILD @@ -19,6 +19,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD deleted file mode 100644 index c90349ab2..000000000 --- a/mediapipe/tasks/cc/components/BUILD +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -mediapipe_proto_library( - name = "image_preprocessing_options_proto", - srcs = ["image_preprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -cc_library( - name = "image_preprocessing", - srcs = ["image_preprocessing.cc"], - hdrs = ["image_preprocessing.h"], - deps = [ - ":image_preprocessing_options_cc_proto", - "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/calculators/image:image_clone_calculator", - "//mediapipe/calculators/image:image_clone_calculator_cc_proto", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/gpu:gpu_origin_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - -# TODO: Enable this test - -# TODO: Investigate rewriting the build rule to only link -# the Bert Preprocessor if it's needed. -cc_library( - name = "text_preprocessing_graph", - srcs = ["text_preprocessing_graph.cc"], - hdrs = ["text_preprocessing_graph.h"], - deps = [ - "//mediapipe/calculators/tensor:bert_preprocessor_calculator", - "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:text_to_tensor_calculator", - "//mediapipe/framework:subgraph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 1f726a018..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 1a83fdad2..ad2c668c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,14 +25,12 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into either a ClassificationResult object // representing the classification results aggregated by classifier head, or @@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications; // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example without timestamp aggregation: // node { @@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node { ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); std::vector ConvertToTimestampedClassificationResults( CalculatorContext* cc); - // TODO: deprecate this function once migration is over. - ClassificationResult LegacyConvertToClassificationResult( - CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } - // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if - // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is - // not connected. All dependent tasks must be updated to use these outputs - // first. + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected()); + } else { + RET_CHECK(kClassificationsOut(cc).IsConnected()); + } return absl::OkStatus(); } @@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process( if (kTimestampsIn(cc).IsEmpty()) { return absl::OkStatus(); } - classification_result = LegacyConvertToClassificationResult(cc); kTimestampedClassificationsOut(cc).Send( ConvertToTimestampedClassificationResults(cc)); } else { - classification_result = LegacyConvertToClassificationResult(cc); kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } kClassificationResultOut(cc).Send(classification_result); @@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( return results; } -ClassificationResult -ClassificationAggregationCalculator::LegacyConvertToClassificationResult( - CalculatorContext* cc) { - ClassificationResult result; - Timestamp first_timestamp(0); - std::vector timestamps; - if (time_aggregation_enabled_) { - timestamps = kTimestampsIn(cc).Get(); - first_timestamp = timestamps[0]; - } else { - timestamps = {cc->InputTimestamp()}; - } - for (Timestamp timestamp : timestamps) { - int count = cached_classifications_[timestamp.Value()].size(); - for (int i = 0; i < count; ++i) { - Classifications* c; - if (result.classifications_size() <= i) { - c = result.add_classifications(); - if (!head_names_.empty()) { - c->set_head_index(i); - c->set_head_name(head_names_[i]); - } - } else { - c = result.mutable_classifications(i); - } - auto* entry = c->add_entries(); - for (const auto& elem : - cached_classifications_[timestamp.Value()][i].classification()) { - auto* category = entry->add_categories(); - if (elem.has_index()) { - category->set_index(elem.index()); - } - if (elem.has_score()) { - category->set_score(elem.score()); - } - if (elem.has_label()) { - category->set_category_name(elem.label()); - } - if (elem.has_display_name()) { - category->set_display_name(elem.display_name()); - } - } - entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / - 1000); - } - } - return result; -} - MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator); } // namespace api2 diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index bd66a0f28..35d3f4785 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 7b455c0c4..27d2357b5 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], -) - mediapipe_proto_library( name = "classifications_proto", srcs = ["classifications.proto"], deps = [ - ":category_proto", "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto deleted file mode 100644 index 412e71428..000000000 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto2"; - -package mediapipe.tasks.components.containers.proto; - -option java_package = "com.google.mediapipe.tasks.components.containers.proto"; -option java_outer_classname = "CategoryProto"; - -// TODO: deprecate this message once migration is over. -// A single classification result. -message Category { - // The index of the category in the corresponding label map, usually packed in - // the TFLite Model Metadata [1]. - // - // [1]: https://www.tensorflow.org/lite/convert/metadata - optional int32 index = 1; - // The score for this category, e.g. (but not necessarily) a probability in - // [0,1]. - optional float score = 2; - // A human readable name of the category filled from the label map. - optional string display_name = 3; - // An ID for the category, not necessarily human-readable, e.g. a Google - // Knowledge Graph ID [1], filled from the label map. - // - // [1]: https://developers.google.com/knowledge-graph - optional string category_name = 4; -} diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index f098ed0e4..2b2306829 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -18,27 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; import "mediapipe/framework/formats/classification.proto"; -import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; -// TODO: deprecate this message once migration is over. -// List of predicted categories with an optional timestamp. -message ClassificationEntry { - // The array of predicted categories, usually sorted by descending scores, - // e.g., from high to low probability. - repeated Category categories = 1; - // The optional timestamp (in milliseconds) associated to the classifcation - // entry. This is useful for time series use cases, e.g., audio - // classification. - optional int64 timestamp_ms = 2; -} - // Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { - // TODO: deprecate this field once migration is over. - repeated ClassificationEntry entries = 1; // The classification results for this head. optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful @@ -48,6 +33,8 @@ message Classifications { // name. // TODO: Add github link to metadata_schema.fbs. optional string head_name = 3; + // Reserved fields. + reserved 1; } // Classifications for a given classifier model. diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 7845a3dae..185bf231b 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -20,6 +20,7 @@ cc_library( name = "classifier_options", srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], ) @@ -67,6 +68,7 @@ cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], hdrs = ["embedder_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"], ) @@ -98,3 +100,62 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "image_preprocessing_graph", + srcs = ["image_preprocessing_graph.cc"], + hdrs = ["image_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_clone_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO: Enable this test + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 0fb62afaf..5a0472f5c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; @@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; // Struct holding the different output streams produced by the graph. struct ClassificationPostprocessingOutputStreams { - Source classification_result; Source classifications; Source> timestamped_classifications; }; @@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph( // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Connects output. ClassificationPostprocessingOutputStreams output_streams{ - /*classification_result=*/result_aggregation - [Output(kClassificationResultTag)], /*classifications=*/ result_aggregation[Output(kClassificationsTag)], /*timestamped_classifications=*/ diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 48575ceb0..03ae91130 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -58,9 +58,6 @@ namespace processors { // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index d4728e725..8eb6f3c3b 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsName[] = "tensors"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kClassificationResultName[] = "classification_result"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsName[] = "classifications"; constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; @@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { })pb")})); } -// TODO: remove these tests once migration is over. -class LegacyPostprocessingTest : public tflite_shims::testing::Test { - protected: - absl::StatusOr BuildGraph( - absl::string_view model_name, const proto::ClassifierOptions& options, - bool connect_timestamps = false) { - ASSIGN_OR_RETURN(auto model_resources, - CreateModelResourcesForModel(model_name)); - - Graph graph; - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( - *model_resources, options, - &postprocessing - .GetOptions())); - graph[Input>(kTensorsTag)].SetName(kTensorsName) >> - postprocessing.In(kTensorsTag); - if (connect_timestamps) { - graph[Input>(kTimestampsTag)].SetName( - kTimestampsName) >> - postprocessing.In(kTimestampsTag); - } - postprocessing.Out(kClassificationResultTag) - .SetName(kClassificationResultName) >> - graph[Output(kClassificationResultTag)]; - - MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); - ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( - kClassificationResultName)); - MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); - return poller; - } - - template - void AddTensor( - const std::vector& tensor, const Tensor::ElementType& element_type, - const Tensor::QuantizationParameters& quantization_parameters = {}) { - tensors_->emplace_back(element_type, - Tensor::Shape{1, static_cast(tensor.size())}, - quantization_parameters); - auto view = tensors_->back().GetCpuWriteView(); - T* buffer = view.buffer(); - std::copy(tensor.begin(), tensor.end(), buffer); - } - - absl::Status Run( - std::optional> aggregation_timestamps = std::nullopt, - int timestamp = 0) { - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); - // Reset tensors for future calls. - tensors_ = absl::make_unique>(); - if (aggregation_timestamps.has_value()) { - auto packet = absl::make_unique>(); - for (const auto& timestamp : *aggregation_timestamps) { - packet->emplace_back(Timestamp(timestamp)); - } - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); - } - return absl::OkStatus(); - } - - absl::StatusOr GetClassificationResult( - OutputStreamPoller& poller) { - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); - MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); - - Packet packet; - if (!poller.Next(&packet)) { - return absl::InternalError("Unable to get output packet"); - } - auto result = packet.Get(); - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); - return result; - } - - private: - CalculatorGraph calculator_graph_; - std::unique_ptr> tensors_ = - absl::make_unique>(); -}; - -TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - options.set_score_threshold(0.5); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 18; - tensor[2] = 16; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto(R"pb(classifications { - entries { - categories { index: 1 score: 0.8 } - categories { index: 2 score: 0.6 } - timestamp_ms: 0 - } - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.8 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.6899744811 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6456563062 - category_name: "great white shark" - } - categories { - index: 2 - score: 0.5986876601 - category_name: "goldfish" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor_0(kTwoHeadsNumClasses[0], 0); - tensor_0[1] = 0.2; - tensor_0[2] = 0.4; - tensor_0[3] = 0.6; - std::vector tensor_1(kTwoHeadsNumClasses[1], 0); - tensor_1[1] = 0.2; - tensor_1[2] = 0.4; - tensor_1[3] = 0.6; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kFloat32); - AddTensor(tensor_1, Tensor::ElementType::kFloat32); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Narration, monologue" - } - categories { - index: 2 - score: 0.4 - category_name: "Conversation" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "yamnet_classification" - } - classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Azara\'s Spinetail" - } - categories { - index: 2 - score: 0.4 - category_name: "House Sparrow" - } - timestamp_ms: 0 - } - head_index: 1 - head_name: "bird_classification" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, - /*connect_timestamps=*/true)); - // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); - tensor_0[1] = 12; - tensor_0[2] = 14; - tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); - tensor_1[5] = 12; - tensor_1[6] = 14; - tensor_1[7] = 16; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - AddTensor(tensor_1, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run( - /*aggregation_timestamps=*/std::optional>({0, 1000}), - /*timestamp=*/1000)); - - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - entries { - categories { index: 7 score: 0.6 category_name: "stingray" } - categories { - index: 6 - score: 0.4 - category_name: "electric ray" - } - timestamp_ms: 1 - } - head_index: 0 - head_name: "probability" - })pb")); -} - } // namespace } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 880aec5d7..ad4881e12 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -150,7 +150,7 @@ absl::StatusOr> GetHeadNames( } // namespace -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options) { @@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing( // timestamp aggregation is required. // // The recommended way of using this graph is through the GraphBuilder API using -// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more -// details. +// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for +// more details. class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 58606ed80..889992463 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -58,7 +58,7 @@ namespace processors { // The embedding result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options); diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 84d84d648..163e46ee8 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { options_in.set_quantize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors." "EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( *model_resources, options, &postprocessing .GetOptions())); diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc similarity index 90% rename from mediapipe/tasks/cc/components/image_preprocessing.cc rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index ef447df97..b24b7f0cb 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" @@ -42,6 +42,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::Tensor; @@ -144,9 +145,9 @@ bool DetermineImagePreprocessingGpuBackend( return acceleration.has_gpu(); } -absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, - bool use_gpu, - ImagePreprocessingOptions* options) { +absl::Status ConfigureImagePreprocessingGraph( + const ModelResources& model_resources, bool use_gpu, + proto::ImagePreprocessingGraphOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( @@ -154,9 +155,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { - options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::GPU_BACKEND); } else { - options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::CPU_BACKEND); } return absl::OkStatus(); } @@ -170,8 +171,7 @@ Source AddDataConverter(Source image_in, Graph& graph, return image_converter[Output("")]; } -// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image -// preprocessing. +// An ImagePreprocessingGraph performs image preprocessing. // - Accepts CPU input images and outputs CPU tensors. // // Inputs: @@ -192,7 +192,7 @@ Source AddDataConverter(Source image_in, Graph& graph, // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to // [0.f, 1.f] by the output dimensions. The padding values are non-zero only -// when the "keep_aspect_ratio" is true in ImagePreprocessingOptions. +// when the "keep_aspect_ratio" is true in ImagePreprocessingGraphOptions. // IMAGE_SIZE - std::pair @Optional // The size of the original input image as a pair. // IMAGE - Image @Optional @@ -200,15 +200,15 @@ Source AddDataConverter(Source image_in, Graph& graph, // GPU). // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureImagePreprocessing()' function. See header file for more -// details. -class ImagePreprocessingSubgraph : public Subgraph { +// using the 'ConfigureImagePreprocessingGraph()' function. See header file for +// more details. +class ImagePreprocessingGraph : public Subgraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; auto output_streams = BuildImagePreprocessing( - sc->Options(), + sc->Options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph); output_streams.tensors >> graph[Output>(kTensorsTag)]; @@ -233,24 +233,25 @@ class ImagePreprocessingSubgraph : public Subgraph { // - the image that has pixel data stored on the target storage // (mediapipe::Image). // - // options: the mediapipe tasks ImagePreprocessingOptions. + // options: the mediapipe tasks ImagePreprocessingGraphOptions. // image_in: (mediapipe::Image) stream to preprocess. // graph: the mediapipe builder::Graph instance to be updated. ImagePreprocessingOutputStreams BuildImagePreprocessing( - const ImagePreprocessingOptions& options, Source image_in, - Source norm_rect_in, Graph& graph) { + const proto::ImagePreprocessingGraphOptions& options, + Source image_in, Source norm_rect_in, + Graph& graph) { // Convert image to tensor. auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); image_to_tensor.GetOptions() .CopyFrom(options.image_to_tensor_options()); switch (options.backend()) { - case ImagePreprocessingOptions::CPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::CPU_BACKEND: { auto cpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/false); cpu_image >> image_to_tensor.In(kImageTag); break; } - case ImagePreprocessingOptions::GPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::GPU_BACKEND: { auto gpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/true); gpu_image >> image_to_tensor.In(kImageTag); @@ -284,8 +285,9 @@ class ImagePreprocessingSubgraph : public Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ImagePreprocessingSubgraph); + ::mediapipe::tasks::components::processors::ImagePreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h similarity index 72% rename from mediapipe/tasks/cc/components/image_preprocessing.h rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h index 6963b6556..455a9b316 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h @@ -13,35 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures an ImagePreprocessing subgraph using the provided model resources +// Configures an ImagePreprocessingGraph using the provided model resources // When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.ImagePreprocessingGraph"); // core::proto::Acceleration acceleration; // acceleration.mutable_xnnpack(); // bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); -// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( +// MP_RETURN_IF_ERROR(ConfigureImagePreprocessingGraph( // model_resources, // use_gpu, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ImagePreprocessing subgraph has the following I/O: +// The resulting ImagePreprocessingGraph has the following I/O: // Inputs: // IMAGE - Image // The image to preprocess. @@ -61,17 +62,18 @@ namespace components { // IMAGE - Image @Optional // The image that has the pixel data stored on the target storage (CPU vs // GPU). -absl::Status ConfigureImagePreprocessing( +absl::Status ConfigureImagePreprocessingGraph( const core::ModelResources& model_resources, bool use_gpu, - ImagePreprocessingOptions* options); + proto::ImagePreprocessingGraphOptions* options); -// Determine if the image preprocessing subgraph should use GPU as the backend +// Determine if the image preprocessing graph should use GPU as the backend // according to the given acceleration setting. bool DetermineImagePreprocessingGpuBackend( const core::proto::Acceleration& acceleration); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc new file mode 100644 index 000000000..6c094c6bc --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::testing::ContainerEq; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +constexpr char kMobileNetFloatWithoutMetadata[] = + "mobilenet_v1_0.25_224_1_default_1.tflite"; +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kMobileNetQuantizedWithoutMetadata[] = + "mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; + +constexpr char kTestImage[] = "burger.jpg"; +constexpr int kTestImageWidth = 480; +constexpr int kTestImageHeight = 325; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; +constexpr std::array kIdentityMatrix = {1, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 1}; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kMatrixName[] = "matrix_out"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTensorsName[] = "tensors_out"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kImageSizeName[] = "image_size_out"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLetterboxPaddingName[] = "letterbox_padding_out"; + +constexpr float kLetterboxMaxAbsError = 1e-5; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +// Helper function to create a TaskRunner from ModelResources. +absl::StatusOr> CreateTaskRunner( + const ModelResources& model_resources, bool keep_aspect_ratio) { + Graph graph; + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + auto& options = + preprocessing.GetOptions(); + options.mutable_image_to_tensor_options()->set_keep_aspect_ratio( + keep_aspect_ratio); + MP_RETURN_IF_ERROR( + ConfigureImagePreprocessingGraph(model_resources, false, &options)); + graph[Input(kImageTag)].SetName(kImageName) >> + preprocessing.In(kImageTag); + preprocessing.Out(kTensorsTag).SetName(kTensorsName) >> + graph[Output>(kTensorsTag)]; + preprocessing.Out(kMatrixTag).SetName(kMatrixName) >> + graph[Output>(kMatrixTag)]; + preprocessing.Out(kImageSizeTag).SetName(kImageSizeName) >> + graph[Output>(kImageSizeTag)]; + preprocessing.Out(kLetterboxPaddingTag).SetName(kLetterboxPaddingName) >> + graph[Output>(kLetterboxPaddingTag)]; + + return TaskRunner::Create(graph.GetConfig()); +} + +class ConfigureTest : public tflite_shims::testing::Test {}; + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 192 + output_tensor_height: 192 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelFallbacksCpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: GPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + auto status = + ConfigureImagePreprocessingGraph(*model_resources, false, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_THAT(status.message(), + HasSubstr("requires specifying NormalizationOptions metadata")); +} + +// Struct holding the parameters for parameterized PreprocessingTest class. +struct PreprocessingParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // If true, keep test image aspect ratio. + bool keep_aspect_ratio; + // The expected output tensor type. + Tensor::ElementType expected_type; + // The expected outoput tensor shape. + std::vector expected_shape; + // The expected output letterbox padding; + std::array expected_letterbox_padding; +}; + +class PreprocessingTest : public testing::TestWithParam {}; + +TEST_P(PreprocessingTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kTestImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateTaskRunner(*model_resources, GetParam().keep_aspect_ratio)); + + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + MP_ASSERT_OK(output_packets); + + const std::vector& tensors = + (*output_packets)[kTensorsName].Get>(); + EXPECT_EQ(tensors.size(), 1); + EXPECT_EQ(tensors[0].element_type(), GetParam().expected_type); + EXPECT_THAT(tensors[0].shape().dims, ContainerEq(GetParam().expected_shape)); + auto& matrix = (*output_packets)[kMatrixName].Get>(); + if (!GetParam().keep_aspect_ratio) { + for (int i = 0; i < matrix.size(); ++i) { + EXPECT_FLOAT_EQ(matrix[i], kIdentityMatrix[i]); + } + } + auto& image_size = + (*output_packets)[kImageSizeName].Get>(); + EXPECT_EQ(image_size.first, kTestImageWidth); + EXPECT_EQ(image_size.second, kTestImageHeight); + std::array letterbox_padding = + (*output_packets)[kLetterboxPaddingName].Get>(); + for (int i = 0; i < letterbox_padding.size(); ++i) { + EXPECT_NEAR(letterbox_padding[i], GetParam().expected_letterbox_padding[i], + kLetterboxMaxAbsError); + } +} + +INSTANTIATE_TEST_SUITE_P( + PreprocessingTest, PreprocessingTest, + Values( + PreprocessingParams{.test_name = "kMobileNetQuantizedWithMetadata", + .input_model_name = kMobileNetQuantizedWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetQuantizedWithoutMetadata", + .input_model_name = kMobileNetQuantizedWithoutMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 192, 192, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{.test_name = "kMobileNetFloatWithMetadata", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetFloatWithMetadataKeepAspectRatio", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = true, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {/*left*/ 0, + /*top*/ 0.161458, + /*right*/ 0, + /*bottom*/ 0.161458}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 23ebbe008..f48c4bad8 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -49,3 +49,22 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", ], ) + +mediapipe_proto_library( + name = "image_preprocessing_graph_options_proto", + srcs = ["image_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +mediapipe_proto_library( + name = "text_preprocessing_graph_options_proto", + srcs = ["text_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/image_preprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/components/image_preprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto index d1685c319..bf4fc9067 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto @@ -15,14 +15,14 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto"; import "mediapipe/framework/calculator.proto"; -message ImagePreprocessingOptions { +message ImagePreprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ImagePreprocessingOptions ext = 456882436; + optional ImagePreprocessingGraphOptions ext = 456882436; } // Options for the ImageToTensor calculator encapsulated by the diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto similarity index 96% rename from mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto rename to mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index 926e3d7fb..a67cfd8a9 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc similarity index 94% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.cc rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 6aad8fdd5..de16375bd 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include @@ -25,13 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -41,7 +42,8 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; +using ::mediapipe::tasks::components::processors::proto:: + TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; @@ -169,7 +171,7 @@ absl::StatusOr GetMaxSeqLen(const tflite::SubGraph& model_graph) { } } // namespace -absl::Status ConfigureTextPreprocessingSubgraph( +absl::Status ConfigureTextPreprocessingGraph( const ModelResources& model_resources, TextPreprocessingGraphOptions& options) { if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { @@ -200,8 +202,7 @@ absl::Status ConfigureTextPreprocessingSubgraph( return absl::OkStatus(); } -// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text -// preprocessing. +// A TextPreprocessingGraph performs text preprocessing. // - Accepts a std::string input and outputs CPU tensors. // // Inputs: @@ -216,9 +217,9 @@ absl::Status ConfigureTextPreprocessingSubgraph( // Vector containing the preprocessed input tensors for the TFLite model. // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureTextPreprocessing()' function. See header file for more -// details. -class TextPreprocessingSubgraph : public mediapipe::Subgraph { +// using the 'ConfigureTextPreprocessingGraph()' function. See header file for +// more details. +class TextPreprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -267,8 +268,9 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::TextPreprocessingSubgraph); + ::mediapipe::tasks::components::processors::TextPreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h similarity index 67% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.h rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h index b031a5550..43d57be29 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h @@ -13,26 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" -// Configures a TextPreprocessing subgraph using the provided `model_resources` +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a TextPreprocessingGraph using the provided `model_resources` // and TextPreprocessingGraphOptions. // - Accepts a std::string input and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( // model_resources, // &preprocessing.GetOptions())); // -// The resulting TextPreprocessing subgraph has the following I/O: +// The resulting TextPreprocessingGraph has the following I/O: // Inputs: // TEXT - std::string // The text to preprocess. @@ -43,16 +48,13 @@ limitations under the License. // Outputs: // TENSORS - std::vector // Vector containing the preprocessed input tensors for the TFLite model. -namespace mediapipe { -namespace tasks { -namespace components { - -absl::Status ConfigureTextPreprocessingSubgraph( - const tasks::core::ModelResources& model_resources, - tasks::components::proto::TextPreprocessingGraphOptions& options); +absl::Status ConfigureTextPreprocessingGraph( + const core::ModelResources& model_resources, + proto::TextPreprocessingGraphOptions& options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 4534a1652..569023753 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -22,12 +22,3 @@ mediapipe_proto_library( name = "segmenter_options_proto", srcs = ["segmenter_options.proto"], ) - -mediapipe_proto_library( - name = "text_preprocessing_graph_options_proto", - srcs = ["text_preprocessing_graph_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 291dd29fe..202f3ea3c 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,6 +22,7 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], + visibility = ["//visibility:public"], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 618761f32..d5c12ee95 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -99,11 +99,21 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { +#ifdef __EMSCRIPTEN__ + // In browsers, the model file may require a custom ResourceProviderFn to + // provide the model content. The open() method may not work in this case. + // Thus, loading the model content from the model file path in advance with + // the help of GetResourceContents. + MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); +#else // If the model file name is a relative path, searches the file in a // platform-specific location and returns the absolute path on success. ASSIGN_OR_RETURN(std::string path_to_resource, mediapipe::PathToResourceAsFile(model_file_->file_name())); model_file_->set_file_name(path_to_resource); +#endif // __EMSCRIPTEN__ } ASSIGN_OR_RETURN( model_file_handler_, diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 52b0c0e4b..3c9c3fc0e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -16,35 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "text_classifier_graph", - srcs = ["text_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_resources_calculator", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Text Classifier +# https://developers.google.com/mediapipe/solutions/text/text_classifier cc_library( name = "text_classifier", srcs = ["text_classifier.cc"], hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":text_classifier_graph", "//mediapipe/framework:packet", @@ -65,6 +43,31 @@ cc_library( ], ) +cc_library( + name = "text_classifier_graph", + srcs = ["text_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator_cpu", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_calculator", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + cc_test( name = "text_classifier_test", srcs = ["text_classifier_test.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 36ff68a07..3be92f309 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -// TODO: remove once Java API migration is over. -// Struct holding the different output streams produced by the text classifier. -struct TextClassifierOutputStreams { - Source classification_result; - Source classifications; -}; - } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -72,10 +64,6 @@ struct TextClassifierOutputStreams { // Outputs: // CLASSIFICATIONS - ClassificationResult @Optional // The classification results aggregated by classifier head. -// TODO: remove once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). // // Example: // node { @@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto output_streams, + auto classifications, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; - output_streams.classifications >> - graph[Output(kClassificationsTag)]; + classifications >> graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -124,18 +109,18 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildTextClassifierTask( + absl::StatusOr> BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. @@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return TextClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], - /*classifications=*/postprocessing[Output( - kClassificationsTag)]}; + return postprocessing[Output(kClassificationsTag)]; } }; diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index e2e16c9c1..4c970159e 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -16,10 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Embedder +# https://developers.google.com/mediapipe/solutions/text/text_embedder cc_library( name = "text_embedder", srcs = ["text_embedder.cc"], hdrs = ["text_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_graph", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", @@ -45,17 +48,17 @@ cc_library( name = "text_embedder_graph", srcs = ["text_embedder_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 79eedb6b5..225ef07bd 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. @@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding result. diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index e8e197a1d..1f5ab5faf 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -19,11 +19,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( name = "image_processing_options", hdrs = ["image_processing_options.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/cc/components/containers:rect", ], diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 75289b1e8..d473a8dc3 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -18,6 +18,51 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Gesture Recognizer +# https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":gesture_recognizer_graph", + ":gesture_recognizer_result", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + cc_library( name = "handedness_util", srcs = ["handedness_util.cc"], @@ -59,10 +104,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_cache", @@ -127,51 +169,9 @@ cc_library( cc_library( name = "gesture_recognizer_result", hdrs = ["gesture_recognizer_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) - -cc_library( - name = "gesture_recognizer", - srcs = ["gesture_recognizer.cc"], - hdrs = ["gesture_recognizer.h"], - deps = [ - ":gesture_recognizer_graph", - ":gesture_recognizer_result", - ":hand_gesture_recognizer_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 8d555b12c..e7fcf6fd9 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 7b6a8c79d..d7e983d81 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -29,8 +29,6 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" -#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 71cef6270..55162d09b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -46,7 +46,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 06bb2e549..c24548c9b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -226,21 +226,23 @@ class HandDetectorGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); auto& image_to_tensor_options = *preprocessing - .GetOptions() + .GetOptions() .mutable_image_to_tensor_options(); image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); image_in >> preprocessing.In("IMAGE"); norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5c5073fc2..03ec45f7d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -18,6 +18,48 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Hand Landmarker +# https://developers.google.com/mediapipe/solutions/vision/hand_landmarker +cc_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.cc"], + hdrs = ["hand_landmarker.h"], + visibility = ["//visibility:public"], + deps = [ + ":hand_landmarker_graph", + ":hand_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "hand_landmark", + hdrs = ["hand_landmark.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], @@ -52,7 +94,7 @@ cc_library( "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", @@ -113,44 +155,11 @@ cc_library( cc_library( name = "hand_landmarker_result", hdrs = ["hand_landmarker_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) -cc_library( - name = "hand_landmarker", - srcs = ["hand_landmarker.cc"], - hdrs = ["hand_landmarker.h"], - deps = [ - ":hand_landmarker_graph", - ":hand_landmarker_result", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], -) - # TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h new file mode 100644 index 000000000..c8dbc9254 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ + +namespace mediapipe::tasks::vision::hand_landmarker { + +// The 21 hand landmarks. +enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +}; + +} // namespace mediapipe::tasks::vision::hand_landmarker + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3a9ed5bc2..2b818b2e5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 1f127deb8..014830ba2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -281,14 +281,15 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In("IMAGE"); hand_rect >> preprocessing.In("NORM_RECT"); auto image_size = preprocessing[Output>("IMAGE_SIZE")]; diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index b59d8d682..514e601ef 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_classifier_graph", - srcs = ["image_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Classifier +# https://developers.google.com/mediapipe/solutions/vision/image_classifier cc_library( name = "image_classifier", srcs = ["image_classifier.cc"], hdrs = ["image_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_graph", "//mediapipe/framework:packet", @@ -69,4 +49,27 @@ cc_library( ], ) +cc_library( + name = "image_classifier_graph", + srcs = ["image_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8fa1a0d2a..2d0379c66 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS"; // Struct holding the different output streams produced by the image classifier // subgraph. struct ImageClassifierOutputStreams { - Source classification_result; Source classifications; Source image; }; @@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams { // The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. -// TODO: remove this output once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example: // node { @@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -142,14 +135,15 @@ class ImageClassifierGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); @@ -174,8 +168,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. return ImageClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], /*classifications=*/ postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index ea7f40261..d729eaf1a 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_embedder_graph", - srcs = ["image_embedder_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Embedder +# https://developers.google.com/mediapipe/solutions/vision/image_embedder cc_library( name = "image_embedder", srcs = ["image_embedder.cc"], hdrs = ["image_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_graph", "//mediapipe/framework/api2:builder", @@ -67,4 +47,27 @@ cc_library( ], ) +cc_library( + name = "image_embedder_graph", + srcs = ["image_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 11e25144c..81ccb5361 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -20,10 +20,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -130,14 +130,15 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); @@ -151,10 +152,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding results. diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 4adba5ab7..72b3e7ee3 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto"; +option java_outer_classname = "ImageEmbedderGraphOptionsProto"; + message ImageEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional ImageEmbedderGraphOptions ext = 476348187; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 4c43a07f5..2124fe6e0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,10 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Image Segmenter +# https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], + visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", @@ -53,10 +56,10 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 43bf5b7e6..511d3b9c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -98,7 +98,7 @@ struct ImageSegmenterOptions { // - list of segmented masks. // - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. // - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. +// `channels`. // - batch is always 1 // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 44742e043..d5eb5af0d 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -27,8 +27,8 @@ limitations under the License. #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -243,14 +243,15 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 8220d8b7f..c2dd9995d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -16,50 +16,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "object_detector_graph", - srcs = ["object_detector_graph.cc"], - deps = [ - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", - "//mediapipe/calculators/util:detection_projection_calculator", - "//mediapipe/calculators/util:detection_transformation_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Object Detector +# https://developers.google.com/mediapipe/solutions/vision/object_detector cc_library( name = "object_detector", srcs = ["object_detector.cc"], hdrs = ["object_detector.h"], + visibility = [ + "//mediapipe/tasks:users", + ], deps = [ ":object_detector_graph", "//mediapipe/calculators/core:concatenate_vector_calculator", @@ -86,4 +51,44 @@ cc_library( ], ) +cc_library( + name = "object_detector_graph", + srcs = ["object_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b149cea0f..f5dc7e061 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -561,14 +561,15 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 6771335ad..2d29ccf23 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -66,10 +66,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", @@ -92,12 +92,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 0f3374175..d78685fe3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -266,7 +266,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /* * Sends audio data (a block in a continuous audio stream) to perform audio classification, and - * the results will be available via the {@link ResultListener} provided in the + * the results will be available via the {@link ResultListener} provided in the * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with * the audio stream mode. * @@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /** * Validates and builds the {@link AudioClassifierOptions} instance. * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the audio classifier - * is in the audio stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final AudioClassifierOptions build() { AudioClassifierOptions options = autoBuild(); @@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi { "The audio classifier is in the audio clips mode, a user-defined result listener" + " shouldn't be provided in AudioClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index c0bc04a4e..4bc505d84 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -309,10 +309,24 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score - * threshold, number of results, etc. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -354,7 +368,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -362,7 +378,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -372,12 +390,14 @@ public final class AudioEmbedder extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java index ee4df0198..a986048f0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -65,8 +65,8 @@ public abstract class AudioEmbedderResult implements TaskResult { /** * Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents - * one audio embedding result in an audio stream, and s only available when running with the audio - * stream mode. + * one audio embedding result in an audio stream, and is only available when running with the + * audio stream mode. */ public abstract Optional embeddingResult(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 8eaf0adcb..2782f8d36 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable { defaultSampleRate = sampleRate; } } + /** * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java index f0a123810..a778eae46 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java @@ -20,7 +20,7 @@ package com.google.mediapipe.tasks.audio.core; *

    *
  • AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips. *
  • AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from - * microphone. + * a microphone. *
*/ public enum RunningMode { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..ad17d5552 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -83,6 +83,15 @@ android_library( ], ) +android_library( + name = "normalized_landmark", + srcs = ["NormalizedLandmark.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index e45866190..7fb1b99d0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -18,16 +18,16 @@ import com.google.auto.value.AutoValue; import java.util.Objects; /** - * Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the - * landmark coordinates is normalized respect to the dimension of image, and the coordinates values - * are in the range of [0,1]. Otherwise, it represenet a point in world coordinates. + * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in + * meters. z represents the landmark depth, and the smaller the value the closer the world landmark + * is to the camera. */ @AutoValue public abstract class Landmark { private static final float TOLERANCE = 1e-6f; - public static Landmark create(float x, float y, float z, boolean normalized) { - return new AutoValue_Landmark(x, y, z, normalized); + public static Landmark create(float x, float y, float z) { + return new AutoValue_Landmark(x, y, z); } // The x coordinates of the landmark. @@ -39,28 +39,24 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); - // Whether this landmark is normalized with respect to the image size. - public abstract boolean normalized(); - @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { return false; } Landmark other = (Landmark) o; - return other.normalized() == this.normalized() - && Math.abs(other.x() - this.x()) < TOLERANCE + return Math.abs(other.x() - this.x()) < TOLERANCE && Math.abs(other.x() - this.y()) < TOLERANCE && Math.abs(other.x() - this.z()) < TOLERANCE; } @Override public final int hashCode() { - return Objects.hash(x(), y(), z(), normalized()); + return Objects.hash(x(), y(), z()); } @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java new file mode 100644 index 000000000..e77f3c3d4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -0,0 +1,63 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are + * normalized to [0.0, 1.0] by the image width and height respectively. z represents the landmark + * depth, and the smaller the value the closer the landmark is to the camera. The magnitude of z + * uses roughly the same scale as x. + */ +@AutoValue +public abstract class NormalizedLandmark { + private static final float TOLERANCE = 1e-6f; + + public static NormalizedLandmark create(float x, float y, float z) { + return new AutoValue_NormalizedLandmark(x, y, z); + } + + // The x coordinates of the normalized landmark. + public abstract float x(); + + // The y coordinates of the normalized landmark. + public abstract float y(); + + // The z coordinates of the normalized landmark. + public abstract float z(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof NormalizedLandmark)) { + return false; + } + NormalizedLandmark other = (NormalizedLandmark) o; + return Math.abs(other.x() - this.x()) < TOLERANCE + && Math.abs(other.x() - this.y()) < TOLERANCE + && Math.abs(other.x() - this.z()) < TOLERANCE; + } + + @Override + public final int hashCode() { + return Objects.hash(x(), y(), z()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index e61e59390..1f99f1612 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -29,19 +29,6 @@ android_library( ], ) -android_library( - name = "embedderoptions", - srcs = ["EmbedderOptions.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) - # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java deleted file mode 100644 index 3cd197234..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.processors; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; - -/** Embedder options shared across MediaPipe Java embedding tasks. */ -@AutoValue -public abstract class EmbedderOptions { - - /** Builder for {@link EmbedderOptions} */ - @AutoValue.Builder - public abstract static class Builder { - /** - * Sets whether L2 normalization should be performed on the returned embeddings. Use this option - * only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. - * In most cases, this is already the case and L2 norm is thus achieved through TF Lite - * inference. - * - *

False by default. - */ - public abstract Builder setL2Normalize(boolean l2Normalize); - - /** - * Sets whether the returned embedding should be quantized to bytes via scalar quantization. - * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed - * to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} if this is - * not the case. - * - *

False by default. - */ - public abstract Builder setQuantize(boolean quantize); - - public abstract EmbedderOptions build(); - } - - public abstract boolean l2Normalize(); - - public abstract boolean quantize(); - - public static Builder builder() { - return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false); - } - - /** - * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions} - * protobuf message. - */ - public EmbedderOptionsProto.EmbedderOptions convertToProto() { - return EmbedderOptionsProto.EmbedderOptions.newBuilder() - .setL2Normalize(l2Normalize()) - .setQuantize(quantize()) - .build(); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 2b648bc43..d91c03cc2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -18,7 +18,6 @@ load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build load("@build_bazel_rules_android//android:rules.bzl", "android_library") _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", @@ -42,6 +41,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", @@ -286,9 +286,9 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", ] + select({ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 0e72878ab..5b10e9aab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -48,12 +48,11 @@ android_library( deps = [ "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//third_party:autovalue", @@ -75,11 +74,11 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 341d6bf91..0ea91a9f8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.TaskInfo; @@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); - public abstract TextClassifierOptions build(); + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract TextClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link TextClassifierOptions} instance. + * + * @throws IllegalArgumentException if any of the set options are invalid. + */ + public final TextClassifierOptions build() { + TextClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } + return options; + } } abstract BaseOptions baseOptions(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); public static Builder builder() { - return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + return new AutoValue_TextClassifier_TextClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 95fa1f087..9b464d0e8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -25,7 +25,7 @@ import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; @@ -41,7 +41,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; /** * Performs embedding extraction on text. @@ -218,20 +217,38 @@ public final class TextEmbedder implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); public abstract TextEmbedderOptions build(); } abstract BaseOptions baseOptions(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); public static Builder builder() { - return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder() + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -240,12 +257,14 @@ public final class TextEmbedder implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 4dc4a547e..6161fe032 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -43,6 +43,7 @@ cc_binary( "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -96,12 +97,11 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", @@ -135,6 +135,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", @@ -145,6 +146,7 @@ android_library( android_library( name = "handlandmarker", srcs = [ + "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", ], @@ -166,6 +168,36 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "imageembedder", + srcs = [ + "imageembedder/ImageEmbedder.java", + "imageembedder/ImageEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imageembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java index ef76bf226..90b92175d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -43,41 +42,36 @@ public abstract class GestureRecognizerResult implements TaskResult { * @param gesturesProto a List of {@link ClassificationList} */ static GestureRecognizerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, List gesturesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); List> multiHandGestures = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + com.google.mediapipe.tasks.components.containers.NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -118,11 +112,10 @@ public abstract class GestureRecognizerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java new file mode 100644 index 000000000..7b21ebddf --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java @@ -0,0 +1,72 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.handlandmarker; + +import androidx.annotation.IntDef; + +/** The 21 hand landmarks. */ +public final class HandLandmark { + public static final int NUM_LANDMARKS = 21; + + public static final int WRIST = 0; + public static final int THUMB_CMC = 1; + public static final int THUMB_MCP = 2; + public static final int THUMB_IP = 3; + public static final int THUMB_TIP = 4; + public static final int INDEX_FINGER_MCP = 5; + public static final int INDEX_FINGER_PIP = 6; + public static final int INDEX_FINGER_DIP = 7; + public static final int INDEX_FINGER_TIP = 8; + public static final int MIDDLE_FINGER_MCP = 9; + public static final int MIDDLE_FINGER_PIP = 10; + public static final int MIDDLE_FINGER_DIP = 11; + public static final int MIDDLE_FINGER_TIP = 12; + public static final int RING_FINGER_MCP = 13; + public static final int RING_FINGER_PIP = 14; + public static final int RING_FINGER_DIP = 15; + public static final int RING_FINGER_TIP = 16; + public static final int PINKY_MCP = 17; + public static final int PINKY_PIP = 18; + public static final int PINKY_DIP = 19; + public static final int PINKY_TIP = 20; + + /** Represents a hand landmark type. */ + @IntDef({ + WRIST, + THUMB_CMC, + THUMB_MCP, + THUMB_IP, + THUMB_TIP, + INDEX_FINGER_MCP, + INDEX_FINGER_PIP, + INDEX_FINGER_DIP, + INDEX_FINGER_TIP, + MIDDLE_FINGER_MCP, + MIDDLE_FINGER_PIP, + MIDDLE_FINGER_DIP, + MIDDLE_FINGER_TIP, + RING_FINGER_MCP, + RING_FINGER_PIP, + RING_FINGER_DIP, + RING_FINGER_TIP, + PINKY_MCP, + PINKY_PIP, + PINKY_DIP, + PINKY_TIP, + }) + public @interface HandLandmarkType {} + + private HandLandmark() {} +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 2889b0e0b..9092c0a2d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.handlandmarker; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -32,47 +31,41 @@ import java.util.List; public abstract class HandLandmarkerResult implements TaskResult { /** - * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and - * handedness protobuf messages. + * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness + * protobuf messages. * * @param landmarksProto a List of {@link NormalizedLandmarkList} * @param worldLandmarksProto a List of {@link LandmarkList} * @param handednessesProto a List of {@link ClassificationList} */ static HandLandmarkerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { - List handWorldLandmarks = - new ArrayList<>(); + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -98,11 +91,10 @@ public abstract class HandLandmarkerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 5e278804b..8990f46fd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { /** * Validates and builds the {@link ImageClassifierOptions} instance. * * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the image classifier - * is in the live stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final ImageClassifierOptions build() { ImageClassifierOptions options = autoBuild(); @@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { "The image classifier is in the image or video mode, a user-defined result listener" + " shouldn't be provided in ImageClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..ebdb037d6 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java new file mode 100644 index 000000000..af053d860 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -0,0 +1,468 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs embedding extraction on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

    + *
  • Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + *
      + *
    • image input of size {@code [batch x height x width x channels]}. + *
    • batch inference is not supported ({@code batch} is required to be 1). + *
    • only RGB inputs are supported ({@code channels} is required to be 3). + *
    • if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the + * metadata for input normalization. + *
    + *
  • At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with shape {@code + * [1 x N]} where N is the number of dimensions in the produced embeddings. + *
+ */ +public final class ImageEmbedder extends BaseVisionTaskApi { + private static final String TAG = ImageEmbedder.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageEmbedder} instance from a model buffer and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageEmbedderResult convertToTaskResult(List packets) { + try { + return ImageEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + BaseVisionTaskApi.generateResultTimestampMs( + options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX))); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs embedding extraction on the provided single image with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image) { + return embed(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided single image. Only use this method when the + * {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageEmbedderResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs embedding extraction on the provided video frame with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) { + return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs embedding extraction on the provided video frame. Only use this method when the {@link + * ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform embedding extraction with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(MPImage image, long timestampMs) { + embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform embedding extraction, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method + * when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link ImageEmbedder}. */ + @AutoValue + public abstract static class ImageEmbedderOptions extends TaskOptions { + + /** Builder for {@link ImageEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image + * embedder has three modes: + * + *
    + *
  • IMAGE: The mode for performing embedding extraction on single image inputs. + *
  • VIDEO: The mode for performing embedding extraction on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing embedding extraction on a live stream of + * input data, such as from camera. In this mode, {@code setResultListener} must be + * called to set up a listener to receive the embedding results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. + */ + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * image embedder is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link ImageEmbedderOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image embedder is + * in the live stream mode. + */ + public final ImageEmbedderOptions build() { + ImageEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract boolean l2Normalize(); + + abstract boolean quantize(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setL2Normalize(false) + .setQuantize(false); + } + + /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java new file mode 100644 index 000000000..ee3f4abc9 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link ImageEmbedder}. */ +@AutoValue +public abstract class ImageEmbedderResult implements TaskResult { + + /** + * Creates an {@link ImageEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index 5e03d2a4c..5ed413f6a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -40,6 +40,37 @@ public class TextClassifierTest { private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index c0be4cffe..5821b36cc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -603,7 +603,7 @@ public class GestureRecognizerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java index 9e12d210f..c313d385d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -399,7 +399,7 @@ public class HandLandmarkerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 69820ce2d..dac11bf02 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -55,6 +54,37 @@ public class ImageClassifierTest { @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -105,7 +135,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -125,7 +155,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -141,7 +171,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .setScoreThreshold(0.02f) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -160,10 +190,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) - .build()) + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -183,11 +210,8 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setMaxResults(3) - .setCategoryDenylist(Arrays.asList("bagel")) - .build()) + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -207,7 +231,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -228,7 +252,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -251,7 +275,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -322,14 +346,14 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -353,7 +377,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -379,7 +403,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -388,7 +412,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -405,13 +429,14 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.VIDEO) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + ImageClassifierResult results = + imageClassifier.classifyForVideo(image, /* timestampMs= */ i); assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -424,7 +449,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -436,11 +461,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); + () -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -453,7 +478,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -466,7 +491,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, /*timestampMs=*/ i); + imageClassifier.classifyAsync(image, /* timestampMs= */ i); } } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..db303a439 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java new file mode 100644 index 000000000..8dec6f80b --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -0,0 +1,441 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageEmbedder}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class}) +public class ImageEmbedderTest { + private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_CROP_IMAGE = "burger_crop.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageEmbedderTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedder.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void embed_succeedsWithNoOptions() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithL2Normalization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithQuantization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776); + } + + @Test + public void embed_succeedsWithRegionOfInterest() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); + ImageEmbedderResult resultRoi = + imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoi.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f); + } + + @Test + public void embed_succeedsWithRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageEmbedderResult resultRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultRotated.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); + } + + @Test + public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageEmbedderResult resultRoiRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoiRotated.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageEmbedderTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(mode) + .setResultListener((result, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void embed_failsWithCallingWrongApiInImageMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void embed_succeedsWithImageMode() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithVideoMode() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + for (int i = 0; i < 3; ++i) { + ImageEmbedderResult result = + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + } + } + + @Test + public void embed_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void embed_succeedsWithLiveStreamMode() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageEmbedder.embedAsync(image, /* timestampMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndCorrectDimension( + ImageEmbedderResult result, boolean quantized) { + assertThat(result.embeddingResult().embeddings()).hasSize(1); + assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0); + assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature"); + if (quantized) { + assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024); + } else { + assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024); + } + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +} diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index 7955cc4dc..d82b6fe27 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -70,7 +70,8 @@ class AudioClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index a774d71e9..629e21882 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -71,7 +71,8 @@ class AudioEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=_EmbedderOptions) result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d931c26c7..9d275e167 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -68,7 +68,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index cfdb83740..9b5419883 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,10 +16,10 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.containers.proto import category_pb2 +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_CategoryProto = category_pb2.Category +_ClassificationProto = classification_pb2.Classification @dataclasses.dataclass @@ -45,23 +45,23 @@ class Category: category_name: Optional[str] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _CategoryProto: + def to_pb2(self) -> _ClassificationProto: """Generates a Category protobuf object.""" - return _CategoryProto( + return _ClassificationProto( index=self.index, score=self.score, - display_name=self.display_name, - category_name=self.category_name) + label=self.category_name, + display_name=self.display_name) @classmethod @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category': + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category': """Creates a `Category` object from the given protobuf object.""" return Category( index=pb2_obj.index, score=pb2_obj.score, display_name=pb2_obj.display_name, - category_name=pb2_obj.category_name) + category_name=pb2_obj.label) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index 6ffdabe51..000468041 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -49,11 +49,7 @@ class Classifications: """Generates a Classifications protobuf object.""" classification_list_proto = _ClassificationListProto() for category in self.categories: - classification_proto = _ClassificationProto( - index=category.index, - score=category.score, - label=category.category_name, - display_name=category.display_name) + classification_proto = category.to_pb2() classification_list_proto.classification.append(classification_proto) return _ClassificationsProto( classification_list=classification_list_proto, @@ -65,14 +61,9 @@ class Classifications: def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': """Creates a `Classifications` object from the given protobuf object.""" categories = [] - for entry in pb2_obj.classification_list.classification: + for classification in pb2_obj.classification_list.classification: categories.append( - category_module.Category( - index=entry.index, - score=entry.score, - display_name=entry.display_name, - category_name=entry.label)) - + category_module.Category.create_from_pb2(classification)) return Classifications( categories=categories, head_index=pb2_obj.head_index, diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 76e2f4f4a..fc0018ab1 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -31,6 +31,7 @@ py_library( py_library( name = "base_options", srcs = ["base_options.py"], + visibility = ["//mediapipe/tasks:users"], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 122dc620f..b48fa2ccc 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -14,6 +14,7 @@ """Base options for MediaPipe Task APIs.""" import dataclasses +import os from typing import Any, Optional from mediapipe.tasks.cc.core.proto import base_options_pb2 @@ -49,10 +50,14 @@ class BaseOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _BaseOptionsProto: """Generates a BaseOptions protobuf object.""" + if self.model_asset_path is not None: + full_path = os.path.abspath(self.model_asset_path) + else: + full_path = None + return _BaseOptionsProto( model_asset=_ExternalFileProto( - file_name=self.model_asset_path, - file_content=self.model_asset_buffer)) + file_name=full_path, file_content=self.model_asset_buffer)) @classmethod @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index bb42da912..10b4b8a6e 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -23,6 +23,7 @@ py_library( srcs = [ "text_classifier.py", ], + visibility = ["//mediapipe/tasks:users"], deps = [ "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 92d547f20..9711e8b3a 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,6 +14,7 @@ """MediaPipe text classifier task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -48,7 +49,8 @@ class TextClassifierOptions: classifier_options: Options for the text classification task. """ base_options: _BaseOptions - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index f3e5eecbe..a9e560ac9 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -14,6 +14,7 @@ """MediaPipe text embedder task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -49,7 +50,8 @@ class TextEmbedderOptions: embedder_options: Options for the text embedder task. """ base_options: _BaseOptions - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=_EmbedderOptions) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 9b6fd8cab..227203a0d 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -181,9 +181,11 @@ class GestureRecognizerOptions: min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 canned_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) custom_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) result_callback: Optional[Callable[ [GestureRecognizerResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 3367f1da7..a0cd99a83 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -14,6 +14,7 @@ """MediaPipe hand landmarker task.""" import dataclasses +import enum from typing import Callable, Mapping, Optional, List from mediapipe.framework.formats import classification_pb2 @@ -53,6 +54,31 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 763160e1e..6cbce7860 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -70,7 +70,8 @@ class ImageClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=_ClassifierOptions) result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index f299fa590..a58dca3ae 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -69,7 +69,8 @@ class ImageEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=_EmbedderOptions) result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 9ef911f75..62fc8bb7c 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -110,7 +110,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): - list of segmented masks. - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - if `output_type` is CONFIDENCE_MASK, float32 Image list of size - `cahnnels`. + `channels`. - batch is always 1 An example of such model can be found at: diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index b8777e785..20e717433 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -13,10 +13,16 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_files(srcs = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ]) # Audio @@ -28,31 +34,19 @@ mediapipe_ts_library( ) rollup_bundle( - name = "audio_cjs_bundle", - config_file = "rollup.config.cjs.mjs", + name = "audio_bundle", + config_file = "rollup.config.mjs", entry_point = "audio.ts", - format = "cjs", - output_dir = False, - deps = [ - ":audio_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", - ], -) - -rollup_bundle( - name = "audio_iife_bundle", - config_file = "rollup.config.iife.mjs", - entry_point = "audio.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -69,8 +63,9 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", - ":audio_cjs_bundle", - ":audio_iife_bundle", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", + ":audio_bundle", ], ) @@ -83,31 +78,19 @@ mediapipe_ts_library( ) rollup_bundle( - name = "text_cjs_bundle", - config_file = "rollup.config.cjs.mjs", + name = "text_bundle", + config_file = "rollup.config.mjs", entry_point = "text.ts", - format = "cjs", - output_dir = False, - deps = [ - ":text_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", - ], -) - -rollup_bundle( - name = "text_iife_bundle", - config_file = "rollup.config.iife.mjs", - entry_point = "text.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -124,8 +107,9 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", - ":text_cjs_bundle", - ":text_iife_bundle", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", + ":text_bundle", ], ) @@ -138,31 +122,19 @@ mediapipe_ts_library( ) rollup_bundle( - name = "vision_cjs_bundle", - config_file = "rollup.config.cjs.mjs", + name = "vision_bundle", + config_file = "rollup.config.mjs", entry_point = "vision.ts", - format = "cjs", - output_dir = False, - deps = [ - ":vision_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", - ], -) - -rollup_bundle( - name = "vision_iife_bundle", - config_file = "rollup.config.iife.mjs", - entry_point = "vision.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -179,7 +151,8 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", - ":vision_cjs_bundle", - ":vision_iife_bundle", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", + ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 4a3b80594..2f4fb0315 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,4 +14,12 @@ * limitations under the License. */ -export * from '../../tasks/web/audio/index'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl, FilesetResolver as FilesetResolverImpl} from '../../tasks/web/audio/index'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; + +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..d08602521 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -9,5 +9,7 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 1bc4af309..6f785dd0d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,7 +2,7 @@ # # This task takes audio data and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,24 +10,35 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", - srcs = [ - "audio_classifier.ts", - "audio_classifier_options.ts", - "audio_classifier_result.ts", - ], + srcs = ["audio_classifier.ts"], deps = [ + ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_classifier_types", + srcs = [ + "audio_classifier_options.d.ts", + "audio_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e3700cd7a..265ba2b33 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -18,25 +18,24 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierResult} from './audio_classifier_result'; +export * from './audio_classifier_options'; +export * from './audio_classifier_result'; + const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and -// cannot be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; @@ -44,68 +43,70 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ -export class AudioClassifier extends TaskRunner { +export class AudioClassifier extends AudioTaskRunner { private classificationResults: AudioClassifierResult[] = []; - private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); /** * Initializes the Wasm runtime and creates a new audio classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioClassifierOptions The options for the audio classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - audioClassifierOptions: AudioClassifierOptions): + static createFromOptions( + wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - AudioClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await classifier.setOptions(audioClassifierOptions); - return classifier; + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + audioClassifierOptions); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -117,34 +118,19 @@ export class AudioClassifier extends TaskRunner { * * @param options The options for the audio classifier. */ - async setOptions(options: AudioClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: AudioClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Sets the sample rate for all calls to `classify()` that omit an explicit - * sample rate. `48000` is used as a default if this method is not called. - * - * @param sampleRate A sample rate (e.g. `44100`). - */ - setDefaultSampleRate(sampleRate: number) { - this.defaultSampleRate = sampleRate; - } - - /** - * Performs audio classification on the provided audio data and waits + * Performs audio classification on the provided audio clip and waits * synchronously for the response. * - * @param audioData An array of raw audio capture data, like - * from a call to getChannelData on an AudioBuffer. + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. * @param sampleRate The sample rate in Hz of the provided audio data. If not * set, defaults to the sample rate set via `setDefaultSampleRate()` or * `48000` if no custom default was set. @@ -152,18 +138,18 @@ export class AudioClassifier extends TaskRunner { */ classify(audioData: Float32Array, sampleRate?: number): AudioClassifierResult[] { - sampleRate = sampleRate ?? this.defaultSampleRate; + return this.processAudioClip(audioData, sampleRate); + } - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); - - const timestamp = performance.now(); - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); - this.addAudioToStream(audioData, timestamp); + /** Sends an audio package to the graph and returns the classifications. */ + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioClassifierResult[] { + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); @@ -206,7 +192,7 @@ export class AudioClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoVectorListener( + this.graphRunner.attachProtoVectorListener( TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { this.addJsAudioClassificationResults(binaryProtos); }); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts new file mode 100644 index 000000000..dc3c494bf --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Audio Classifier Task */ +export declare interface AudioClassifierOptions extends ClassifierOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts rename to mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD new file mode 100644 index 000000000..0555bb639 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -0,0 +1,43 @@ +# This contains the MediaPipe Audio Embedder Task. +# +# This task takes audio input and performs embedding. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_embedder", + srcs = ["audio_embedder.ts"], + deps = [ + ":audio_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_embedder_types", + srcs = [ + "audio_embedder_options.d.ts", + "audio_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts new file mode 100644 index 000000000..445dd5172 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -0,0 +1,217 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../../../../tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {AudioEmbedderOptions} from './audio_embedder_options'; +import {AudioEmbedderResult} from './audio_embedder_result'; + +export * from './audio_embedder_options'; +export * from './audio_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const AUDIO_STREAM = 'audio_in'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; +const AUDIO_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + +/** Performs embedding extraction on audio. */ +export class AudioEmbedder extends AudioTaskRunner { + private embeddingResults: AudioEmbedderResult[] = []; + private readonly options = new AudioEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new audio embedder from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param audioEmbedderOptions The options for the audio embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + audioEmbedderOptions: AudioEmbedderOptions): Promise { + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + audioEmbedderOptions); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the audio embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio embedder. + */ + override async setOptions(options: AudioEmbedderOptions): Promise { + await super.setOptions(options); + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + this.refreshGraph(); + } + + /** + * Performs embeding extraction on the provided audio clip and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The embedding resuls of the audio + */ + embed(audioData: Float32Array, sampleRate?: number): AudioEmbedderResult[] { + return this.processAudioClip(audioData, sampleRate); + } + + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioEmbedderResult[] { + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); + + this.embeddingResults = []; + this.finishProcessing(); + return this.embeddingResults; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(AUDIO_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('AUDIO:' + AUDIO_STREAM); + embedderNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.addOutputStream( + 'TIMESTAMPED_EMBEDDINGS:' + TIMESTAMPED_EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + }); + + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts similarity index 63% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts rename to mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts index 93bd9927e..ac22728ab 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Audio Embedder Task */ +export declare interface AudioEmbedderOptions extends EmbedderOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts similarity index 83% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts rename to mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts index 51b2b3947..13abc28d9 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts @@ -14,4 +14,4 @@ * limitations under the License. */ -export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +export {Embedding, EmbeddingResult as AudioEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD new file mode 100644 index 000000000..9ab6c7bee --- /dev/null +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -0,0 +1,14 @@ +# This package contains options shared by all MediaPipe Audio Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "audio_task_runner", + srcs = ["audio_task_runner.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts new file mode 100644 index 000000000..00cfe0253 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -0,0 +1,45 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Base class for all MediaPipe Audio Tasks. */ +export abstract class AudioTaskRunner extends TaskRunner { + private defaultSampleRate = 48000; + + /** + * Sets the sample rate for API calls that omit an explicit sample rate. + * `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** Sends an audio packet to the graph and awaits results. */ + protected abstract process( + audioData: Float32Array, sampleRate: number, timestampMs: number): T; + + /** Sends a single audio clip to the graph and awaits results. */ + protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + return this.process( + audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + } +} + + diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 114a8ceca..dbad8c617 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -// Audio Classifier -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_options'; -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_result'; export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..0f916bf88 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -15,10 +15,27 @@ */ /** - * Landmark represents a point in 3D space with x, y, z coordinates. If - * normalized is true, the landmark coordinates is normalized respect to the - * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represenet a point in world coordinates. + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. + * x and y are normalized to [0.0, 1.0] by the image width and height + * respectively. z represents the landmark depth, and the smaller the value the + * closer the landmark is to the camera. The magnitude of z uses roughly the + * same scale as x. + */ +export declare interface NormalizedLandmark { + /** The x coordinates of the normalized landmark. */ + x: number; + + /** The y coordinates of the normalized landmark. */ + y: number; + + /** The z coordinates of the normalized landmark. */ + z: number; +} + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. The + * landmark coordinates are in meters. z represents the landmark depth, + * and the smaller the value the closer the world landmark is to the camera. */ export declare interface Landmark { /** The x coordinates of the landmark. */ @@ -29,7 +46,4 @@ export declare interface Landmark { /** The z coordinates of the landmark. */ z: number; - - /** Whether this landmark is normalized with respect to the image size. */ - normalized: boolean; } diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 1b56bf4c9..86e743928 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -17,7 +17,6 @@ mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ - "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:classification_result", ], diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index ac24a8db6..16d562262 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD new file mode 100644 index 000000000..1c1ba69ca --- /dev/null +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -0,0 +1,11 @@ +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts new file mode 100644 index 000000000..1f483b9b6 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -0,0 +1,63 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +/** + * Computes cosine similarity[1] between two `Embedding` objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types (float vs. quantized), + * have different sizes, or have an L2-norm of 0. + */ +export function computeCosineSimilarity(u: Embedding, v: Embedding): number { + if (u.floatEmbedding && v.floatEmbedding) { + return compute(u.floatEmbedding, v.floatEmbedding); + } + if (u.quantizedEmbedding && v.quantizedEmbedding) { + return compute( + convertToBytes(u.quantizedEmbedding), + convertToBytes(v.quantizedEmbedding)); + } + throw new Error( + 'Cannot compute cosine similarity between quantized and float embeddings.'); +} + +function convertToBytes(data: Uint8Array): number[] { + return Array.from(data, v => v - 128); +} + +function compute(u: number[], v: number[]) { + if (u.length !== v.length) { + throw new Error( + `Cannot compute cosine similarity between embeddings of different sizes (${ + u.length} vs. ${v.length}).`); + } + let dotProduct = 0.0; + let normU = 0.0; + let normV = 0.0; + for (let i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new Error( + 'Cannot compute cosine similarity on embedding with 0 norm.'); + } + return dotProduct / Math.sqrt(normU * normV); +} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index e9ef85d46..de429690d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -7,23 +7,30 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_declaration( name = "core", srcs = [ - "base_options.d.ts", - "wasm_loader_options.d.ts", + "task_runner_options.d.ts", + "wasm_fileset.d.ts", ], ) mediapipe_ts_library( name = "task_runner", - srcs = [ - "task_runner.ts", - ], + srcs = ["task_runner.ts"], deps = [ + ":core", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) +mediapipe_ts_library( + name = "fileset_resolver", + srcs = ["fileset_resolver.ts"], + deps = [":core"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 3dec8d27e..08e7a7664 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -14,13 +14,8 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - -/** Options to configure the Mediapipe Classifier Task. */ +/** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts index 78ddad1ae..8669acfcb 100644 --- a/mediapipe/tasks/web/core/embedder_options.d.ts +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -14,13 +14,8 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - -/** Options to configure the MediaPipe Embedder Task */ +/** Options to configure a MediaPipe Embedder Task */ export declare interface EmbedderOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * Whether to normalize the returned feature vector with L2 norm. Use this * option only if the model does not already contain a native L2_NORMALIZATION diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts new file mode 100644 index 000000000..d4691243b --- /dev/null +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -0,0 +1,130 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource URL builder + +import {WasmFileset} from './wasm_fileset'; + +let supportsSimd: boolean|undefined; + +/** + * Simple WASM program to test compatibility with the M91 instruction set. + * Compiled from + * https://github.com/GoogleChromeLabs/wasm-feature-detect/blob/main/src/detectors/simd/module.wat + */ +const WASM_SIMD_CHECK = new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 5, 1, 96, 0, 1, 123, 3, + 2, 1, 0, 10, 10, 1, 8, 0, 65, 0, 253, 15, 253, 98, 11 +]); + +async function isSimdSupported(): Promise { + if (supportsSimd === undefined) { + try { + await WebAssembly.instantiate(WASM_SIMD_CHECK); + supportsSimd = true; + } catch { + supportsSimd = false; + } + } + + return supportsSimd; +} + +async function createFileset( + taskName: string, basePath: string = '.'): Promise { + if (await isSimdSupported()) { + return { + wasmLoaderPath: + `${basePath}/${taskName}_wasm_internal.js`, + wasmBinaryPath: + `${basePath}/${taskName}_wasm_internal.wasm`, + }; + } else { + return { + wasmLoaderPath: + `${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: + `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, + }; + } +} + +// tslint:disable:class-as-namespace + +/** + * Resolves the files required for the MediaPipe Task APIs. + * + * This class verifies whether SIMD is supported in the current environment and + * loads the SIMD files only if support is detected. The returned filesets + * require that the Wasm files are published without renaming. If this is not + * possible, you can invoke the MediaPipe Tasks APIs using a manually created + * `WasmFileset`. + */ +export class FilesetResolver { + /** + * Returns whether SIMD is supported in the current environment. + * + * If your environment requires custom locations for the MediaPipe Wasm files, + * you can use `isSimdSupported()` to decide whether to load the SIMD-based + * assets. + * + * @return Whether SIMD support was detected in the current environment. + */ + static isSimdSupported(): Promise { + return isSimdSupported(); + } + + /** + * Creates a fileset for the MediaPipe Audio tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Audio + * tasks. + */ + static forAudioTasks(basePath?: string): Promise { + return createFileset('audio', basePath); + } + + /** + * Creates a fileset for the MediaPipe Text tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Text + * tasks. + */ + static forTextTasks(basePath?: string): Promise { + return createFileset('text', basePath); + } + + /** + * Creates a fileset for the MediaPipe Vision tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Vision + * tasks. + */ + static forVisionTasks(basePath?: string): Promise { + return createFileset('vision', basePath); + } +} + + diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c948930fc..d769139bc 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,27 +14,77 @@ * limitations under the License. */ +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; +import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; +import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; -import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; -import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; + +import {WasmFileset} from './wasm_fileset'; + +// None of the MP Tasks ship bundle assets. +const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const WasmMediaPipeImageLib = - SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); +const GraphRunnerImageLibType = + SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class GraphRunnerImageLib extends GraphRunnerImageLibType {} /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends WasmMediaPipeImageLib { +export abstract class TaskRunner { + protected abstract baseOptions: BaseOptionsProto; + protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; - constructor(wasmModule: WasmModule) { - super(wasmModule); + /** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ + protected static async createInstance, + O extends TaskRunnerOptions>( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset, options: O): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. - this.setAutoRenderToScreen(false); + this.graphRunner.setAutoRenderToScreen(false); // Enables use of our model resource caching graph service. - this.registerModelResourcesGraphService(); + this.graphRunner.registerModelResourcesGraphService(); + } + + /** Configures the shared options of a MediaPipe Task. */ + async setOptions(options: O): Promise { + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } } /** @@ -47,11 +97,11 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * @param isBinary This should be set to true if the graph is in * binary format, and false if it is in human-readable text format. */ - override setGraph(graphData: Uint8Array, isBinary: boolean): void { - this.attachErrorListener((code, message) => { + protected setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.graphRunner.attachErrorListener((code, message) => { this.processingErrors.push(new Error(message)); }); - super.setGraph(graphData, isBinary); + this.graphRunner.setGraph(graphData, isBinary); this.handleErrors(); } @@ -60,8 +110,8 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * far as possible, performing all processing until no more processing can be * done. */ - override finishProcessing(): void { - super.finishProcessing(); + protected finishProcessing(): void { + this.graphRunner.finishProcessing(); this.handleErrors(); } diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts similarity index 85% rename from mediapipe/tasks/web/core/base_options.d.ts rename to mediapipe/tasks/web/core/task_runner_options.d.ts index 86635b8c7..aa0b4a028 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -16,7 +16,7 @@ // Placeholder for internal dependency on trusted resource url -/** Options to configure MediaPipe Tasks in general. */ +/** Options to configure MediaPipe model loading and processing. */ export declare interface BaseOptions { /** * The model path to the model asset file. Only one of `modelAssetPath` or @@ -33,3 +33,9 @@ export declare interface BaseOptions { /** Overrides the default backend to use for the provided model. */ delegate?: 'cpu'|'gpu'|undefined; } + +/** Options to configure MediaPipe Tasks in general. */ +export declare interface TaskRunnerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; +} diff --git a/mediapipe/tasks/web/core/wasm_loader_options.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts similarity index 88% rename from mediapipe/tasks/web/core/wasm_loader_options.d.ts rename to mediapipe/tasks/web/core/wasm_fileset.d.ts index 74436583d..18227eab9 100644 --- a/mediapipe/tasks/web/core/wasm_loader_options.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -16,8 +16,8 @@ // Placeholder for internal dependency on trusted resource url -/** An object containing the locations of all Wasm assets */ -export declare interface WasmLoaderOptions { +/** An object containing the locations of the Wasm assets */ +export declare interface WasmFileset { /** The path to the Wasm loader script. */ wasmLoaderPath: string; /** The path to the Wasm binary. */ diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json index 1870f18a6..89c9a599e 100644 --- a/mediapipe/tasks/web/package.json +++ b/mediapipe/tasks/web/package.json @@ -2,20 +2,10 @@ "name": "@mediapipe/tasks-__NAME__", "version": "__VERSION__", "description": "__DESCRIPTION__", - "main": "__NAME___cjs_bundle.js", - "module": "__NAME___cjs_bundle.js", - "jsdeliver": "__NAME___iife_bundle.js", - "exports": { - ".": "./__NAME___cjs_bundle.js", - "./loader": "./wasm/__NAME___wasm_internal.js", - "./wasm": "./wasm/__NAME___wasm_internal.wasm" - }, + "main": "__NAME___bundle.js", "author": "mediapipe@google.com", "license": "Apache-2.0", "types": "__TYPES__", - "dependencies": { - "google-protobuf": "^3.21.2" - }, "homepage": "http://mediapipe.dev", "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ] } diff --git a/mediapipe/tasks/web/rollup.config.iife.mjs b/mediapipe/tasks/web/rollup.config.iife.mjs deleted file mode 100644 index 1320927aa..000000000 --- a/mediapipe/tasks/web/rollup.config.iife.mjs +++ /dev/null @@ -1,21 +0,0 @@ -import resolve from '@rollup/plugin-node-resolve'; -import commonjs from '@rollup/plugin-commonjs'; -import terser from '@rollup/plugin-terser'; -import replace from '@rollup/plugin-replace'; - -export default { - output: { - name: 'bundle', - sourcemap: false - }, - plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), - resolve({browser: true}), - commonjs(), - terser() - ] -} diff --git a/mediapipe/tasks/web/rollup.config.cjs.mjs b/mediapipe/tasks/web/rollup.config.mjs similarity index 86% rename from mediapipe/tasks/web/rollup.config.cjs.mjs rename to mediapipe/tasks/web/rollup.config.mjs index 5f8ca1848..e633bf702 100644 --- a/mediapipe/tasks/web/rollup.config.cjs.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,6 +1,7 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; import replace from '@rollup/plugin-replace'; +import terser from '@rollup/plugin-terser'; export default { plugins: [ @@ -10,6 +11,7 @@ export default { delimiters: ['', ''] }), resolve(), - commonjs() + commonjs(), + terser() ] } diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index f8a0b6457..0636714b8 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,4 +14,12 @@ * limitations under the License. */ -export * from '../../tasks/web/text/index'; +import {FilesetResolver as FilesetResolverImpl, TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index a369d0af0..159db1a0d 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -8,6 +8,8 @@ mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", ], ) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index dc511a426..a28e4dd1c 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -// Text Classifier -export * from '../../../tasks/web/text/text_classifier/text_classifier_options'; -export * from '../../../tasks/web/text/text_classifier/text_classifier_result'; export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 4ebdce18a..2a7de21d6 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -3,7 +3,7 @@ # This task takes text input performs Natural Language classification (including # BERT-based text classification). -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,24 +11,35 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", - srcs = [ - "text_classifier.ts", - "text_classifier_options.ts", - "text_classifier_result.ts", - ], + srcs = ["text_classifier.ts"], deps = [ + ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "text_classifier_types", + srcs = [ + "text_classifier_options.d.ts", + "text_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index e1d0c9601..8810d4b42 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -17,18 +17,21 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierResult} from './text_classifier_result'; +export * from './text_classifier_options'; +export * from './text_classifier_result'; + const INPUT_STREAM = 'text_in'; const CLASSIFICATIONS_STREAM = 'classifications_out'; const TEXT_CLASSIFIER_GRAPH = @@ -38,66 +41,62 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); /** * Initializes the Wasm runtime and creates a new text classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textClassifierOptions The options for the text classifier. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - TextClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await classifier.setOptions(textClassifierOptions); - return classifier; + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + textClassifierOptions); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -109,18 +108,20 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - async setOptions(options: TextClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: TextClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } /** * Performs Natural Language classification on the provided text and waits @@ -132,7 +133,7 @@ export class TextClassifier extends TaskRunner { classify(text: string): TextClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.classificationResult; @@ -156,10 +157,11 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts new file mode 100644 index 000000000..25592deb5 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Text Classifier Task */ +export declare interface TextClassifierOptions extends ClassifierOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 8e397ce6f..17d105258 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -3,7 +3,7 @@ # This task takes text input and performs embedding # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,22 +11,34 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", + srcs = ["text_embedder.ts"], + deps = [ + ":text_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "text_embedder_types", srcs = [ - "text_embedder.ts", "text_embedder_options.d.ts", "text_embedder_result.d.ts", ], deps = [ - "//mediapipe/framework:calculator_jspb_proto", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", - "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", - "//mediapipe/tasks/web/components/processors:embedder_options", - "//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 65df5df6a..62f9b06db 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -17,18 +17,22 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; import {TextEmbedderResult} from './text_embedder_result'; +export * from './text_embedder_options'; +export * from './text_embedder_result'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern @@ -41,66 +45,62 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); /** * Initializes the Wasm runtime and creates a new text embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textEmbedderOptions The options for the text embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - TextEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await embedder.setOptions(textEmbedderOptions); - return embedder; + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + textEmbedderOptions); } /** * Initializes the Wasm runtime and creates a new text embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -112,19 +112,20 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - async setOptions(options: TextEmbedderOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: TextEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } /** * Performs embeding extraction on the provided text and waits synchronously @@ -135,12 +136,25 @@ export class TextEmbedder extends TaskRunner { */ embed(text: string): TextEmbedderResult { // Get text embeddings by running our MediaPipe graph. - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.embeddingResult; } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); @@ -159,7 +173,7 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); }); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts index 9af263765..7689ee0c1 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {EmbedderOptions as TextEmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Text Embedder Task */ +export declare interface TextEmbedderOptions extends EmbedderOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 6ff8f725b..f1ced59af 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,4 +14,22 @@ * limitations under the License. */ -export * from '../../tasks/web/vision/index'; +import {FilesetResolver as FilesetResolverImpl, GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const FilesetResolver = FilesetResolverImpl; +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + FilesetResolver, + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3c45fbfa6..42bc0a494 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 7ab822b7c..b389a9b01 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,11 +1,24 @@ -# This package contains options shared by all MediaPipe Tasks for Web. +# This package contains options shared by all MediaPipe Vision Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( - name = "running_mode", - srcs = ["running_mode.ts"], - deps = ["//mediapipe/tasks/cc/core/proto:base_options_jspb_proto"], +mediapipe_ts_declaration( + name = "vision_task_options", + srcs = ["vision_task_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + ], +) + +mediapipe_ts_library( + name = "vision_task_runner", + srcs = ["vision_task_runner.ts"], + deps = [ + ":vision_task_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], ) diff --git a/mediapipe/tasks/web/vision/core/running_mode.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts similarity index 60% rename from mediapipe/tasks/web/vision/core/running_mode.ts rename to mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 1e9b1b9a7..76c0177a0 100644 --- a/mediapipe/tasks/web/vision/core/running_mode.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -14,23 +14,22 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** - * The running mode of a task. + * The two running modes of a vision task. * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ export type RunningMode = 'image'|'video'; -/** Configues the `useStreamMode` option . */ -export function configureRunningMode( - options: {runningMode?: RunningMode}, - proto?: BaseOptionsProto): BaseOptionsProto { - proto = proto ?? new BaseOptionsProto(); - if ('runningMode' in options) { - const useStreamMode = options.runningMode === 'video'; - proto.setUseStreamMode(useStreamMode); - } - return proto; +/** The options for configuring a MediaPipe vision task. */ +export declare interface VisionTaskOptions extends TaskRunnerOptions { + /** + * The running mode of the task. Default to the image mode. + * Vision tasks have two running modes: + * 1) The image mode for processing single image inputs. + * 2) The video mode for processing decoded frames of a video. + */ + runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts new file mode 100644 index 000000000..78b4859f2 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -0,0 +1,59 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; + +import {VisionTaskOptions} from './vision_task_options'; + +/** Base class for all MediaPipe Vision Tasks. */ +export abstract class VisionTaskRunner extends + TaskRunner { + /** Configures the shared options of a vision task. */ + override async setOptions(options: VisionTaskOptions): Promise { + await super.setOptions(options); + if ('runningMode' in options) { + const useStreamMode = + !!options.runningMode && options.runningMode !== 'image'; + this.baseOptions.setUseStreamMode(useStreamMode); + } + } + + /** Sends an image packet to the graph and awaits results. */ + protected abstract process(input: ImageSource, timestamp: number): T; + + /** Sends a single image to the graph and awaits results. */ + protected processImageData(image: ImageSource): T { + if (!!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with image mode. ' + + '\'runningMode\' must be set to \'image\'.'); + } + return this.process(image, performance.now()); + } + + /** Sends a single video frame to the graph and awaits results. */ + protected processVideoData(imageFrame: ImageSource, timestamp: number): T { + if (!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with video mode. ' + + '\'runningMode\' must be set to \'video\'.'); + } + return this.process(imageFrame, timestamp); + } +} + + diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index d67974a16..ddfd1a327 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -19,6 +19,7 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto", @@ -27,12 +28,11 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -47,5 +47,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 6c8072ff5..69a8118a6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb'; @@ -26,12 +27,11 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; @@ -64,9 +64,10 @@ FULL_IMAGE_RECT.setWidth(1); FULL_IMAGE_RECT.setHeight(1); /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends TaskRunner { +export class GestureRecognizer extends + VisionTaskRunner { private gestures: Category[][] = []; - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -81,66 +82,58 @@ export class GestureRecognizer extends TaskRunner { /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param gestureRecognizerOptions The options for the gesture recognizer. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const recognizer = await createMediaPipeLib( - GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await recognizer.setOptions(gestureRecognizerOptions); - return recognizer; + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + gestureRecognizerOptions); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return GestureRecognizer.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return GestureRecognizer.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); this.handLandmarksDetectorGraphOptions = @@ -156,10 +149,14 @@ export class GestureRecognizer extends TaskRunner { this.handGestureRecognizerGraphOptions); this.initDefaults(); + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -171,12 +168,8 @@ export class GestureRecognizer extends TaskRunner { * * @param options The options for the gesture recognizer. */ - async setOptions(options: GestureRecognizerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: GestureRecognizerOptions): Promise { + await super.setOptions(options); if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -232,21 +225,41 @@ export class GestureRecognizer extends TaskRunner { /** * Performs gesture recognition on the provided single image and waits - * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `image`. + * + * @param image A single image to process. * @return The detected gestures. */ - recognize(imageSource: ImageSource, timestamp: number = performance.now()): + recognize(image: ImageSource): GestureRecognizerResult { + return this.processImageData(image); + } + + /** + * Performs gesture recognition on the provided video frame and waits + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The detected gestures. + */ + recognizeForVideo(videoFrame: ImageSource, timestamp: number): + GestureRecognizerResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the gesture recognition and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): GestureRecognizerResult { this.gestures = []; this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -294,13 +307,12 @@ export class GestureRecognizer extends TaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, - z: handLandmarkProto.getZ() ?? 0, - normalized: true + z: handLandmarkProto.getZ() ?? 0 }); } this.landmarks.push(landmarks); @@ -321,8 +333,7 @@ export class GestureRecognizer extends TaskRunner { worldLandmarks.push({ x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, - z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false + z: handWorldLandmarkProto.getZ() ?? 0 }); } this.worldLandmarks.push(worldLandmarks); @@ -355,18 +366,22 @@ export class GestureRecognizer extends TaskRunner { graphConfig.addNode(recognizerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts index 45601a74c..dd8fc9548 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -14,14 +14,11 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Gesture Recognizer Task */ -export declare interface GestureRecognizerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface GestureRecognizerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the GestureRecognizer. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index 7c295c9e9..e570270b2 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ export declare interface GestureRecognizerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 25c70e0a5..fc3e6ef1f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -19,21 +19,22 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) mediapipe_ts_declaration( name = "hand_landmarker_types", srcs = [ + "hand_landmark.d.ts", "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], @@ -41,5 +42,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts new file mode 100644 index 000000000..ca2543f78 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index af10305b2..9a0823f23 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -19,15 +19,15 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; @@ -58,8 +58,8 @@ FULL_IMAGE_RECT.setWidth(1); FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ -export class HandLandmarker extends TaskRunner { - private landmarks: Landmark[][] = []; +export class HandLandmarker extends VisionTaskRunner { + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -71,65 +71,57 @@ export class HandLandmarker extends TaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param handLandmarkerOptions The options for the HandLandmarker. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const landmarker = await createMediaPipeLib( - HandLandmarker, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await landmarker.setOptions(handLandmarkerOptions); - return landmarker; + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + handLandmarkerOptions); } /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return HandLandmarker.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return HandLandmarker.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarksDetectorGraphOptions = new HandLandmarksDetectorGraphOptions(); this.options.setHandLandmarksDetectorGraphOptions( @@ -138,10 +130,14 @@ export class HandLandmarker extends TaskRunner { this.options.setHandDetectorGraphOptions(this.handDetectorGraphOptions); this.initDefaults(); + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -153,12 +149,8 @@ export class HandLandmarker extends TaskRunner { * * @param options The options for the hand landmarker. */ - async setOptions(options: HandLandmarkerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: HandLandmarkerOptions): Promise { + await super.setOptions(options); // Configure hand detector options. if ('numHands' in options) { @@ -185,20 +177,40 @@ export class HandLandmarker extends TaskRunner { /** * Performs hand landmarks detection on the provided single image and waits - * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `image`. + * + * @param image An image to process. * @return The detected hand landmarks. */ - detect(imageSource: ImageSource, timestamp: number = performance.now()): + detect(image: ImageSource): HandLandmarkerResult { + return this.processImageData(image); + } + + /** + * Performs hand landmarks detection on the provided video frame and waits + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The detected hand landmarks. + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): + HandLandmarkerResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the hand landmarker graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): HandLandmarkerResult { this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -244,13 +256,12 @@ export class HandLandmarker extends TaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, z: handLandmarkProto.getZ() ?? 0, - normalized: true }); } this.landmarks.push(landmarks); @@ -258,7 +269,7 @@ export class HandLandmarker extends TaskRunner { } /** - * Converts raw data into a landmark, and adds it to our worldLandmarks + * Converts raw data into a world landmark, and adds it to our worldLandmarks * list. */ private adddJsWorldLandmarks(data: Uint8Array[]): void { @@ -272,7 +283,6 @@ export class HandLandmarker extends TaskRunner { x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false }); } this.worldLandmarks.push(worldLandmarks); @@ -303,15 +313,18 @@ export class HandLandmarker extends TaskRunner { graphConfig.addNode(landmarkerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts index 53ad9440a..fe79b7089 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts @@ -14,13 +14,10 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe HandLandmarker Task */ -export declare interface HandLandmarkerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface HandLandmarkerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the HandLandmarker. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 044bdfbe7..89f867d69 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ export declare interface HandLandmarkerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 8506f3574..ebe64ecf4 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -16,16 +16,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -39,5 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 5d60e4a21..40e8b5099 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -17,13 +17,13 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -42,67 +42,70 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends TaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); /** * Initializes the Wasm runtime and creates a new image classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location + * Wasm binary and its loader. * @param imageClassifierOptions The options for the image classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - imageClassifierOptions: ImageClassifierOptions): + static createFromOptions( + wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - ImageClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await classifier.setOptions(imageClassifierOptions); - return classifier; + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + imageClassifierOptions); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -114,32 +117,45 @@ export class ImageClassifier extends TaskRunner { * * @param options The options for the image classifier. */ - async setOptions(options: ImageClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: ImageClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Performs image classification on the provided image and waits synchronously - * for the response. + * Performs image classification on the provided single image and waits + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `image`. * - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The classification result of the image */ - classify(imageSource: ImageSource, timestamp?: number): + classify(image: ImageSource): ImageClassifierResult { + return this.processImageData(image); + } + + /** + * Performs image classification on the provided video frame and waits + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The classification result of the image + */ + classifyForVideo(videoFrame: ImageSource, timestamp: number): + ImageClassifierResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the image classification graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.classificationResult; @@ -165,10 +181,11 @@ export class ImageClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts index a5f5c2386..e99dd2b69 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Image Classifier Task. */ +export declare interface ImageClassifierOptions extends ClassifierOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index d12a05ad9..2f012dc5e 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -2,7 +2,7 @@ # # This task performs embedding extraction on images. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,24 +10,36 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", - srcs = [ - "image_embedder.ts", - "image_embedder_options.ts", - "image_embedder_result.ts", - ], + srcs = ["image_embedder.ts"], deps = [ + ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/tasks/web/vision/core:running_mode", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:vision_task_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "image_embedder_types", + srcs = [ + "image_embedder_options.d.ts", + "image_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 4184e763c..f8b0204ee 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -17,14 +17,15 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {configureRunningMode} from '../../../../tasks/web/vision/core/running_mode'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -38,69 +39,75 @@ const EMBEDDINGS_STREAM = 'embeddings_out'; const TEXT_EMBEDDER_CALCULATOR = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; +export * from './image_embedder_options'; +export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ -export class ImageEmbedder extends TaskRunner { +export class ImageEmbedder extends VisionTaskRunner { private readonly options = new ImageEmbedderGraphOptions(); private embeddings: ImageEmbedderResult = {embeddings: []}; /** * Initializes the Wasm runtime and creates a new image embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param imageEmbedderOptions The options for the image embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - ImageEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await embedder.setOptions(imageEmbedderOptions); - return embedder; + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + imageEmbedderOptions); } /** * Initializes the Wasm runtime and creates a new image embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + static createFromModelPath( + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -112,45 +119,29 @@ export class ImageEmbedder extends TaskRunner { * * @param options The options for the image embedder. */ - async setOptions(options: ImageEmbedderOptions): Promise { - let baseOptionsProto = this.options.getBaseOptions(); - if (options.baseOptions) { - baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, baseOptionsProto); - } - baseOptionsProto = configureRunningMode(options, baseOptionsProto); - this.options.setBaseOptions(baseOptionsProto); - + override async setOptions(options: ImageEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } /** - * Performs embedding extraction on the provided image and waits synchronously - * for the response. - * - * Only use this method when the `useStreamMode` option is not set or - * expliclity set to `false`. + * Performs embedding extraction on the provided single image and waits + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `image`. * * @param image The image to process. * @return The classification result of the image */ embed(image: ImageSource): ImageEmbedderResult { - if (!!this.options.getBaseOptions()?.getUseStreamMode()) { - throw new Error( - 'Task is not initialized with image mode. ' + - '\'runningMode\' must be set to \'image\'.'); - } - return this.performEmbeddingExtraction(image, performance.now()); + return this.processImageData(image); } /** * Performs embedding extraction on the provided video frame and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is set to `true`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `video`. * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. @@ -158,19 +149,27 @@ export class ImageEmbedder extends TaskRunner { */ embedForVideo(imageFrame: ImageSource, timestamp: number): ImageEmbedderResult { - if (!this.options.getBaseOptions()?.getUseStreamMode()) { - throw new Error( - 'Task is not initialized with video mode. ' + - '\'runningMode\' must be set to \'video\' or \'live_stream\'.'); - } - return this.performEmbeddingExtraction(imageFrame, timestamp); + return this.processVideoData(imageFrame, timestamp); } - /** Runs the embedding extractio and blocks on the response. */ - private performEmbeddingExtraction(image: ImageSource, timestamp: number): + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + + /** Runs the embedding extraction and blocks on the response. */ + protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( image, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.embeddings; @@ -202,7 +201,7 @@ export class ImageEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.addJsImageEmdedding(binaryProto); }); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts new file mode 100644 index 000000000..8a04be5e1 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options for configuring a MediaPipe Image Embedder task. */ +export declare interface ImageEmbedderOptions extends EmbedderOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts deleted file mode 100644 index 4d795d0d8..000000000 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; -import {RunningMode} from '../../../../tasks/web/vision/core/running_mode'; - -/** The options for configuring a MediaPipe image embedder task. */ -export declare interface ImageEmbedderOptions extends EmbedderOptions { - /** - * The running mode of the task. Default to the image mode. - * Image embedder has three running modes: - * 1) The image mode for embedding image on single image inputs. - * 2) The video mode for embedding image on the decoded frames of a video. - * 3) The live stream mode for embedding image on the live stream of input - * data, such as from camera. - */ - runningMode?: RunningMode; -} diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts rename to mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 0ea844fc9..0337a0f2f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,19 +14,9 @@ * limitations under the License. */ -// Image Classifier export * from '../../../tasks/web/vision/image_classifier/image_classifier'; - -// Image Embedder -export * from '../../../tasks/web/vision/image_embedder/image_embedder_options'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder_result'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; - -// Gesture Recognizer export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; - -// Hand Landmarker export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; - -// Object Detector export * from '../../../tasks/web/vision/object_detector/object_detector'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index a74dc9211..198585258 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -17,12 +17,12 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -35,5 +35,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e17a42020..e2cfe0575 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -17,11 +17,11 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -41,66 +41,70 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends TaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); /** * Initializes the Wasm runtime and creates a new object detector from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param objectDetectorOptions The options for the Object Detector. Note that * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + static createFromOptions( + wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const detector = await createMediaPipeLib( - ObjectDetector, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await detector.setOptions(objectDetectorOptions); - return detector; + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + objectDetectorOptions); } /** * Initializes the Wasm runtime and creates a new object detector based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ObjectDetector.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new object detector based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ObjectDetector.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); } /** @@ -112,12 +116,8 @@ export class ObjectDetector extends TaskRunner { * * @param options The options for the object detector. */ - async setOptions(options: ObjectDetectorOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: ObjectDetectorOptions): Promise { + await super.setOptions(options); // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to @@ -157,16 +157,35 @@ export class ObjectDetector extends TaskRunner { /** * Performs object detection on the provided single image and waits - * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `image`. + * + * @param image An image to process. * @return The list of detected objects */ - detect(imageSource: ImageSource, timestamp?: number): Detection[] { + detect(image: ImageSource): Detection[] { + return this.processImageData(image); + } + + /** + * Performs object detection on the provided vidoe frame and waits + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The list of detected objects + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): Detection[] { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the object detector graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): + Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return [...this.detections]; @@ -223,9 +242,10 @@ export class ObjectDetector extends TaskRunner { graphConfig.addNode(detectorNode); - this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { - this.addJsObjectDetections(binaryProto); - }); + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index eec12cf17..7564e7760 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -14,39 +14,9 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - - /** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ - displayNamesLocale?: string|undefined; - - /** The maximum number of top-scored detection results to return. */ - maxResults?: number|undefined; - - /** - * Overrides the value provided in the model metadata. Results below this - * value are rejected. - */ - scoreThreshold?: number|undefined; - - /** - * Allowlist of category names. If non-empty, detection results whose category - * name is not in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryDenylist`. - */ - categoryAllowlist?: string[]|undefined; - - /** - * Denylist of category names. If non-empty, detection results whose category - * name is in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryAllowlist`. - */ - categoryDenylist?: string[]|undefined; -} +export interface ObjectDetectorOptions extends VisionTaskOptions, + ClassifierOptions {} diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index ab3390e0a..55c1df59f 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -228,6 +228,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", ], ) @@ -367,3 +368,21 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "image_test_utils", + testonly = 1, + srcs = ["image_test_utils.cc"], + hdrs = ["image_test_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + ], +) diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc new file mode 100644 index 000000000..815666985 --- /dev/null +++ b/mediapipe/util/image_test_utils.cc @@ -0,0 +1,57 @@ +#include "mediapipe/util/image_test_utils.h" + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +cv::Mat GetRgb(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); + return rgb; +} + +cv::Mat GetRgba(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); + return rgb; +} + +cv::Mat GetGray(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + LOG(FATAL) << "Unsupported input image channles: " << image_channels; +} + +Packet MakeImageFramePacket(cv::Mat input, int timestamp) { + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input, int timestamp) { + mediapipe::Image input_image(std::make_shared( + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +} // namespace mediapipe diff --git a/mediapipe/util/image_test_utils.h b/mediapipe/util/image_test_utils.h new file mode 100644 index 000000000..6df9644d2 --- /dev/null +++ b/mediapipe/util/image_test_utils.h @@ -0,0 +1,32 @@ +#ifndef MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ +#define MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ + +#include + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/opencv_core_inc.h" + +namespace mediapipe { + +// Reads the image file into cv::Mat with RGB channels. +cv::Mat GetRgb(const std::string& path); + +// Reads the image file into cv::Mat with RGBA channels. +cv::Mat GetRgba(const std::string& path); + +// Reads the image file into cv::Mat with Gray channel. +cv::Mat GetGray(const std::string& path); + +// Converts the image channels into corresponding ImageFormat. +mediapipe::ImageFormat::Format GetImageFormat(int image_channels); + +// Converts the cv::Mat into ImageFrame packet. +Packet MakeImageFramePacket(cv::Mat input, int timestamp = 0); + +// Converts the cv::Mat into Image packet. +Packet MakeImagePacket(cv::Mat input, int timestamp = 0); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ diff --git a/mediapipe/util/packet_test_util.h b/mediapipe/util/packet_test_util.h index 106d7f8d4..61e9322e1 100644 --- a/mediapipe/util/packet_test_util.h +++ b/mediapipe/util/packet_test_util.h @@ -32,30 +32,29 @@ namespace mediapipe { namespace internal { template -class PacketMatcher : public ::testing::MatcherInterface { +class PacketMatcher : public testing::MatcherInterface { public: template explicit PacketMatcher(InnerMatcher inner_matcher) : inner_matcher_( - ::testing::SafeMatcherCast(inner_matcher)) {} + testing::SafeMatcherCast(inner_matcher)) {} // Returns true iff the packet contains value of PayloadType satisfying // the inner matcher. - bool MatchAndExplain( - const Packet& packet, - ::testing::MatchResultListener* listener) const override { + bool MatchAndExplain(const Packet& packet, + testing::MatchResultListener* listener) const override { if (!packet.ValidateAsType().ok()) { *listener << packet.DebugString() << " does not contain expected type " << ExpectedTypeName(); return false; } - ::testing::StringMatchResultListener match_listener; + testing::StringMatchResultListener match_listener; const PayloadType& payload = packet.Get(); const bool matches = inner_matcher_.MatchAndExplain(payload, &match_listener); const std::string explanation = match_listener.str(); *listener << packet.DebugString() << " containing value " - << ::testing::PrintToString(payload); + << testing::PrintToString(payload); if (!explanation.empty()) { *listener << ", which " << explanation; } @@ -78,9 +77,28 @@ class PacketMatcher : public ::testing::MatcherInterface { return ::mediapipe::Demangle(typeid(PayloadType).name()); } - const ::testing::Matcher inner_matcher_; + const testing::Matcher inner_matcher_; }; +inline std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +std::string SourceString(Packet packet) { + std::ostringstream oss; + if (packet.IsEmpty()) { + oss << "Packet()"; + } else { + oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" + << packet.Get() << ")"; + } + oss << ".At(" << SourceString(packet.Timestamp()) << ")"; + return oss.str(); +} + } // namespace internal // Creates matcher validating that the packet contains value of expected type @@ -91,9 +109,9 @@ class PacketMatcher : public ::testing::MatcherInterface { // // EXPECT_THAT(MakePacket(42), PacketContains(Eq(42))) template -inline ::testing::Matcher PacketContains( +inline testing::Matcher PacketContains( InnerMatcher inner_matcher) { - return ::testing::MakeMatcher( + return testing::MakeMatcher( new internal::PacketMatcher(inner_matcher)); } @@ -110,7 +128,7 @@ inline ::testing::Matcher PacketContains( // Eq(42))) template -inline ::testing::Matcher PacketContainsTimestampAndPayload( +inline testing::Matcher PacketContainsTimestampAndPayload( TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) { return testing::AllOf( testing::Property("Packet::Timestamp", &Packet::Timestamp, @@ -118,6 +136,46 @@ inline ::testing::Matcher PacketContainsTimestampAndPayload( PacketContains(content_matcher)); } +template +class PacketEqMatcher : public testing::MatcherInterface { + public: + PacketEqMatcher(Packet packet) : packet_(packet) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet: " << internal::SourceString(packet_); + } + bool MatchAndExplain(Packet value, + testing::MatchResultListener* listener) const override { + bool unequal = (value.Timestamp() != packet_.Timestamp() || + value.IsEmpty() != packet_.IsEmpty() || + (!value.IsEmpty() && value.Get() != packet_.Get())); + if (unequal && listener->IsInterested()) { + *listener << "The actual packet: " << internal::SourceString(value); + } + return !unequal; + } + const Packet packet_; +}; + +template +testing::Matcher PacketEq(Packet packet) { + return MakeMatcher(new PacketEqMatcher(packet)); +} + +template +std::vector> PacketMatchers( + std::vector packets) { + std::vector> result; + for (const auto& packet : packets) { + result.push_back(PacketEq(packet)); + } + return result; +} + +} // namespace mediapipe + +namespace mediapipe { +using mediapipe::PacketContains; +using mediapipe::PacketContainsTimestampAndPayload; } // namespace mediapipe #endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_ diff --git a/mediapipe/util/resource_cache.h b/mediapipe/util/resource_cache.h index 4cd869f6a..2b3ccbc7d 100644 --- a/mediapipe/util/resource_cache.h +++ b/mediapipe/util/resource_cache.h @@ -17,6 +17,7 @@ #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/function_ref.h" #include "mediapipe/framework/port/logging.h" @@ -26,7 +27,8 @@ namespace mediapipe { // resource (e.g., image dimension for an image pool) is described bye the `Key` // type. The `Value` type must include an unset value, with implicit conversion // to bool reflecting set/unset state. -template +template ::hasher> class ResourceCache { public: Value Lookup( @@ -36,15 +38,14 @@ class ResourceCache { Entry* entry; if (map_it == map_.end()) { std::tie(map_it, std::ignore) = - map_.emplace(std::piecewise_construct, std::forward_as_tuple(key), - std::forward_as_tuple(key)); - entry = &map_it->second; + map_.try_emplace(key, std::make_unique(key)); + entry = map_it->second.get(); CHECK_EQ(entry->request_count, 0); entry->request_count = 1; entry_list_.Append(entry); if (entry->prev != nullptr) CHECK_GE(entry->prev->request_count, 1); } else { - entry = &map_it->second; + entry = map_it->second.get(); ++entry->request_count; Entry* larger = entry->prev; while (larger != nullptr && @@ -171,7 +172,7 @@ class ResourceCache { size_t size_ = 0; }; - std::unordered_map map_; + absl::flat_hash_map, KeyHash> map_; EntryList entry_list_; int total_request_count_ = 0; }; diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 40a474599..42b0e3889 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -802,7 +802,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImages) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_image(bytes.begin(), bytes.end()); AddImageEncoded(encoded_image, &sequence); AddImageEncoded(encoded_image, &sequence); @@ -843,7 +843,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFlowEncoded) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_flow(bytes.begin(), bytes.end()); AddForwardFlowEncoded(encoded_flow, &sequence); diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index dfbc8d659..5eeaa230f 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -21,6 +21,7 @@ #include "absl/status/status.h" #include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -89,7 +90,8 @@ class TFLiteGPURunner { serialized_binary_cache_ = std::move(cache); } - std::vector GetSerializedBinaryCache() { + absl::StatusOr> GetSerializedBinaryCache() { + RET_CHECK(cl_environment_) << "CL environment is not initialized."; return cl_environment_->GetSerializedBinaryCache(); } diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 319e99d5b..6bca24446 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -134,7 +134,6 @@ proto_library( mediapipe_cc_proto_library( name = "tone_models_cc_proto", srcs = ["tone_models.proto"], - visibility = ["//visibility:public"], deps = [":tone_models_proto"], ) @@ -142,7 +141,6 @@ mediapipe_cc_proto_library( name = "tone_estimation_cc_proto", srcs = ["tone_estimation.proto"], cc_deps = [":tone_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tone_estimation_proto"], ) @@ -153,21 +151,18 @@ mediapipe_cc_proto_library( ":tone_estimation_cc_proto", ":tone_models_cc_proto", ], - visibility = ["//visibility:public"], deps = [":region_flow_computation_proto"], ) mediapipe_cc_proto_library( name = "motion_saliency_cc_proto", srcs = ["motion_saliency.proto"], - visibility = ["//visibility:public"], deps = [":motion_saliency_proto"], ) mediapipe_cc_proto_library( name = "motion_estimation_cc_proto", srcs = ["motion_estimation.proto"], - visibility = ["//visibility:public"], deps = [":motion_estimation_proto"], ) @@ -179,7 +174,6 @@ mediapipe_cc_proto_library( ":motion_saliency_cc_proto", ":region_flow_computation_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_proto"], ) @@ -187,14 +181,12 @@ mediapipe_cc_proto_library( name = "region_flow_cc_proto", srcs = ["region_flow.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":region_flow_proto"], ) mediapipe_cc_proto_library( name = "motion_models_cc_proto", srcs = ["motion_models.proto"], - visibility = ["//visibility:public"], deps = [":motion_models_proto"], ) @@ -202,21 +194,18 @@ mediapipe_cc_proto_library( name = "camera_motion_cc_proto", srcs = ["camera_motion.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":camera_motion_proto"], ) mediapipe_cc_proto_library( name = "push_pull_filtering_cc_proto", srcs = ["push_pull_filtering.proto"], - visibility = ["//visibility:public"], deps = [":push_pull_filtering_proto"], ) mediapipe_cc_proto_library( name = "frame_selection_solution_evaluator_cc_proto", srcs = ["frame_selection_solution_evaluator.proto"], - visibility = ["//visibility:public"], deps = [":frame_selection_solution_evaluator_proto"], ) @@ -228,7 +217,6 @@ mediapipe_cc_proto_library( ":frame_selection_solution_evaluator_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":frame_selection_proto"], ) @@ -239,7 +227,6 @@ mediapipe_cc_proto_library( ":motion_models_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_proto"], ) @@ -247,7 +234,6 @@ mediapipe_cc_proto_library( name = "tracking_cc_proto", srcs = ["tracking.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tracking_proto"], ) @@ -255,14 +241,12 @@ mediapipe_cc_proto_library( name = "box_tracker_cc_proto", srcs = ["box_tracker.proto"], cc_deps = [":tracking_cc_proto"], - visibility = ["//visibility:public"], deps = [":box_tracker_proto"], ) mediapipe_cc_proto_library( name = "tracked_detection_manager_config_cc_proto", srcs = ["tracked_detection_manager_config.proto"], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_config_proto"], ) @@ -273,7 +257,6 @@ mediapipe_cc_proto_library( ":box_tracker_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_proto"], ) @@ -458,7 +441,6 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", ], ) @@ -739,7 +721,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", diff --git a/mediapipe/util/tracking/motion_analysis.cc b/mediapipe/util/tracking/motion_analysis.cc index 0b7678889..5b6a970cf 100644 --- a/mediapipe/util/tracking/motion_analysis.cc +++ b/mediapipe/util/tracking/motion_analysis.cc @@ -791,7 +791,7 @@ void MotionAnalysis::VisualizeBlurAnalysisRegions(cv::Mat* input_view) { region_flow_computation_->ComputeBlurMask(*input_view, &corner_values, &mask); cv::Mat mask_3c; - cv::cvtColor(mask, mask_3c, CV_GRAY2RGB); + cv::cvtColor(mask, mask_3c, cv::COLOR_GRAY2RGB); cv::addWeighted(*input_view, 0.5, mask_3c, 0.5, -128, *input_view); } diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index cfd5c23c2..708c868b5 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -30,6 +30,7 @@ #include "absl/container/node_hash_set.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_features2d_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_video_inc.h" @@ -935,12 +936,13 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, // Area based method best for downsampling. // For color images to temporary buffer. cv::Mat& resized = source.channels() == 1 ? dest_frame : *curr_color_image_; - cv::resize(source, resized, resized.size(), 0, 0, CV_INTER_AREA); + cv::resize(source, resized, resized.size(), 0, 0, cv::INTER_AREA); source_ptr = &resized; // Resize feature extraction mask if needed. if (!source_mask.empty()) { dest_mask.create(resized.rows, resized.cols, CV_8UC1); - cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, CV_INTER_NN); + cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, + cv::INTER_NEAREST); } } else if (!source_mask.empty()) { source_mask.copyTo(dest_mask); @@ -954,7 +956,7 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, const int dimension = visual_options.tiny_image_dimension(); data->tiny_image.create(dimension, dimension, type); cv::resize(*source_ptr, data->tiny_image, data->tiny_image.size(), 0, 0, - CV_INTER_AREA); + cv::INTER_AREA); } if (source_ptr->channels() == 1 && @@ -2286,7 +2288,7 @@ void RegionFlowComputation::ExtractFeatures( // Initialize mask from frame's feature extraction mask, by downsampling and // negating the latter mask. if (!data->mask.empty()) { - cv::resize(data->mask, mask, mask.size(), 0, 0, CV_INTER_NN); + cv::resize(data->mask, mask, mask.size(), 0, 0, cv::INTER_NEAREST); for (int y = 0; y < mask.rows; ++y) { uint8* mask_ptr = mask.ptr(y); for (int x = 0; x < mask.cols; ++x) { @@ -2590,12 +2592,6 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, cv::_InputArray input_frame2(data2.pyramid); #endif - // Using old c-interface for OpenCV's 2.2 tracker. - CvTermCriteria criteria; - criteria.type = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER; - criteria.max_iter = options_.tracking_options().tracking_iterations(); - criteria.epsilon = 0.02f; - feature_track_error_.resize(num_features); feature_status_.resize(num_features); if (use_cv_tracking_) { diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 0ac6dc2a5..435a8e200 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -28,7 +28,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" -#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD index dab6be50f..5c12947af 100644 --- a/mediapipe/web/graph_runner/BUILD +++ b/mediapipe/web/graph_runner/BUILD @@ -3,32 +3,24 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = [ - ":internal", "//mediapipe/tasks:internal", ]) -package_group( - name = "internal", - packages = [ - "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", - ], -) - mediapipe_ts_library( - name = "wasm_mediapipe_lib_ts", + name = "graph_runner_ts", srcs = [ - ":wasm_mediapipe_lib.ts", + ":graph_runner.ts", ], allow_unoptimized_namespaces = True, ) mediapipe_ts_library( - name = "wasm_mediapipe_image_lib_ts", + name = "graph_runner_image_lib_ts", srcs = [ - ":wasm_mediapipe_image_lib.ts", + ":graph_runner_image_lib.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) mediapipe_ts_library( @@ -37,5 +29,5 @@ mediapipe_ts_library( ":register_model_resources_graph_service.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/graph_runner.ts similarity index 86% rename from mediapipe/web/graph_runner/wasm_mediapipe_lib.ts rename to mediapipe/web/graph_runner/graph_runner.ts index 9ecf094ca..a9bb979af 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -15,9 +15,6 @@ export declare interface FileLocator { locateFile: (filename: string) => string; } -/** Listener to be passed in by user for handling output audio data. */ -export type AudioOutputListener = (output: Float32Array) => void; - /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -32,19 +29,14 @@ export declare interface WasmModule { _bindTextureToCanvas: () => boolean; _changeBinaryGraph: (size: number, dataPtr: number) => void; _changeTextGraph: (size: number, dataPtr: number) => void; - _configureAudio: - (channels: number, samples: number, sampleRate: number) => void; _free: (ptr: number) => void; _malloc: (size: number) => number; - _processAudio: (dataPtr: number, timestamp: number) => void; _processFrame: (width: number, height: number, timestamp: number) => void; _setAutoRenderToScreen: (enabled: boolean) => void; _waitUntilIdle: () => void; // Exposed so that clients of this lib can access this field dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; - // Wasm module will call us back at this function when given audio data. - onAudioOutput?: AudioOutputListener; // Wasm Module multistream entrypoints. Require // gl_graph_runner_internal_multi_input as a build dependency. @@ -100,11 +92,14 @@ export declare interface WasmModule { _attachProtoVectorListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; - // Requires dependency ":gl_graph_runner_audio_out", and will register an - // audio output listening function which can be tapped into dynamically during - // graph running via onAudioOutput. This call must be made before graph is - // initialized, but after wasmModule is instantiated. - _attachAudioOutputListener: () => void; + // Require dependency ":gl_graph_runner_audio_out" + _attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Require dependency ":gl_graph_runner_audio" + _addAudioToInputStream: (dataPtr: number, numChannels: number, + numSamples: number, streamNamePtr: number, timestamp: number) => void; + _configureAudio: (channels: number, samples: number, sampleRate: number, + streamNamePtr: number, headerNamePtr: number) => void; // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more @@ -129,7 +124,7 @@ declare global { declare function importScripts(...urls: Array): void; /** - * Valid types of image sources which we can run our WasmMediaPipeLib over. + * Valid types of image sources which we can run our GraphRunner over. */ export type ImageSource = HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; @@ -138,9 +133,11 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing WasmMediaPipeLib and -// subclasses. -type WasmMediaPipeConstructor = +/** + * Internal type of constructors used for initializing GraphRunner and + * subclasses. + */ +export type WasmMediaPipeConstructor = (new ( module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => LibType); @@ -151,7 +148,7 @@ type WasmMediaPipeConstructor = * into canvas, or else return the output WebGLTexture. Takes a WebAssembly * Module (must be instantiated to self.Module). */ -export class WasmMediaPipeLib { +export class GraphRunner { // TODO: These should be protected/private, but are left exposed for // now so that we can use proper TS mixins with this class as a base. This // should be somewhat fixed when we create our .d.ts files. @@ -181,10 +178,14 @@ export class WasmMediaPipeLib { if (glCanvas !== undefined) { this.wasmModule.canvas = glCanvas; - } else { + } else if (typeof OffscreenCanvas !== 'undefined') { // If no canvas is provided, assume Chrome/Firefox and just make an // OffscreenCanvas for GPU processing. this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } else { + console.warn('OffscreenCanvas not detected and GraphRunner constructor ' + + 'glCanvas parameter is undefined. Creating backup canvas.'); + this.wasmModule.canvas = document.createElement('canvas'); } } @@ -235,19 +236,38 @@ export class WasmMediaPipeLib { } /** - * Configures the current graph to handle audio in a certain way. Must be - * called before the graph is set/started in order to use processAudio. + * Configures the current graph to handle audio processing in a certain way + * for all its audio input streams. Additionally can configure audio headers + * (both input side packets as well as input stream headers), but these + * configurations only take effect if called before the graph is set/started. * @param numChannels The number of channels of audio input. Only 1 * is supported for now. * @param numSamples The number of samples that are taken in each * audio capture. * @param sampleRate The rate, in Hz, of the sampling. + * @param streamName The optional name of the input stream to additionally + * configure with audio information. This configuration only occurs before + * the graph is set/started. If unset, a default stream name will be used. + * @param headerName The optional name of the header input side packet to + * additionally configure with audio information. This configuration only + * occurs before the graph is set/started. If unset, a default header name + * will be used. */ - configureAudio(numChannels: number, numSamples: number, sampleRate: number) { - this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); - if (this.wasmModule._attachAudioOutputListener) { - this.wasmModule._attachAudioOutputListener(); + configureAudio(numChannels: number, numSamples: number, sampleRate: number, + streamName?: string, headerName?: string) { + if (!this.wasmModule._configureAudio) { + console.warn( + 'Attempting to use configureAudio without support for input audio. ' + + 'Is build dep ":gl_graph_runner_audio" missing?'); } + streamName = streamName || 'input_audio'; + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + headerName = headerName || 'audio_header'; + this.wrapStringPtr(headerName, (headerNamePtr: number) => { + this.wasmModule._configureAudio(streamNamePtr, headerNamePtr, + numChannels, numSamples, sampleRate); + }); + }); } /** @@ -305,6 +325,10 @@ export class WasmMediaPipeLib { if ((imageSource as HTMLVideoElement).videoWidth) { width = (imageSource as HTMLVideoElement).videoWidth; height = (imageSource as HTMLVideoElement).videoHeight; + } else if ((imageSource as HTMLImageElement).naturalWidth) { + // TODO: Ensure this works with SVG images + width = (imageSource as HTMLImageElement).naturalWidth; + height = (imageSource as HTMLImageElement).naturalHeight; } else { width = imageSource.width; height = imageSource.height; @@ -406,7 +430,7 @@ export class WasmMediaPipeLib { */ setVectorListener( outputStreamName: string, callbackFcn: (data: T[]) => void) { - const buffer: T[] = []; + let buffer: T[] = []; this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; this.wasmModule.vectorListeners[outputStreamName] = (data: unknown, index: number, length: number) => { @@ -419,6 +443,7 @@ export class WasmMediaPipeLib { // the underlying data elements once we leave the scope of the // listener. callbackFcn(buffer); + buffer = []; } }; } @@ -436,9 +461,36 @@ export class WasmMediaPipeLib { * processed. * @param audioData An array of raw audio capture data, like * from a call to getChannelData on an AudioBuffer. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. * @param timestamp The timestamp of the current frame, in ms. */ - addAudioToStream(audioData: Float32Array, timestamp: number) { + addAudioToStream( + audioData: Float32Array, streamName: string, timestamp: number) { + // numChannels and numSamples being 0 will cause defaults to be used, + // which will reflect values from last call to configureAudio. + this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp); + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed, shaping the audioData array into an audio matrix according to + * the numChannels and numSamples parameters. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param numChannels The number of audio channels this data represents. If 0 + * is passed, then the value will be taken from the last call to + * configureAudio. + * @param numSamples The number of audio samples captured in this data packet. + * If 0 is passed, then the value will be taken from the last call to + * configureAudio. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStreamWithShape( + audioData: Float32Array, numChannels: number, numSamples: number, + streamName: string, timestamp: number) { // 4 bytes for each F32 const size = audioData.length * 4; if (this.audioSize !== size) { @@ -449,7 +501,11 @@ export class WasmMediaPipeLib { this.audioSize = size; } this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); - this.wasmModule._processAudio(this.audioPtr!, timestamp); + + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addAudioToInputStream( + this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp); + }); } /** @@ -942,17 +998,45 @@ export class WasmMediaPipeLib { } /** - * Sets a listener to be called back with audio output packet data, as a - * Float32Array, when graph has finished processing it. - * @param audioOutputListener The caller's listener function. + * Attaches an audio packet listener to the specified output_stream, to be + * given a Float32Array as output. + * @param outputStreamName The name of the graph output stream to grab audio + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. If the + * audio data needs to be able to outlive the call, you may set the + * optional makeDeepCopy parameter to true, or can manually deep-copy the + * data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). */ - setOnAudioOutput(audioOutputListener: AudioOutputListener) { - this.wasmModule.onAudioOutput = audioOutputListener; - if (!this.wasmModule._attachAudioOutputListener) { + attachAudioListener(outputStreamName: string, + callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void { + if (!this.wasmModule._attachAudioListener) { console.warn( - 'Attempting to use AudioOutputListener without support for ' + + 'Attempting to use attachAudioListener without support for ' + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); } + + // Set up our TS listener to receive any packets for this stream, and + // additionally reformat our Uint8Array into a Float32Array for the user. + this.setListener(outputStreamName, (data: Uint8Array) => { + const floatArray = new Float32Array(data.buffer); // Should be very fast + callbackFcn(floatArray); + }); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachAudioListener( + outputStreamNamePtr, makeDeepCopy || false); + }); } /** @@ -988,7 +1072,7 @@ async function runScript(scriptUrl: string) { /** * Global function to initialize Wasm blob and load runtime assets for a * specialized MediaPipe library. This allows us to create a requested - * subclass inheriting from WasmMediaPipeLib. + * subclass inheriting from GraphRunner. * @param constructorFcn The name of the class to instantiate via "new". * @param wasmLoaderScript Url for the wasm-runner script; produced by the build * process. @@ -1001,8 +1085,8 @@ async function runScript(scriptUrl: string) { */ export async function createMediaPipeLib( constructorFcn: WasmMediaPipeConstructor, - wasmLoaderScript?: string, - assetLoaderScript?: string, + wasmLoaderScript?: string|null, + assetLoaderScript?: string|null, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, fileLocator?: FileLocator): Promise { const scripts = []; @@ -1042,12 +1126,12 @@ export async function createMediaPipeLib( * @return promise A promise which will resolve when initialization has * completed successfully. */ -export async function createWasmMediaPipeLib( +export async function createGraphRunner( wasmLoaderScript?: string, assetLoaderScript?: string, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - fileLocator?: FileLocator): Promise { + fileLocator?: FileLocator): Promise { return createMediaPipeLib( - WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + GraphRunner, wasmLoaderScript, assetLoaderScript, glCanvas, fileLocator); } diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts similarity index 83% rename from mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts rename to mediapipe/web/graph_runner/graph_runner_image_lib.ts index 3b45e8230..7a4ea09e2 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -1,12 +1,12 @@ -import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {ImageSource, GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -19,10 +19,10 @@ export declare interface WasmImageModule { } /** - * An implementation of WasmMediaPipeLib that supports binding GPU image data as + * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` + * `const GraphRunnerImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index e85d63b06..9f2791d80 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -1,12 +1,12 @@ -import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -17,11 +17,11 @@ export declare interface WasmModuleRegisterModelResources { } /** - * An implementation of WasmMediaPipeLib that supports registering model + * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: - * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * WasmMediaPipeLib);` + * `const GraphRunnerWithModelResourcesLib = + * SupportModelResourcesGraphService(GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 1f0b00289..72ca95e66 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -90,8 +90,8 @@ def external_files(): http_file( name = "com_google_mediapipe_canned_gesture_classifier_tflite", - sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2", - urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"], + sha256 = "ee121d85979de1b86126faabb0a0f4d2e4039c3e33e2cd687db50571001b24d0", + urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668550473107417"], ) http_file( @@ -294,8 +294,8 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_tflite", - sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"], + sha256 = "927e4f6cbe6451da6b4fd1485e2576a6f8dbd95062666661cbd9dea893c41d01", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668550476472972"], ) http_file( @@ -990,14 +990,14 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", - sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"], + sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668550482128410"], ) http_file( name = "com_google_mediapipe_gesture_embedder_saved_model_pb", - sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"], + sha256 = "0082d37c5b85487fbf553e00a63f640945faf3da2d561a5f5a24c3194fecda6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) http_file( @@ -1038,12 +1038,12 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", - sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"], + sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668550487965052"], ) http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_index", - sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"], + sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 6bfde21ba..504f8567a 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,36 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"], - ) - - http_file( - name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"], - ) - - http_file( - name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"], + sha256 = "42d2d0ade6e2e8b81425b23686be93eb1423b7777f043eb8f18ad671e2ca803f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1669173769507080"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"], + sha256 = "20200ee9b0866d5176f633a9b375e8a44e53204c01ea2e159e2f9245afb00e80", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1669173772528997"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", + sha256 = "11bbf73d48723b19a5a6a13ec296ecdb2aa178cdc3db9d7bc54265a7d4b94c6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1669173774625527"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", + sha256 = "d4528972219033996a83a62798952b6ee8b6b396bcffd96fd5bda5458d57d3a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1669173777474822"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_internal_js", + sha256 = "29e72e177122f92bda6a3ecd463ebacf30b920559b06c97068112a22eeea4d0e", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1669173779706893"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"], + sha256 = "84e5f5ac70f7718baeaa09a89b155abbea67386e7d50663301b3af7ef0941e74", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1669173782728605"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", + sha256 = "36f247673124e32535f217265b96508c1badee8fe2458c11c1efa95b6bec5daa", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1669173785027190"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", + sha256 = "cc74d90a8aaf6d006ec24048cc80c33f96baeeb0075a6c6739f30d41da54e450", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1669173787903754"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_internal_js", + sha256 = "c3451423186766b08008e07ef6d52f628fcc0aca75beedd9bb4d87d380f29edd", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1669173790070986"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"], + sha256 = "d1e8ad748913e3f190bfd3f72e0e8a4a308f78b918d54c79cec60a2cf30a49f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1669173792993881"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", + sha256 = "e5f1b5e8264ff9a90371653cb0fdbf9ce3b30b712acbd72068af18ebca2293ac", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1669173794969702"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", + sha256 = "24351fe580e88f2065b1978b8b3c0f3ad7b90f1c95805aafa07971ce422b5854", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1669173797596874"], )