diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 74398be42..ecd878115 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -936,6 +936,7 @@ cc_test( "//mediapipe/framework/tool:simulation_clock", "//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/time", ], ) diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index d209b1dbb..5b08f3af5 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -18,7 +18,6 @@ #include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/util/header_util.h" @@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS"; // FlowLimiterCalculator provides limited support for multiple input streams. // The first input stream is treated as the main input stream and successive // input streams are treated as auxiliary input streams. The auxiliary input -// streams are limited to timestamps passed on the main input stream. +// streams are limited to timestamps allowed by the "ALLOW" stream. // class FlowLimiterCalculator : public CalculatorBase { public: @@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase { cc->InputSidePackets().Tag(kMaxInFlightTag).Get()); } input_queues_.resize(cc->Inputs().NumEntries("")); + allowed_[Timestamp::Unset()] = true; RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); return absl::OkStatus(); } - // Returns true if an additional frame can be released for processing. - // The "ALLOW" output stream indicates this condition at each input frame. - bool ProcessingAllowed() { - return frames_in_flight_.size() < options_.max_in_flight(); - } - - // Outputs a packet indicating whether a frame was sent or dropped. - void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { - if (cc->Outputs().HasTag(kAllowTag)) { - cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket(allow).At(ts)); - } - } - - // Sets the timestamp bound or closes an output stream. - void SetNextTimestampBound(Timestamp bound, OutputStream* stream) { - if (bound > Timestamp::Max()) { - stream->Close(); - } else { - stream->SetNextTimestampBound(bound); - } - } - - // Returns true if a certain timestamp is being processed. - bool IsInFlight(Timestamp timestamp) { - return std::find(frames_in_flight_.begin(), frames_in_flight_.end(), - timestamp) != frames_in_flight_.end(); - } - - // Releases input packets up to the latest settled input timestamp. - void ProcessAuxiliaryInputs(CalculatorContext* cc) { - Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound(); - for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) { - // Release settled frames from each input queue. - while (!input_queues_[i].empty() && - input_queues_[i].front().Timestamp() < settled_bound) { - Packet packet = input_queues_[i].front(); - input_queues_[i].pop_front(); - if (IsInFlight(packet.Timestamp())) { - cc->Outputs().Get("", i).AddPacket(packet); - } - } - - // Propagate each input timestamp bound. - if (!input_queues_[i].empty()) { - Timestamp bound = input_queues_[i].front().Timestamp(); - SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); - } else { - Timestamp bound = - cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream(); - SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); - } - } - } - // Releases input packets allowed by the max_in_flight constraint. absl::Status Process(CalculatorContext* cc) final { options_ = tool::RetrieveOptions(options_, cc->Inputs()); @@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase { } ProcessAuxiliaryInputs(cc); + + // Discard old ALLOW ranges. + Timestamp input_bound = InputTimestampBound(cc); + auto first_range = std::prev(allowed_.upper_bound(input_bound)); + allowed_.erase(allowed_.begin(), first_range); return absl::OkStatus(); } + int LedgerSize() { + int result = frames_in_flight_.size() + allowed_.size(); + for (const auto& queue : input_queues_) { + result += queue.size(); + } + return result; + } + + private: + // Returns true if an additional frame can be released for processing. + // The "ALLOW" output stream indicates this condition at each input frame. + bool ProcessingAllowed() { + return frames_in_flight_.size() < options_.max_in_flight(); + } + + // Outputs a packet indicating whether a frame was sent or dropped. + void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) { + if (cc->Outputs().HasTag(kAllowTag)) { + cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket(allow).At(ts)); + } + allowed_[ts] = allow; + } + + // Returns true if a timestamp falls within a range of allowed timestamps. + bool IsAllowed(Timestamp timestamp) { + auto it = allowed_.upper_bound(timestamp); + return std::prev(it)->second; + } + + // Sets the timestamp bound or closes an output stream. + void SetNextTimestampBound(Timestamp bound, OutputStream* stream) { + if (bound > Timestamp::Max()) { + stream->Close(); + } else { + stream->SetNextTimestampBound(bound); + } + } + + // Returns the lowest unprocessed input Timestamp. + Timestamp InputTimestampBound(CalculatorContext* cc) { + Timestamp result = Timestamp::Done(); + for (int i = 0; i < input_queues_.size(); ++i) { + auto& queue = input_queues_[i]; + auto& stream = cc->Inputs().Get("", i); + Timestamp bound = queue.empty() + ? stream.Value().Timestamp().NextAllowedInStream() + : queue.front().Timestamp(); + result = std::min(result, bound); + } + return result; + } + + // Releases input packets up to the latest settled input timestamp. + void ProcessAuxiliaryInputs(CalculatorContext* cc) { + Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound(); + for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) { + // Release settled frames from each input queue. + while (!input_queues_[i].empty() && + input_queues_[i].front().Timestamp() < settled_bound) { + Packet packet = input_queues_[i].front(); + input_queues_[i].pop_front(); + if (IsAllowed(packet.Timestamp())) { + cc->Outputs().Get("", i).AddPacket(packet); + } + } + + // Propagate each input timestamp bound. + if (!input_queues_[i].empty()) { + Timestamp bound = input_queues_[i].front().Timestamp(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); + } else { + Timestamp bound = + cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream(); + SetNextTimestampBound(bound, &cc->Outputs().Get("", i)); + } + } + } + private: FlowLimiterCalculatorOptions options_; std::vector> input_queues_; std::deque frames_in_flight_; + std::map allowed_; }; REGISTER_CALCULATOR(FlowLimiterCalculator); diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 962b1c81a..8a8cc9656 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "absl/time/clock.h" @@ -32,6 +33,7 @@ #include "mediapipe/framework/tool/simulation_clock.h" #include "mediapipe/framework/tool/simulation_clock_executor.h" #include "mediapipe/framework/tool/sink.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { @@ -77,6 +79,77 @@ std::vector PacketValues(const std::vector& packets) { return result; } +template +std::vector MakePackets(std::vector> contents) { + std::vector result; + for (auto& entry : contents) { + result.push_back(MakePacket(entry.second).At(entry.first)); + } + return result; +} + +std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +class PacketsEqMatcher + : public ::testing::MatcherInterface { + public: + PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet contents: \n"; + Print(packets_, os); + } + bool MatchAndExplain( + const PacketContainer& value, + ::testing::MatchResultListener* listener) const override { + if (!Equals(packets_, value)) { + if (listener->IsInterested()) { + *listener << "The actual packet contents: \n"; + Print(value, listener->stream()); + } + return false; + } + return true; + } + + private: + bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { + if (c1.size() != c2.size()) { + return false; + } + for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { + Packet p1 = *i1, p2 = *i2; + if (p1.Timestamp() != p2.Timestamp() || + p1.Get() != p2.Get()) { + return false; + } + } + return true; + } + void Print(const PacketContainer& packets, ::std::ostream* os) const { + for (auto it = packets.begin(); it != packets.end(); ++it) { + const Packet& packet = *it; + *os << (it == packets.begin() ? "{" : "") << "{" + << SourceString(packet.Timestamp()) << ", " + << packet.Get() << "}" + << (std::next(it) == packets.end() ? "}" : ", "); + } + } + + const PacketContainer packets_; +}; + +template +::testing::Matcher PackestEq( + const PacketContainer& packets) { + return MakeMatcher( + new PacketsEqMatcher(packets)); +} + // A Calculator::Process callback function. typedef std::function @@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { input_packets_[17], input_packets_[19], input_packets_[20], }; EXPECT_EQ(out_1_packets_, expected_output); - // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. + // The timestamps released by FlowLimiterCalculator for in_1_sampled, + // plus input_packets_[21]. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[14], input_packets_[17], input_packets_[19], - input_packets_[20], + input_packets_[20], input_packets_[21], }; EXPECT_EQ(out_2_packets, expected_output_2); } @@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { + auto BoolPackestEq = PackestEq, bool>; + auto IntPackestEq = PackestEq, int>; + // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -699,11 +776,10 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { } )pb"); - auto limiter_options = ParseTextProtoOrDie(R"pb( - max_in_flight: 1 - max_in_queue: 0 - in_flight_timeout: 100000 # 100 ms - )pb"); + auto limiter_options = ParseTextProtoOrDie( + R"pb( + max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000 # 100 ms + )pb"); std::map side_packets = { {"limiter_options", MakePacket(limiter_options)}, @@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_EQ(out_1_packets_, expected_output); + EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output)); // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_EQ(out_2_packets, expected_output_2); + EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2)); + + // Validate the ALLOW stream output. + std::vector expected_allow = MakePackets( // + {{Timestamp(0), true}, {Timestamp(10000), false}, + {Timestamp(20000), true}, {Timestamp(30000), false}, + {Timestamp(40000), true}, {Timestamp(50000), false}, + {Timestamp(60000), false}, {Timestamp(70000), false}, + {Timestamp(80000), false}, {Timestamp(90000), false}, + {Timestamp(100000), false}, {Timestamp(110000), false}, + {Timestamp(120000), false}, {Timestamp(130000), false}, + {Timestamp(140000), false}, {Timestamp(150000), true}, + {Timestamp(160000), false}, {Timestamp(170000), true}, + {Timestamp(180000), false}, {Timestamp(190000), true}, + {Timestamp(200000), false}}); + EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow)); +} + +// Shows how FlowLimiterCalculator releases auxiliary input packets. +// In this test, auxiliary input packets arrive at twice the primary rate. +TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { + auto BoolPackestEq = PackestEq, bool>; + auto IntPackestEq = PackestEq, int>; + + // Configure the test. + SetUpInputData(); + SetUpSimulationClock(); + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(R"pb( + input_stream: 'in_1' + input_stream: 'in_2' + node { + calculator: 'FlowLimiterCalculator' + input_side_packet: 'OPTIONS:limiter_options' + input_stream: 'in_1' + input_stream: 'in_2' + input_stream: 'FINISHED:out_1' + input_stream_info: { tag_index: 'FINISHED' back_edge: true } + output_stream: 'in_1_sampled' + output_stream: 'in_2_sampled' + output_stream: 'ALLOW:allow' + } + node { + calculator: 'SleepCalculator' + input_side_packet: 'WARMUP_TIME:warmup_time' + input_side_packet: 'SLEEP_TIME:sleep_time' + input_side_packet: 'CLOCK:clock' + input_stream: 'PACKET:in_1_sampled' + output_stream: 'PACKET:out_1' + } + )pb"); + + auto limiter_options = ParseTextProtoOrDie( + R"pb( + max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000 # 1s + )pb"); + std::map side_packets = { + {"limiter_options", + MakePacket(limiter_options)}, + {"warmup_time", MakePacket(22000)}, + {"sleep_time", MakePacket(22000)}, + {"clock", MakePacket(clock_)}, + }; + + // Start the graph. + MP_ASSERT_OK(graph_.Initialize(graph_config)); + MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { + out_1_packets_.push_back(p); + return absl::OkStatus(); + })); + std::vector out_2_packets; + MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) { + out_2_packets.push_back(p); + return absl::OkStatus(); + })); + MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { + allow_packets_.push_back(p); + return absl::OkStatus(); + })); + simulation_clock_->ThreadStart(); + MP_ASSERT_OK(graph_.StartRun(side_packets)); + + // Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2. + clock_->Sleep(absl::Microseconds(10000)); + for (int i = 1; i < 10; ++i) { + if (i % 2 == 0) { + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i])); + } + MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i])); + clock_->Sleep(absl::Microseconds(10000)); + } + + // Finish the graph run. + MP_EXPECT_OK(graph_.CloseAllPacketSources()); + clock_->Sleep(absl::Microseconds(40000)); + MP_EXPECT_OK(graph_.WaitUntilDone()); + simulation_clock_->ThreadFinish(); + + // Validate the output. + // Input packets 4 and 8 are dropped due to max_in_flight. + std::vector expected_output = { + input_packets_[2], + input_packets_[6], + }; + EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output)); + // Packets following input packets 2 and 6, and not input packets 4 and 8. + std::vector expected_output_2 = { + input_packets_[1], input_packets_[2], input_packets_[3], + input_packets_[6], input_packets_[7], + }; + EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2)); + + // Validate the ALLOW stream output. + std::vector expected_allow = + MakePackets({{Timestamp(20000), 1}, + {Timestamp(40000), 0}, + {Timestamp(60000), 1}, + {Timestamp(80000), 0}}); + EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow)); } } // anonymous namespace diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 8358ea01a..70775949d 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -14,45 +14,43 @@ #include "mediapipe/framework/deps/status_builder.h" +#include +#include + #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" namespace mediapipe { -StatusBuilder::StatusBuilder(const StatusBuilder& sb) { - status_ = sb.status_; - file_ = sb.file_; - line_ = sb.line_; - no_logging_ = sb.no_logging_; - stream_ = sb.stream_ - ? absl::make_unique(sb.stream_->str()) - : nullptr; - join_style_ = sb.join_style_; -} +StatusBuilder::StatusBuilder(const StatusBuilder& sb) + : impl_(sb.impl_ ? std::make_unique(*sb.impl_) : nullptr) {} StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { - status_ = sb.status_; - file_ = sb.file_; - line_ = sb.line_; - no_logging_ = sb.no_logging_; - stream_ = sb.stream_ - ? absl::make_unique(sb.stream_->str()) - : nullptr; - join_style_ = sb.join_style_; + if (!sb.impl_) { + impl_ = nullptr; + return *this; + } + if (impl_) { + *impl_ = *sb.impl_; + return *this; + } + impl_ = std::make_unique(*sb.impl_); + return *this; } StatusBuilder& StatusBuilder::SetAppend() & { - if (status_.ok()) return *this; - join_style_ = MessageJoinStyle::kAppend; + if (!impl_) return *this; + impl_->join_style = Impl::MessageJoinStyle::kAppend; return *this; } StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); } StatusBuilder& StatusBuilder::SetPrepend() & { - if (status_.ok()) return *this; - join_style_ = MessageJoinStyle::kPrepend; + if (!impl_) return *this; + impl_->join_style = Impl::MessageJoinStyle::kPrepend; return *this; } @@ -61,7 +59,8 @@ StatusBuilder&& StatusBuilder::SetPrepend() && { } StatusBuilder& StatusBuilder::SetNoLogging() & { - no_logging_ = true; + if (!impl_) return *this; + impl_->no_logging = true; return *this; } @@ -70,34 +69,72 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && { } StatusBuilder::operator Status() const& { - if (!stream_ || stream_->str().empty() || no_logging_) { - return status_; - } return StatusBuilder(*this).JoinMessageToStatus(); } -StatusBuilder::operator Status() && { - if (!stream_ || stream_->str().empty() || no_logging_) { - return status_; - } - return JoinMessageToStatus(); -} +StatusBuilder::operator Status() && { return JoinMessageToStatus(); } absl::Status StatusBuilder::JoinMessageToStatus() { - if (!stream_) { + if (!impl_) { return absl::OkStatus(); } - std::string message; - if (join_style_ == MessageJoinStyle::kAnnotate) { - if (!status_.ok()) { - message = absl::StrCat(status_.message(), "; ", stream_->str()); - } - } else { - message = join_style_ == MessageJoinStyle::kPrepend - ? absl::StrCat(stream_->str(), status_.message()) - : absl::StrCat(status_.message(), stream_->str()); + return impl_->JoinMessageToStatus(); +} + +absl::Status StatusBuilder::Impl::JoinMessageToStatus() { + if (stream.str().empty() || no_logging) { + return status; } - return Status(status_.code(), message); + return absl::Status(status.code(), [this]() { + switch (join_style) { + case MessageJoinStyle::kAnnotate: + return absl::StrCat(status.message(), "; ", stream.str()); + case MessageJoinStyle::kAppend: + return absl::StrCat(status.message(), stream.str()); + case MessageJoinStyle::kPrepend: + return absl::StrCat(stream.str(), status.message()); + } + }()); +} + +StatusBuilder::Impl::Impl(const absl::Status& status, const char* file, + int line) + : status(status), line(line), file(file), stream() {} + +StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line) + : status(std::move(status)), line(line), file(file), stream() {} + +StatusBuilder::Impl::Impl(const absl::Status& status, + mediapipe::source_location location) + : status(status), + line(location.line()), + file(location.file_name()), + stream() {} + +StatusBuilder::Impl::Impl(absl::Status&& status, + mediapipe::source_location location) + : status(std::move(status)), + line(location.line()), + file(location.file_name()), + stream() {} + +StatusBuilder::Impl::Impl(const Impl& other) + : status(other.status), + line(other.line), + file(other.file), + no_logging(other.no_logging), + stream(other.stream.str()), + join_style(other.join_style) {} + +StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) { + status = other.status; + line = other.line; + file = other.file; + no_logging = other.no_logging; + stream = std::ostringstream(other.stream.str()); + join_style = other.join_style; + + return *this; } } // namespace mediapipe diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index c9111c603..d2e40d575 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -22,7 +22,6 @@ #include "absl/base/attributes.h" #include "absl/memory/memory.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/deps/source_location.h" #include "mediapipe/framework/deps/status.h" @@ -42,34 +41,37 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // occurs. A typical user will call this with `MEDIAPIPE_LOC`. StatusBuilder(const absl::Status& original_status, mediapipe::source_location location) - : status_(original_status), - line_(location.line()), - file_(location.file_name()), - stream_(InitStream(status_)) {} + : impl_(original_status.ok() + ? nullptr + : std::make_unique(original_status, location)) {} StatusBuilder(absl::Status&& original_status, mediapipe::source_location location) - : status_(std::move(original_status)), - line_(location.line()), - file_(location.file_name()), - stream_(InitStream(status_)) {} + : impl_(original_status.ok() + ? nullptr + : std::make_unique(std::move(original_status), + location)) {} // Creates a `StatusBuilder` from a mediapipe status code. If logging is // enabled, it will use `location` as the location from which the log message // occurs. A typical user will call this with `MEDIAPIPE_LOC`. StatusBuilder(absl::StatusCode code, mediapipe::source_location location) - : status_(code, ""), - line_(location.line()), - file_(location.file_name()), - stream_(InitStream(status_)) {} + : impl_(code == absl::StatusCode::kOk + ? nullptr + : std::make_unique(absl::Status(code, ""), location)) {} StatusBuilder(const absl::Status& original_status, const char* file, int line) - : status_(original_status), - line_(line), - file_(file), - stream_(InitStream(status_)) {} + : impl_(original_status.ok() + ? nullptr + : std::make_unique(original_status, file, line)) {} - bool ok() const { return status_.ok(); } + StatusBuilder(absl::Status&& original_status, const char* file, int line) + : impl_(original_status.ok() + ? nullptr + : std::make_unique(std::move(original_status), file, + line)) {} + + bool ok() const { return !impl_; } StatusBuilder& SetAppend() &; StatusBuilder&& SetAppend() &&; @@ -82,8 +84,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { template StatusBuilder& operator<<(const T& msg) & { - if (!stream_) return *this; - *stream_ << msg; + if (!impl_) return *this; + impl_->stream << msg; return *this; } @@ -98,35 +100,42 @@ class ABSL_MUST_USE_RESULT StatusBuilder { absl::Status JoinMessageToStatus(); private: - // Specifies how to join the error message in the original status and any - // additional message that has been streamed into the builder. - enum class MessageJoinStyle { - kAnnotate, - kAppend, - kPrepend, + struct Impl { + // Specifies how to join the error message in the original status and any + // additional message that has been streamed into the builder. + enum class MessageJoinStyle { + kAnnotate, + kAppend, + kPrepend, + }; + + Impl(const absl::Status& status, const char* file, int line); + Impl(absl::Status&& status, const char* file, int line); + Impl(const absl::Status& status, mediapipe::source_location location); + Impl(absl::Status&& status, mediapipe::source_location location); + Impl(const Impl&); + Impl& operator=(const Impl&); + + absl::Status JoinMessageToStatus(); + + // The status that the result will be based on. + absl::Status status; + // The line to record if this file is logged. + int line; + // Not-owned: The file to record if this status is logged. + const char* file; + // Logging disabled if true. + bool no_logging = false; + // The additional messages added with `<<`. This is nullptr when status_ is + // ok. + std::ostringstream stream; + // Specifies how to join the message in `status_` and `stream_`. + MessageJoinStyle join_style = MessageJoinStyle::kAnnotate; }; - // Conditionally creates an ostringstream if the status is not ok. - static std::unique_ptr InitStream( - const absl::Status status) { - if (status.ok()) { - return nullptr; - } - return absl::make_unique(); - } - - // The status that the result will be based on. - absl::Status status_; - // The line to record if this file is logged. - int line_; - // Not-owned: The file to record if this status is logged. - const char* file_; - bool no_logging_ = false; - // The additional messages added with `<<`. This is nullptr when status_ is - // ok. - std::unique_ptr stream_; - // Specifies how to join the message in `status_` and `stream_`. - MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate; + // Internal store of data for the class. An invariant of the class is that + // this is null when the original status is okay, and not-null otherwise. + std::unique_ptr impl_; }; inline StatusBuilder AlreadyExistsErrorBuilder( diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index f517bb909..560acd3c6 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -33,6 +33,21 @@ TEST(StatusBuilder, OkStatusRvalue) { ASSERT_EQ(status, absl::OkStatus()); } +TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) { + absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234) + << "annotated message1 " + << "annotated message2"; + ASSERT_EQ(status, absl::OkStatus()); +} + +TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) { + const auto original_status = absl::OkStatus(); + absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) + << "annotated message1 " + << "annotated message2"; + ASSERT_EQ(status, absl::OkStatus()); +} + TEST(StatusBuilder, AnnotateMode) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, "original message"), @@ -45,6 +60,30 @@ TEST(StatusBuilder, AnnotateMode) { "original message; annotated message1 annotated message2"); } +TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) { + absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, + "original message"), + "hello.cc", 1234) + << "annotated message1 " + << "annotated message2"; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_EQ(status.message(), + "original message; annotated message1 annotated message2"); +} + +TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) { + const auto original_status = + absl::Status(absl::StatusCode::kNotFound, "original message"); + absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) + << "annotated message1 " + << "annotated message2"; + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_EQ(status.message(), + "original message; annotated message1 annotated message2"); +} + TEST(StatusBuilder, PrependModeLvalue) { StatusBuilder builder( absl::Status(absl::StatusCode::kInvalidArgument, "original message"), diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index f83d4059a..5d0fbd066 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model class Classifier(custom_model.CustomModel): """An abstract base class that represents a TensorFlow classifier.""" - def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool, - full_train: bool): + def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool): """Initilizes a classifier with its specifications. Args: model_spec: Specification for the model. label_names: A list of label names for the classes. shuffle: Whether the dataset should be shuffled. - full_train: If true, train the model end-to-end including the backbone - and the classification layers on top. Otherwise, only train the top - classification layers. """ super(Classifier, self).__init__(model_spec, shuffle) self._label_names = label_names - self._full_train = full_train self._num_classes = len(label_names) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: diff --git a/mediapipe/model_maker/python/core/tasks/classifier_test.py b/mediapipe/model_maker/python/core/tasks/classifier_test.py index 52a3b97db..6bf3b7a2e 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier_test.py +++ b/mediapipe/model_maker/python/core/tasks/classifier_test.py @@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase): super(ClassifierTest, self).setUp() label_names = ['cat', 'dog'] self.model = MockClassifier( - model_spec=None, - label_names=label_names, - shuffle=False, - full_train=False) + model_spec=None, label_names=label_names, shuffle=False) self.model.model = test_util.build_model(input_shape=[4], num_classes=2) def _check_nonempty_file(self, filepath): diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 61e7c7152..569138df7 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier): hparams: The hyperparameters for training image classifier. """ super().__init__( - model_spec=model_spec, - label_names=label_names, - shuffle=hparams.shuffle, - full_train=hparams.do_fine_tuning) + model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle) self._hparams = hparams self._preprocess = image_preprocessing.Preprocessor( input_shape=self._model_spec.input_image_shape, diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index dc2058554..2423370e6 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -93,6 +93,13 @@ cc_library( "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", ], + ] + select({ + # TODO: Build text_classifier_graph on Windows. + "//mediapipe:windows": [], + "//conditions:default": [ + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + ], + }), ) py_library( diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index 421ac44d3..b36fa306a 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) { // TODO: Should take "const Eigen::Ref&" // as the input argument. Investigate why bazel non-optimized mode // triggers a memory allocation bug in Eigen::internal::aligned_free(). - [](const Eigen::MatrixXf& matrix) { + [](const Eigen::MatrixXf& matrix, bool transpose) { // MakePacket copies the data. + if (transpose) { + return MakePacket(matrix.transpose()); + } return MakePacket(matrix); }, R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray. @@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) { Args: matrix: A 2d numpy float ndarray. + transpose: A boolean to indicate if the input matrix needs to be transposed. + Default to False. Returns: A MediaPipe Matrix Packet. @@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) { np.array([[.1, .2, .3], [.4, .5, .6]]) matrix = mp.packet_getter.get_matrix(packet) )doc", + py::arg("matrix"), py::arg("transpose") = false, py::return_value_policy::move); } // NOLINT(readability/fn_size) diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 4717ea50e..cdb3998c8 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -31,13 +31,13 @@ namespace core { // Base options for MediaPipe C++ Tasks. struct BaseOptions { - // The model asset file contents as as string. + // The model asset file contents as a string. std::unique_ptr model_asset_buffer; // The path to the model asset to open and mmap in memory. std::string model_asset_path = ""; - // The delegate to run MediaPipe. If the delegate is not set, default + // The delegate to run MediaPipe. If the delegate is not set, the default // delegate CPU is used. enum Delegate { CPU = 0, diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index 7ab4847dd..47d95100b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { hand_gesture_subgraph[Output>( kHandGesturesTag)]; - return {{.gesture = hand_gestures, - .handedness = handedness, - .hand_landmarks = hand_landmarks, - .hand_world_landmarks = hand_world_landmarks, - .image = hand_landmarker_graph[Output(kImageTag)]}}; + return GestureRecognizerOutputs{ + /*gesture=*/hand_gestures, + /*handedness=*/handedness, + /*hand_landmarks=*/hand_landmarks, + /*hand_world_landmarks=*/hand_world_landmarks, + /*image=*/hand_landmarker_graph[Output(kImageTag)]}; } }; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 5a5baa50e..564184c64 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) { } // Populate normalized non rotated face bounding box - return {.left = bounding_box_left, - .top = bounding_box_top, - .right = bounding_box_right, - .bottom = bounding_box_bottom}; + return Rect{/*left=*/bounding_box_left, + /*top=*/bounding_box_top, + /*right=*/bounding_box_right, + /*bottom=*/bounding_box_bottom}; } // Uses IoU and distance of some corresponding hand landmarks to detect diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java index d28946736..7f2903503 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java @@ -48,7 +48,7 @@ public abstract class BaseOptions { public abstract Builder setModelAssetBuffer(ByteBuffer value); /** - * Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default + * Sets device delegate to run the MediaPipe pipeline. If the delegate is not set, the default * delegate CPU is used. */ public abstract Builder setDelegate(Delegate delegate); diff --git a/mediapipe/tasks/python/components/containers/audio_data.py b/mediapipe/tasks/python/components/containers/audio_data.py index 21b606079..56399dea8 100644 --- a/mediapipe/tasks/python/components/containers/audio_data.py +++ b/mediapipe/tasks/python/components/containers/audio_data.py @@ -68,7 +68,11 @@ class AudioData(object): ValueError: If the input array has an incorrect shape or if `offset` + `size` exceeds the length of the `src` array. """ - if src.shape[1] != self._audio_format.num_channels: + if len(src.shape) == 1: + if self._audio_format.num_channels != 1: + raise ValueError(f"Input audio is mono, but the audio data is expected " + f"to have {self._audio_format.num_channels} channels.") + elif src.shape[1] != self._audio_format.num_channels: raise ValueError(f"Input audio contains an invalid number of channels. " f"Expect {self._audio_format.num_channels}.") @@ -93,6 +97,28 @@ class AudioData(object): self._buffer = np.roll(self._buffer, -shift, axis=0) self._buffer[-shift:, :] = src[offset:offset + size].copy() + @classmethod + def create_from_array(cls, + src: np.ndarray, + sample_rate: Optional[float] = None) -> "AudioData": + """Creates an `AudioData` object from a NumPy array. + + Args: + src: A NumPy source array contains the input audio. + sample_rate: the optional audio sample rate. + + Returns: + An `AudioData` object that contains a copy of the NumPy source array as + the data. + """ + obj = cls( + buffer_length=src.shape[0], + audio_format=AudioFormat( + num_channels=1 if len(src.shape) == 1 else src.shape[1], + sample_rate=sample_rate)) + obj.load_from_array(src) + return obj + @property def audio_format(self) -> AudioFormat: """Gets the audio format of the audio.""" diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py index 6854791b4..23ee4abe1 100644 --- a/mediapipe/tasks/python/test/test_utils.py +++ b/mediapipe/tasks/python/test/test_utils.py @@ -157,7 +157,7 @@ def _normalize_number_fields(pb): descriptor.FieldDescriptor.TYPE_ENUM): normalized_values = [int(x) for x in values] elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: - normalized_values = [round(x, 5) for x in values] + normalized_values = [round(x, 4) for x in values] elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: normalized_values = [round(float(x), 6) for x in values] diff --git a/mediapipe/tasks/python/test/text/BUILD b/mediapipe/tasks/python/test/text/BUILD new file mode 100644 index 000000000..d7176b0a5 --- /dev/null +++ b/mediapipe/tasks/python/test/text/BUILD @@ -0,0 +1,36 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.py"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/text:text_classifier", + ], +) diff --git a/mediapipe/tasks/python/test/text/__init__.py b/mediapipe/tasks/python/test/text/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/test/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py new file mode 100644 index 000000000..c93def48e --- /dev/null +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -0,0 +1,244 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for text classifier.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.python.components.containers import category +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.text import text_classifier + +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_Category = category.Category +_ClassificationEntry = classifications_module.ClassificationEntry +_Classifications = classifications_module.Classifications +_TextClassifierResult = classifications_module.ClassificationResult +_TextClassifier = text_classifier.TextClassifier +_TextClassifierOptions = text_classifier.TextClassifierOptions + +_BERT_MODEL_FILE = 'bert_text_classifier.tflite' +_REGEX_MODEL_FILE = 'test_model_text_classifier_with_regex_tokenizer.tflite' +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' + +_NEGATIVE_TEXT = 'What a waste of my time.' +_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.' + 'Strongly recommend it!') + +_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=0, + score=0.999479, + display_name='', + category_name='negative'), + _Category( + index=1, + score=0.00052154, + display_name='', + category_name='positive') + ], + timestamp_ms=0) + ], + head_index=0, + head_name='probability') +]) +_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=1, + score=0.999466, + display_name='', + category_name='positive'), + _Category( + index=0, + score=0.000533596, + display_name='', + category_name='negative') + ], + timestamp_ms=0) + ], + head_index=0, + head_name='probability') +]) +_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=0, + score=0.81313, + display_name='', + category_name='Negative'), + _Category( + index=1, + score=0.1868704, + display_name='', + category_name='Positive') + ], + timestamp_ms=0) + ], + head_index=0, + head_name='probability') +]) +_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=1, + score=0.5134273, + display_name='', + category_name='Positive'), + _Category( + index=0, + score=0.486573, + display_name='', + category_name='Negative') + ], + timestamp_ms=0) + ], + head_index=0, + head_name='probability') +]) + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class ImageClassifierTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _TextClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _TextClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _TextClassifierOptions(base_options=base_options) + with _TextClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _TextClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _TextClassifierOptions(base_options=base_options) + _TextClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _TextClassifierOptions(base_options=base_options) + classifier = _TextClassifier.create_from_options(options) + self.assertIsInstance(classifier, _TextClassifier) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT, + _BERT_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, + _NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS), + (ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _POSITIVE_TEXT, + _BERT_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, + _POSITIVE_TEXT, _BERT_POSITIVE_RESULTS), + (ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _NEGATIVE_TEXT, + _REGEX_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE, + _NEGATIVE_TEXT, _REGEX_NEGATIVE_RESULTS), + (ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _POSITIVE_TEXT, + _REGEX_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE, + _POSITIVE_TEXT, _REGEX_POSITIVE_RESULTS)) + def test_classify(self, model_file_type, model_name, text, + expected_classification_result): + # Creates classifier. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, model_name)) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _TextClassifierOptions(base_options=base_options) + classifier = _TextClassifier.create_from_options(options) + + # Performs text classification on the input. + text_result = classifier.classify(text) + # Comparing results. + test_utils.assert_proto_equals(self, text_result.to_pb2(), + expected_classification_result.to_pb2()) + # Closes the classifier explicitly when the classifier is not used in + # a context. + classifier.close() + + @parameterized.parameters((ModelFileType.FILE_NAME, _BERT_MODEL_FILE, + _NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS), + (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE, + _NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS)) + def test_classify_in_context(self, model_file_type, model_name, text, + expected_classification_result): + # Creates classifier. + model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, model_name)) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _TextClassifierOptions(base_options=base_options) + + with _TextClassifier.create_from_options(options) as classifier: + # Performs text classification on the input. + text_result = classifier.classify(text) + # Comparing results. + test_utils.assert_proto_equals(self, text_result.to_pb2(), + expected_classification_result.to_pb2()) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD new file mode 100644 index 000000000..fd5d701b4 --- /dev/null +++ b/mediapipe/tasks/python/text/BUILD @@ -0,0 +1,38 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "text_classifier", + srcs = [ + "text_classifier.py", + ], + deps = [ + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/text/core:base_text_task_api", + ], +) diff --git a/mediapipe/tasks/python/text/__init__.py b/mediapipe/tasks/python/text/__init__.py new file mode 100644 index 000000000..65c1214af --- /dev/null +++ b/mediapipe/tasks/python/text/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/text/core/BUILD b/mediapipe/tasks/python/text/core/BUILD new file mode 100644 index 000000000..072a0c7d8 --- /dev/null +++ b/mediapipe/tasks/python/text/core/BUILD @@ -0,0 +1,31 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "base_text_task_api", + srcs = [ + "base_text_task_api.py", + ], + deps = [ + "//mediapipe/framework:calculator_py_pb2", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/text/core/__init__.py b/mediapipe/tasks/python/text/core/__init__.py new file mode 100644 index 000000000..6a8405189 --- /dev/null +++ b/mediapipe/tasks/python/text/core/__init__.py @@ -0,0 +1,16 @@ +"""Copyright 2022 The MediaPipe Authors. + +All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/mediapipe/tasks/python/text/core/base_text_task_api.py b/mediapipe/tasks/python/text/core/base_text_task_api.py new file mode 100644 index 000000000..b22bfff00 --- /dev/null +++ b/mediapipe/tasks/python/text/core/base_text_task_api.py @@ -0,0 +1,55 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe text task base api.""" + +from mediapipe.framework import calculator_pb2 +from mediapipe.python._framework_bindings import task_runner +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_TaskRunner = task_runner.TaskRunner + + +class BaseTextTaskApi(object): + """The base class of the user-facing mediapipe text task api classes.""" + + def __init__(self, + graph_config: calculator_pb2.CalculatorGraphConfig) -> None: + """Initializes the `BaseVisionTaskApi` object. + + Args: + graph_config: The mediapipe text task graph config proto. + """ + self._runner = _TaskRunner.create(graph_config) + + def close(self) -> None: + """Shuts down the mediapipe text task instance. + + Raises: + RuntimeError: If the mediapipe text task failed to close. + """ + self._runner.close() + + @doc_controls.do_not_generate_docs + def __enter__(self): + """Returns `self` upon entering the runtime context.""" + return self + + @doc_controls.do_not_generate_docs + def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback): + """Shuts down the mediapipe text task instance on exit of the context manager. + + Raises: + RuntimeError: If the mediapipe text task failed to close. + """ + self.close() diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py new file mode 100644 index 000000000..1e230ee20 --- /dev/null +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -0,0 +1,140 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe text classifier task.""" + +import dataclasses + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.containers import classifications +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.text.core import base_text_task_api + +TextClassifierResult = classifications.ClassificationResult +_BaseOptions = base_options_module.BaseOptions +_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_TaskInfo = task_info_module.TaskInfo + +_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' +_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_TEXT_IN_STREAM_NAME = 'text_in' +_TEXT_TAG = 'TEXT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph' + + +@dataclasses.dataclass +class TextClassifierOptions: + """Options for the text classifier task. + + Attributes: + base_options: Base options for the text classifier task. + classifier_options: Options for the text classification task. + """ + base_options: _BaseOptions + classifier_options: _ClassifierOptions = _ClassifierOptions() + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _TextClassifierGraphOptionsProto: + """Generates an TextClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + classifier_options_proto = self.classifier_options.to_pb2() + + return _TextClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto) + + +class TextClassifier(base_text_task_api.BaseTextTaskApi): + """Class that performs classification on text.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'TextClassifier': + """Creates an `TextClassifier` object from a TensorFlow Lite model and the default `TextClassifierOptions`. + + Args: + model_path: Path to the model. + + Returns: + `TextClassifier` object that's created from the model file and the + default `TextClassifierOptions`. + + Raises: + ValueError: If failed to create `TextClassifier` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = TextClassifierOptions(base_options=base_options) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: TextClassifierOptions) -> 'TextClassifier': + """Creates the `TextClassifier` object from text classifier options. + + Args: + options: Options for the text classifier task. + + Returns: + `TextClassifier` object that's created from `options`. + + Raises: + ValueError: If failed to create `TextClassifier` object from + `TextClassifierOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])], + output_streams=[ + ':'.join([ + _CLASSIFICATION_RESULT_TAG, + _CLASSIFICATION_RESULT_OUT_STREAM_NAME + ]) + ], + task_options=options) + return cls(task_info.generate_graph_config()) + + def classify(self, text: str) -> TextClassifierResult: + """Performs classification on the input `text`. + + Args: + text: The input text. + + Returns: + A `TextClassifierResult` object that contains a list of text + classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If text classification failed to run. + """ + output_packets = self._runner.process( + {_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)}) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + + return TextClassifierResult([ + classifications.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD new file mode 100644 index 000000000..bc3048df1 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -0,0 +1,33 @@ +# This contains the MediaPipe Audio Classifier Task. +# +# This task takes audio data and outputs the classification result. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_classifier", + srcs = [ + "audio_classifier.ts", + "audio_classifier_options.ts", + "audio_classifier_result.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts new file mode 100644 index 000000000..fd79487a4 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -0,0 +1,215 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {AudioClassifierOptions} from './audio_classifier_options'; +import {Classifications} from './audio_classifier_result'; + +const MEDIAPIPE_GRAPH = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + +// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and +// cannot be changed +// TODO: Change this to `audio_in` to match the name in the CC +// implementation +const AUDIO_STREAM = 'input_audio'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const CLASSIFICATION_RESULT_STREAM = 'classification_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs audio classification. */ +export class AudioClassifier extends TaskRunner { + private classifications: Classifications[] = []; + private defaultSampleRate = 48000; + private readonly options = new AudioClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new audio classifier from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param audioClassifierOptions The options for the audio classifier. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + audioClassifierOptions: AudioClassifierOptions): + Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const classifier = await createMediaPipeLib( + AudioClassifier, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await classifier.setOptions(audioClassifierOptions); + return classifier; + } + + /** + * Initializes the Wasm runtime and creates a new audio classifier based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return AudioClassifier.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio classifier based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return AudioClassifier.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the audio classifier. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio classifier. + */ + async setOptions(options: AudioClassifierOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + this.refreshGraph(); + } + + /** + * Sets the sample rate for all calls to `classify()` that omit an explicit + * sample rate. `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** + * Performs audio classification on the provided audio data and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The classification result of the audio datas + */ + classify(audioData: Float32Array, sampleRate?: number): Classifications[] { + sampleRate = sampleRate ?? this.defaultSampleRate; + + // Configures the number of samples in the WASM layer. We re-configure the + // number of samples and the sample rate for every frame, but ignore other + // side effects of this function (such as sending the input side packet and + // the input stream header). + this.configureAudio( + /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); + + const timestamp = performance.now(); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); + this.addAudioToStream(audioData, timestamp); + + this.classifications = []; + this.finishProcessing(); + return [...this.classifications]; + } + + /** + * Internal function for converting raw data into a classification, and + * adding it to our classfications list. + **/ + private addJsAudioClassification(binaryProto: Uint8Array): void { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + this.classifications.push( + ...convertFromClassificationResultProto(classificationResult)); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioClassifierGraphOptions.ext, this.options); + + // Perform audio classification. Pre-processing and results post-processing + // are built-in. + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(MEDIAPIPE_GRAPH); + classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); + classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + classifierNode.addOutputStream( + 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { + this.addJsAudioClassification(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts new file mode 100644 index 000000000..93bd9927e --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts new file mode 100644 index 000000000..0a51dee04 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Category} from '../../../../tasks/web/components/containers/category'; +export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; diff --git a/mediapipe/tasks/web/components/processors/classifier_options.ts b/mediapipe/tasks/web/components/processors/classifier_options.ts index 8e01dd373..5b8ae796e 100644 --- a/mediapipe/tasks/web/components/processors/classifier_options.ts +++ b/mediapipe/tasks/web/components/processors/classifier_options.ts @@ -29,31 +29,31 @@ export function convertClassifierOptionsToProto( baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto { const classifierOptions = baseOptions ? baseOptions.clone() : new ClassifierOptionsProto(); - if (options.displayNamesLocale) { + if (options.displayNamesLocale !== undefined) { classifierOptions.setDisplayNamesLocale(options.displayNamesLocale); } else if (options.displayNamesLocale === undefined) { classifierOptions.clearDisplayNamesLocale(); } - if (options.maxResults) { + if (options.maxResults !== undefined) { classifierOptions.setMaxResults(options.maxResults); } else if ('maxResults' in options) { // Check for undefined classifierOptions.clearMaxResults(); } - if (options.scoreThreshold) { + if (options.scoreThreshold !== undefined) { classifierOptions.setScoreThreshold(options.scoreThreshold); } else if ('scoreThreshold' in options) { // Check for undefined classifierOptions.clearScoreThreshold(); } - if (options.categoryAllowlist) { + if (options.categoryAllowlist !== undefined) { classifierOptions.setCategoryAllowlistList(options.categoryAllowlist); } else if ('categoryAllowlist' in options) { // Check for undefined classifierOptions.clearCategoryAllowlistList(); } - if (options.categoryDenylist) { + if (options.categoryDenylist !== undefined) { classifierOptions.setCategoryDenylistList(options.categoryDenylist); } else if ('categoryDenylist' in options) { // Check for undefined classifierOptions.clearCategoryDenylistList(); diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index a5547ad6e..4fb57d6c3 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -12,6 +12,18 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "task_runner", + srcs = [ + "task_runner.ts", + ], + deps = [ + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) + mediapipe_ts_library( name = "classifier_options", srcs = [ diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts new file mode 100644 index 000000000..c948930fc --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -0,0 +1,83 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; +import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; +import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; + +// tslint:disable-next-line:enforce-name-casing +const WasmMediaPipeImageLib = + SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); + +/** Base class for all MediaPipe Tasks. */ +export abstract class TaskRunner extends WasmMediaPipeImageLib { + private processingErrors: Error[] = []; + + constructor(wasmModule: WasmModule) { + super(wasmModule); + + // Disables the automatic render-to-screen code, which allows for pure + // CPU processing. + this.setAutoRenderToScreen(false); + + // Enables use of our model resource caching graph service. + this.registerModelResourcesGraphService(); + } + + /** + * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run + * over the video stream. Will replace the previously running MediaPipe graph, + * if there is one. + * @param graphData The raw MediaPipe graph data, either in binary + * protobuffer format (.binarypb), or else in raw text format (.pbtxt or + * .textproto). + * @param isBinary This should be set to true if the graph is in + * binary format, and false if it is in human-readable text format. + */ + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.attachErrorListener((code, message) => { + this.processingErrors.push(new Error(message)); + }); + super.setGraph(graphData, isBinary); + this.handleErrors(); + } + + /** + * Forces all queued-up packets to be pushed through the MediaPipe graph as + * far as possible, performing all processing until no more processing can be + * done. + */ + override finishProcessing(): void { + super.finishProcessing(); + this.handleErrors(); + } + + /** Throws the error from the error listener if an error was raised. */ + private handleErrors() { + const errorCount = this.processingErrors.length; + if (errorCount === 1) { + // Re-throw error to get a more meaningful stacktrace + throw new Error(this.processingErrors[0].message); + } else if (errorCount > 1) { + throw new Error( + 'Encountered multiple errors: ' + + this.processingErrors.map(e => e.message).join(', ')); + } + this.processingErrors = []; + } +} + + diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD new file mode 100644 index 000000000..e984a9554 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -0,0 +1,34 @@ +# This contains the MediaPipe Text Classifier Task. +# +# This task takes text input performs Natural Language classification (including +# BERT-based text classification). + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "text_classifier", + srcs = [ + "text_classifier.ts", + "text_classifier_options.d.ts", + "text_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts new file mode 100644 index 000000000..ff36bb9e0 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -0,0 +1,180 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {TextClassifierOptions} from './text_classifier_options'; +import {Classifications} from './text_classifier_result'; + +const INPUT_STREAM = 'text_in'; +const CLASSIFICATION_RESULT_STREAM = 'classification_result_out'; +const TEXT_CLASSIFIER_GRAPH = + 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs Natural Language classification. */ +export class TextClassifier extends TaskRunner { + private classifications: Classifications[] = []; + private readonly options = new TextClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new text classifier from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param textClassifierOptions The options for the text classifier. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + textClassifierOptions: TextClassifierOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const classifier = await createMediaPipeLib( + TextClassifier, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await classifier.setOptions(textClassifierOptions); + return classifier; + } + + /** + * Initializes the Wasm runtime and creates a new text classifier based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return TextClassifier.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new text classifier based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return TextClassifier.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the text classifier. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the text classifier. + */ + async setOptions(options: TextClassifierOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + this.refreshGraph(); + } + + + /** + * Performs Natural Language classification on the provided text and waits + * synchronously for the response. + * + * @param text The text to process. + * @return The classification result of the text + */ + classify(text: string): Classifications[] { + // Get classification classes by running our MediaPipe graph. + this.classifications = []; + this.addStringToStream( + text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.finishProcessing(); + return [...this.classifications]; + } + + // Internal function for converting raw data into a classification, and + // adding it to our classifications list. + private addJsTextClassification(binaryProto: Uint8Array): void { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + console.log(classificationResult.toObject()); + this.classifications.push( + ...convertFromClassificationResultProto(classificationResult)); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + TextClassifierGraphOptions.ext, this.options); + + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH); + classifierNode.addInputStream('TEXT:' + INPUT_STREAM); + classifierNode.addOutputStream( + 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { + this.addJsTextClassification(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts new file mode 100644 index 000000000..51b2b3947 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts new file mode 100644 index 000000000..0a51dee04 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Category} from '../../../../tasks/web/components/containers/category'; +export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD new file mode 100644 index 000000000..8988c4794 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -0,0 +1,40 @@ +# This contains the MediaPipe Gesture Recognizer Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more gesture categories, using Gesture Recognizer. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "gesture_recognizer", + srcs = [ + "gesture_recognizer.ts", + "gesture_recognizer_options.d.ts", + "gesture_recognizer_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts new file mode 100644 index 000000000..ad8db1477 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -0,0 +1,374 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationList} from '../../../../framework/formats/classification_pb'; +import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; +import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; +import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb'; +import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; +import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; +import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {GestureRecognizerOptions} from './gesture_recognizer_options'; +import {GestureRecognitionResult} from './gesture_recognizer_result'; + +export {ImageSource}; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const HAND_GESTURES_STREAM = 'hand_gestures'; +const LANDMARKS_STREAM = 'hand_landmarks'; +const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks'; +const HANDEDNESS_STREAM = 'handedness'; +const GESTURE_RECOGNIZER_GRAPH = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + +const DEFAULT_NUM_HANDS = 1; +const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CATEGORY_INDEX = -1; + +const FULL_IMAGE_RECT = new NormalizedRect(); +FULL_IMAGE_RECT.setXCenter(0.5); +FULL_IMAGE_RECT.setYCenter(0.5); +FULL_IMAGE_RECT.setWidth(1); +FULL_IMAGE_RECT.setHeight(1); + +/** Performs hand gesture recognition on images. */ +export class GestureRecognizer extends TaskRunner { + private gestures: Category[][] = []; + private landmarks: Landmark[][] = []; + private worldLandmarks: Landmark[][] = []; + private handednesses: Category[][] = []; + + private readonly options: GestureRecognizerGraphOptions; + private readonly handLandmarkerGraphOptions: HandLandmarkerGraphOptions; + private readonly handLandmarksDetectorGraphOptions: + HandLandmarksDetectorGraphOptions; + private readonly handDetectorGraphOptions: HandDetectorGraphOptions; + private readonly handGestureRecognizerGraphOptions: + HandGestureRecognizerGraphOptions; + + /** + * Initializes the Wasm runtime and creates a new gesture recognizer from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param gestureRecognizerOptions The options for the gesture recognizer. + * Note that either a path to the model asset or a model buffer needs to + * be provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + gestureRecognizerOptions: GestureRecognizerOptions): + Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load via this mechanism is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const recognizer = await createMediaPipeLib( + GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await recognizer.setOptions(gestureRecognizerOptions); + return recognizer; + } + + /** + * Initializes the Wasm runtime and creates a new gesture recognizer based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return GestureRecognizer.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new gesture recognizer based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return GestureRecognizer.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + constructor(wasmModule: WasmModule) { + super(wasmModule); + + this.options = new GestureRecognizerGraphOptions(); + this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); + this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); + this.handLandmarksDetectorGraphOptions = + new HandLandmarksDetectorGraphOptions(); + this.handLandmarkerGraphOptions.setHandLandmarksDetectorGraphOptions( + this.handLandmarksDetectorGraphOptions); + this.handDetectorGraphOptions = new HandDetectorGraphOptions(); + this.handLandmarkerGraphOptions.setHandDetectorGraphOptions( + this.handDetectorGraphOptions); + this.handGestureRecognizerGraphOptions = + new HandGestureRecognizerGraphOptions(); + this.options.setHandGestureRecognizerGraphOptions( + this.handGestureRecognizerGraphOptions); + + this.initDefaults(); + + // Disables the automatic render-to-screen code, which allows for pure + // CPU processing. + this.setAutoRenderToScreen(false); + } + + /** + * Sets new options for the gesture recognizer. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the gesture recognizer. + */ + async setOptions(options: GestureRecognizerOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + if ('numHands' in options) { + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + } + if ('minHandDetectionConfidence' in options) { + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minHandPresenceConfidence' in options) { + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minTrackingConfidence' in options) { + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + if (options.cannedGesturesClassifierOptions) { + // Note that we have to support both JSPB and ProtobufJS and cannot + // use JSPB's getMutableX() APIs. + const graphOptions = new GestureClassifierGraphOptions(); + graphOptions.setClassifierOptions(convertClassifierOptionsToProto( + options.cannedGesturesClassifierOptions, + this.handGestureRecognizerGraphOptions + .getCannedGestureClassifierGraphOptions() + ?.getClassifierOptions())); + this.handGestureRecognizerGraphOptions + .setCannedGestureClassifierGraphOptions(graphOptions); + } else if (options.cannedGesturesClassifierOptions === undefined) { + this.handGestureRecognizerGraphOptions + .getCannedGestureClassifierGraphOptions() + ?.clearClassifierOptions(); + } + + if (options.customGesturesClassifierOptions) { + const graphOptions = new GestureClassifierGraphOptions(); + graphOptions.setClassifierOptions(convertClassifierOptionsToProto( + options.customGesturesClassifierOptions, + this.handGestureRecognizerGraphOptions + .getCustomGestureClassifierGraphOptions() + ?.getClassifierOptions())); + this.handGestureRecognizerGraphOptions + .setCustomGestureClassifierGraphOptions(graphOptions); + } else if (options.customGesturesClassifierOptions === undefined) { + this.handGestureRecognizerGraphOptions + .getCustomGestureClassifierGraphOptions() + ?.clearClassifierOptions(); + } + + this.refreshGraph(); + } + + /** + * Performs gesture recognition on the provided single image and waits + * synchronously for the response. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The detected gestures. + */ + recognize(imageSource: ImageSource, timestamp: number = performance.now()): + GestureRecognitionResult { + this.gestures = []; + this.landmarks = []; + this.worldLandmarks = []; + this.handednesses = []; + + this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); + this.addProtoToStream( + FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', + NORM_RECT_STREAM, timestamp); + this.finishProcessing(); + + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } + + /** Sets the default values for the graph. */ + private initDefaults(): void { + this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + DEFAULT_SCORE_THRESHOLD); + } + + /** Converts the proto data to a Category[][] structure. */ + private toJsCategories(data: Uint8Array[]): Category[][] { + const result: Category[][] = []; + for (const binaryProto of data) { + const inputList = ClassificationList.deserializeBinary(binaryProto); + const outputList: Category[] = []; + for (const classification of inputList.getClassificationList()) { + outputList.push({ + score: classification.getScore() ?? 0, + index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + categoryName: classification.getLabel() ?? '', + displayName: classification.getDisplayName() ?? '', + }); + } + result.push(outputList); + } + return result; + } + + /** Converts raw data into a landmark, and adds it to our landmarks list. */ + private addJsLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handLandmarksProto = + NormalizedLandmarkList.deserializeBinary(binaryProto); + const landmarks: Landmark[] = []; + for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { + landmarks.push({ + x: handLandmarkProto.getX() ?? 0, + y: handLandmarkProto.getY() ?? 0, + z: handLandmarkProto.getZ() ?? 0, + normalized: true + }); + } + this.landmarks.push(landmarks); + } + } + + /** + * Converts raw data into a landmark, and adds it to our worldLandmarks + * list. + */ + private adddJsWorldLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handWorldLandmarksProto = + LandmarkList.deserializeBinary(binaryProto); + const worldLandmarks: Landmark[] = []; + for (const handWorldLandmarkProto of + handWorldLandmarksProto.getLandmarkList()) { + worldLandmarks.push({ + x: handWorldLandmarkProto.getX() ?? 0, + y: handWorldLandmarkProto.getY() ?? 0, + z: handWorldLandmarkProto.getZ() ?? 0, + normalized: false + }); + } + this.worldLandmarks.push(worldLandmarks); + } + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(HAND_GESTURES_STREAM); + graphConfig.addOutputStream(LANDMARKS_STREAM); + graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM); + graphConfig.addOutputStream(HANDEDNESS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + GestureRecognizerGraphOptions.ext, this.options); + + const recognizerNode = new CalculatorGraphConfig.Node(); + recognizerNode.setCalculator(GESTURE_RECOGNIZER_GRAPH); + recognizerNode.addInputStream('IMAGE:' + IMAGE_STREAM); + recognizerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + recognizerNode.addOutputStream('HAND_GESTURES:' + HAND_GESTURES_STREAM); + recognizerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM); + recognizerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM); + recognizerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM); + recognizerNode.setOptions(calculatorOptions); + + graphConfig.addNode(recognizerNode); + + this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts new file mode 100644 index 000000000..16169a93f --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -0,0 +1,65 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +/** Options to configure the MediaPipe Gesture Recognizer Task */ +export interface GestureRecognizerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The maximum number of hands can be detected by the GestureRecognizer. + * Defaults to 1. + */ + numHands?: number|undefined; + + /** + * The minimum confidence score for the hand detection to be considered + * successful. Defaults to 0.5. + */ + minHandDetectionConfidence?: number|undefined; + + /** + * The minimum confidence score of hand presence score in the hand landmark + * detection. Defaults to 0.5. + */ + minHandPresenceConfidence?: number|undefined; + + /** + * The minimum confidence score for the hand tracking to be considered + * successful. Defaults to 0.5. + */ + minTrackingConfidence?: number|undefined; + + /** + * Sets the optional `ClassifierOptions` controling the canned gestures + * classifier, such as score threshold, allow list and deny list of gestures. + * The categories for canned gesture + * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", + * "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"] + */ + // TODO: Note this option is subject to change + cannedGesturesClassifierOptions?: ClassifierOptions|undefined; + + /** + * Options for configuring the custom gestures classifier, such as score + * threshold, allow list and deny list of gestures. + */ + // TODO b/251816640): Note this option is subject to change. + customGesturesClassifierOptions?: ClassifierOptions|undefined; +} diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts new file mode 100644 index 000000000..cccdfaf68 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -0,0 +1,35 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark} from '../../../../tasks/web/components/containers/landmark'; + +/** + * Represents the gesture recognition results generated by `GestureRecognizer`. + */ +export interface GestureRecognitionResult { + /** Hand landmarks of detected hands. */ + landmarks: Landmark[][]; + + /** Hand landmarks in world coordniates of detected hands. */ + worldLandmarks: Landmark[][]; + + /** Handedness of detected hands. */ + handednesses: Category[][]; + + /** Recognized hand gestures of detected hands */ + gestures: Category[][]; +} diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD new file mode 100644 index 000000000..6937dc4f3 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -0,0 +1,33 @@ +# This contains the MediaPipe Image Classifier Task. +# +# This task takes video or image frames and outputs the classification result. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_classifier", + srcs = [ + "image_classifier.ts", + "image_classifier_options.ts", + "image_classifier_result.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts new file mode 100644 index 000000000..39674e85c --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -0,0 +1,186 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {ImageClassifierOptions} from './image_classifier_options'; +import {Classifications} from './image_classifier_result'; + +const IMAGE_CLASSIFIER_GRAPH = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; +const INPUT_STREAM = 'input_image'; +const CLASSIFICATION_RESULT_STREAM = 'classification_result'; + +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs classification on images. */ +export class ImageClassifier extends TaskRunner { + private classifications: Classifications[] = []; + private readonly options = new ImageClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new image classifier from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param imageClassifierOptions The options for the image classifier. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + imageClassifierOptions: ImageClassifierOptions): + Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const classifier = await createMediaPipeLib( + ImageClassifier, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await classifier.setOptions(imageClassifierOptions); + return classifier; + } + + /** + * Initializes the Wasm runtime and creates a new image classifier based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return ImageClassifier.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image classifier based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return ImageClassifier.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the image classifier. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the image classifier. + */ + async setOptions(options: ImageClassifierOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + this.refreshGraph(); + } + + /** + * Performs image classification on the provided image and waits synchronously + * for the response. + * + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The classification result of the image + */ + classify(imageSource: ImageSource, timestamp?: number): Classifications[] { + // Get classification classes by running our MediaPipe graph. + this.classifications = []; + this.addGpuBufferAsImageToStream( + imageSource, INPUT_STREAM, timestamp ?? performance.now()); + this.finishProcessing(); + return [...this.classifications]; + } + + /** + * Internal function for converting raw data into a classification, and + * adding it to our classfications list. + **/ + private addJsImageClassification(binaryProto: Uint8Array): void { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + this.classifications.push( + ...convertFromClassificationResultProto(classificationResult)); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageClassifierGraphOptions.ext, this.options); + + // Perform image classification. Pre-processing and results post-processing + // are built-in. + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); + classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); + classifierNode.addOutputStream( + 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { + this.addJsImageClassification(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts new file mode 100644 index 000000000..a5f5c2386 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts new file mode 100644 index 000000000..0a51dee04 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Category} from '../../../../tasks/web/components/containers/category'; +export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD new file mode 100644 index 000000000..888537bd1 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -0,0 +1,30 @@ +# This contains the MediaPipe Object Detector Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more object categories, using Object Detector. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "object_detector", + srcs = [ + "object_detector.ts", + "object_detector_options.d.ts", + "object_detector_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts new file mode 100644 index 000000000..c3bb21baa --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -0,0 +1,233 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {ObjectDetectorOptions} from './object_detector_options'; +import {Detection} from './object_detector_result'; + +const INPUT_STREAM = 'input_frame_gpu'; +const DETECTIONS_STREAM = 'detections'; +const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + +const DEFAULT_CATEGORY_INDEX = -1; + +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs object detection on images. */ +export class ObjectDetector extends TaskRunner { + private detections: Detection[] = []; + private readonly options = new ObjectDetectorOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new object detector from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param objectDetectorOptions The options for the Object Detector. Note that + * either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + objectDetectorOptions: ObjectDetectorOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const detector = await createMediaPipeLib( + ObjectDetector, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await detector.setOptions(objectDetectorOptions); + return detector; + } + + /** + * Initializes the Wasm runtime and creates a new object detector based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return ObjectDetector.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new object detector based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return ObjectDetector.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the object detector. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the object detector. + */ + async setOptions(options: ObjectDetectorOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + // Note that we have to support both JSPB and ProtobufJS, hence we + // have to expliclity clear the values instead of setting them to + // `undefined`. + if (options.displayNamesLocale !== undefined) { + this.options.setDisplayNamesLocale(options.displayNamesLocale); + } else if ('displayNamesLocale' in options) { // Check for undefined + this.options.clearDisplayNamesLocale(); + } + + if (options.maxResults !== undefined) { + this.options.setMaxResults(options.maxResults); + } else if ('maxResults' in options) { // Check for undefined + this.options.clearMaxResults(); + } + + if (options.scoreThreshold !== undefined) { + this.options.setScoreThreshold(options.scoreThreshold); + } else if ('scoreThreshold' in options) { // Check for undefined + this.options.clearScoreThreshold(); + } + + if (options.categoryAllowlist !== undefined) { + this.options.setCategoryAllowlistList(options.categoryAllowlist); + } else if ('categoryAllowlist' in options) { // Check for undefined + this.options.clearCategoryAllowlistList(); + } + + if (options.categoryDenylist !== undefined) { + this.options.setCategoryDenylistList(options.categoryDenylist); + } else if ('categoryDenylist' in options) { // Check for undefined + this.options.clearCategoryDenylistList(); + } + + this.refreshGraph(); + } + + /** + * Performs object detection on the provided single image and waits + * synchronously for the response. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The list of detected objects + */ + detect(imageSource: ImageSource, timestamp?: number): Detection[] { + // Get detections by running our MediaPipe graph. + this.detections = []; + this.addGpuBufferAsImageToStream( + imageSource, INPUT_STREAM, timestamp ?? performance.now()); + this.finishProcessing(); + return [...this.detections]; + } + + /** Converts raw data into a Detection, and adds it to our detection list. */ + private addJsObjectDetections(data: Uint8Array[]): void { + for (const binaryProto of data) { + const detectionProto = DetectionProto.deserializeBinary(binaryProto); + const scores = detectionProto.getScoreList(); + const indexes = detectionProto.getLabelIdList(); + const labels = detectionProto.getLabelList(); + const displayNames = detectionProto.getDisplayNameList(); + + const detection: Detection = {categories: []}; + for (let i = 0; i < scores.length; i++) { + detection.categories.push({ + score: scores[i], + index: indexes[i] ?? DEFAULT_CATEGORY_INDEX, + categoryName: labels[i] ?? '', + displayName: displayNames[i] ?? '', + }); + } + + const boundingBox = detectionProto.getLocationData()?.getBoundingBox(); + if (boundingBox) { + detection.boundingBox = { + originX: boundingBox.getXmin() ?? 0, + originY: boundingBox.getYmin() ?? 0, + width: boundingBox.getWidth() ?? 0, + height: boundingBox.getHeight() ?? 0 + }; + } + + this.detections.push(detection); + } + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(DETECTIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ObjectDetectorOptionsProto.ext, this.options); + + const detectorNode = new CalculatorGraphConfig.Node(); + detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH); + detectorNode.addInputStream('IMAGE:' + INPUT_STREAM); + detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); + detectorNode.setOptions(calculatorOptions); + + graphConfig.addNode(detectorNode); + + this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts new file mode 100644 index 000000000..eec12cf17 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -0,0 +1,52 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** Options to configure the MediaPipe Object Detector Task */ +export interface ObjectDetectorOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ + displayNamesLocale?: string|undefined; + + /** The maximum number of top-scored detection results to return. */ + maxResults?: number|undefined; + + /** + * Overrides the value provided in the model metadata. Results below this + * value are rejected. + */ + scoreThreshold?: number|undefined; + + /** + * Allowlist of category names. If non-empty, detection results whose category + * name is not in this set will be filtered out. Duplicate or unknown category + * names are ignored. Mutually exclusive with `categoryDenylist`. + */ + categoryAllowlist?: string[]|undefined; + + /** + * Denylist of category names. If non-empty, detection results whose category + * name is in this set will be filtered out. Duplicate or unknown category + * names are ignored. Mutually exclusive with `categoryAllowlist`. + */ + categoryDenylist?: string[]|undefined; +} diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts new file mode 100644 index 000000000..7b2621134 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts @@ -0,0 +1,38 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; + +/** An integer bounding box, axis aligned. */ +export interface BoundingBox { + /** The X coordinate of the top-left corner, in pixels. */ + originX: number; + /** The Y coordinate of the top-left corner, in pixels. */ + originY: number; + /** The width of the bounding box, in pixels. */ + width: number; + /** The height of the bounding box, in pixels. */ + height: number; +} + +/** Represents one object detected by the `ObjectDetector`. */ +export interface Detection { + /** A list of `Category` objects. */ + categories: Category[]; + + /** The bounding box of the detected objects. */ + boundingBox?: BoundingBox; +} diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD new file mode 100644 index 000000000..dab6be50f --- /dev/null +++ b/mediapipe/web/graph_runner/BUILD @@ -0,0 +1,41 @@ +# The TypeScript graph runner used by all MediaPipe Web tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = [ + ":internal", + "//mediapipe/tasks:internal", +]) + +package_group( + name = "internal", + packages = [ + "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", + ], +) + +mediapipe_ts_library( + name = "wasm_mediapipe_lib_ts", + srcs = [ + ":wasm_mediapipe_lib.ts", + ], + allow_unoptimized_namespaces = True, +) + +mediapipe_ts_library( + name = "wasm_mediapipe_image_lib_ts", + srcs = [ + ":wasm_mediapipe_image_lib.ts", + ], + allow_unoptimized_namespaces = True, + deps = [":wasm_mediapipe_lib_ts"], +) + +mediapipe_ts_library( + name = "register_model_resources_graph_service_ts", + srcs = [ + ":register_model_resources_graph_service.ts", + ], + allow_unoptimized_namespaces = True, + deps = [":wasm_mediapipe_lib_ts"], +) diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts new file mode 100644 index 000000000..e85d63b06 --- /dev/null +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -0,0 +1,41 @@ +import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; + +/** + * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * access to the wasmModule, among other things. The `any` type is required for + * mixin constructors. + */ +// tslint:disable-next-line:no-any +type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; + +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmModuleRegisterModelResources { + _registerModelResourcesGraphService: () => void; +} + +/** + * An implementation of WasmMediaPipeLib that supports registering model + * resources to a cache, in the form of a GraphService C++-side. We implement as + * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: + * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( + * WasmMediaPipeLib);` + */ +// tslint:disable:enforce-name-casing +export function SupportModelResourcesGraphService( + Base: TBase) { + return class extends Base { + // tslint:enable:enforce-name-casing + /** + * Instructs the graph runner to use the model resource caching graph + * service for both graph expansion/inintialization, as well as for graph + * run. + */ + registerModelResourcesGraphService(): void { + (this.wasmModule as unknown as WasmModuleRegisterModelResources) + ._registerModelResourcesGraphService(); + } + }; +} diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts new file mode 100644 index 000000000..3b45e8230 --- /dev/null +++ b/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts @@ -0,0 +1,52 @@ +import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; + +/** + * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * access to the wasmModule, among other things. The `any` type is required for + * mixin constructors. + */ +// tslint:disable-next-line:no-any +type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; + +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmImageModule { + _addBoundTextureAsImageToStream: + (streamNamePtr: number, width: number, height: number, + timestamp: number) => void; +} + +/** + * An implementation of WasmMediaPipeLib that supports binding GPU image data as + * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for + * effective multiple inheritance. Example usage: + * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` + */ +// tslint:disable-next-line:enforce-name-casing +export function SupportImage(Base: TBase) { + return class extends Base { + /** + * Takes the relevant information from the HTML video or image element, and + * passes it into the WebGL-based graph for processing on the given stream + * at the given timestamp as a MediaPipe image. Processing will not occur + * until a blocking call (like processVideoGl or finishProcessing) is made. + * @param imageSource Reference to the video frame we wish to add into our + * graph. + * @param streamName The name of the MediaPipe graph stream to add the frame + * to. + * @param timestamp The timestamp of the input frame, in ms. + */ + addGpuBufferAsImageToStream( + imageSource: ImageSource, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + const [width, height] = + this.bindTextureToStream(imageSource, streamNamePtr); + (this.wasmModule as unknown as WasmImageModule) + ._addBoundTextureAsImageToStream( + streamNamePtr, width, height, timestamp); + }); + } + }; +} diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts new file mode 100644 index 000000000..714f42134 --- /dev/null +++ b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts @@ -0,0 +1,1044 @@ +// Placeholder for internal dependency on assertTruthy +// Placeholder for internal dependency on jsloader +// Placeholder for internal dependency on trusted resource url + +// This file can serve as a common interface for most simple TypeScript +// libraries-- additionally, it can hook automatically into wasm_mediapipe_demo +// to autogenerate simple TS APIs from demos for instantaneous 1P integrations. + +/** + * Simple interface for allowing users to set the directory where internal + * wasm-loading and asset-loading code looks (e.g. for .wasm and .data file + * locations). + */ +export declare interface FileLocator { + locateFile: (filename: string) => string; +} + +/** Listener to be passed in by user for handling output audio data. */ +export type AudioOutputListener = (output: Float32Array) => void; + +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmModule { + canvas: HTMLCanvasElement|OffscreenCanvas|null; + HEAPU8: Uint8Array; + HEAPU32: Uint32Array; + HEAPF32: Float32Array; + HEAPF64: Float64Array; + errorListener?: ErrorListener; + _bindTextureToCanvas: () => boolean; + _changeBinaryGraph: (size: number, dataPtr: number) => void; + _changeTextGraph: (size: number, dataPtr: number) => void; + _configureAudio: + (channels: number, samples: number, sampleRate: number) => void; + _free: (ptr: number) => void; + _malloc: (size: number) => number; + _processAudio: (dataPtr: number, timestamp: number) => void; + _processFrame: (width: number, height: number, timestamp: number) => void; + _setAutoRenderToScreen: (enabled: boolean) => void; + _waitUntilIdle: () => void; + + // Exposed so that clients of this lib can access this field + dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; + // Wasm module will call us back at this function when given audio data. + onAudioOutput?: AudioOutputListener; + + // Wasm Module multistream entrypoints. Require + // gl_graph_runner_internal_multi_input as a build dependency. + stringToNewUTF8: (data: string) => number; + _bindTextureToStream: (streamNamePtr: number) => void; + _addBoundTextureToStream: + (streamNamePtr: number, width: number, height: number, + timestamp: number) => void; + _addBoolToInputStream: + (data: boolean, streamNamePtr: number, timestamp: number) => void; + _addDoubleToInputStream: + (data: number, streamNamePtr: number, timestamp: number) => void; + _addFloatToInputStream: + (data: number, streamNamePtr: number, timestamp: number) => void; + _addIntToInputStream: + (data: number, streamNamePtr: number, timestamp: number) => void; + _addStringToInputStream: + (dataPtr: number, streamNamePtr: number, timestamp: number) => void; + _addFlatHashMapToInputStream: + (keysPtr: number, valuesPtr: number, count: number, streamNamePtr: number, + timestamp: number) => void; + _addProtoToInputStream: + (dataPtr: number, dataSize: number, protoNamePtr: number, + streamNamePtr: number, timestamp: number) => void; + // Input side packets + _addBoolToInputSidePacket: (data: boolean, streamNamePtr: number) => void; + _addDoubleToInputSidePacket: (data: number, streamNamePtr: number) => void; + _addFloatToInputSidePacket: (data: number, streamNamePtr: number) => void; + _addIntToInputSidePacket: (data: number, streamNamePtr: number) => void; + _addStringToInputSidePacket: (dataPtr: number, streamNamePtr: number) => void; + _addProtoToInputSidePacket: + (dataPtr: number, dataSize: number, protoNamePtr: number, + streamNamePtr: number) => void; + + // Wasm Module output listener entrypoints. Also built as part of + // gl_graph_runner_internal_multi_input. + simpleListeners?: {[outputStreamName: string]: (data: unknown) => void}; + vectorListeners?: { + [outputStreamName: string]: ( + data: unknown, index: number, length: number) => void + }; + _attachBoolListener: (streamNamePtr: number) => void; + _attachBoolVectorListener: (streamNamePtr: number) => void; + _attachDoubleListener: (streamNamePtr: number) => void; + _attachDoubleVectorListener: (streamNamePtr: number) => void; + _attachFloatListener: (streamNamePtr: number) => void; + _attachFloatVectorListener: (streamNamePtr: number) => void; + _attachIntListener: (streamNamePtr: number) => void; + _attachIntVectorListener: (streamNamePtr: number) => void; + _attachStringListener: (streamNamePtr: number) => void; + _attachStringVectorListener: (streamNamePtr: number) => void; + _attachProtoListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + _attachProtoVectorListener: + (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Requires dependency ":gl_graph_runner_audio_out", and will register an + // audio output listening function which can be tapped into dynamically during + // graph running via onAudioOutput. This call must be made before graph is + // initialized, but after wasmModule is instantiated. + _attachAudioOutputListener: () => void; + + // TODO: Refactor to just use a few numbers (perhaps refactor away + // from gl_graph_runner_internal.cc entirely to use something a little more + // streamlined; new version is _processFrame above). + _processGl: (frameDataPtr: number) => number; +} + +// Global declarations, for tapping into Window for Wasm blob running +declare global { + interface Window { + // Created by us using wasm-runner script + Module?: WasmModule|FileLocator; + // Created by wasm-runner script + ModuleFactory?: (fileLocator: FileLocator) => Promise; + } +} + +/** + * Fetches each URL in urls, executes them one-by-one in the order they are + * passed, and then returns (or throws if something went amiss). + */ +declare function importScripts(...urls: Array): void; + +/** + * Valid types of image sources which we can run our WasmMediaPipeLib over. + */ +export type ImageSource = + HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; + + +/** A listener that will be invoked with an absl::StatusCode and message. */ +export type ErrorListener = (code: number, message: string) => void; + +// Internal type of constructors used for initializing WasmMediaPipeLib and +// subclasses. +type WasmMediaPipeConstructor = + (new ( + module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => + LibType); + +/** + * Simple class to run an arbitrary image-in/image-out MediaPipe graph (i.e. + * as created by wasm_mediapipe_demo BUILD macro), and either render results + * into canvas, or else return the output WebGLTexture. Takes a WebAssembly + * Module (must be instantiated to self.Module). + */ +export class WasmMediaPipeLib { + // TODO: These should be protected/private, but are left exposed for + // now so that we can use proper TS mixins with this class as a base. This + // should be somewhat fixed when we create our .d.ts files. + readonly wasmModule: WasmModule; + readonly hasMultiStreamSupport: boolean; + autoResizeCanvas: boolean = true; + audioPtr: number|null; + audioSize: number; + + /** + * Creates a new MediaPipe WASM module. Must be called *after* wasm Module has + * initialized. Note that we take control of the GL canvas from here on out, + * and will resize it to fit input. + * + * @param module The underlying Wasm Module to use. + * @param glCanvas The type of the GL canvas to use, or `null` if no GL + * canvas should be initialzed. Initializes an offscreen canvas if not + * provided. + */ + constructor( + module: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + this.wasmModule = module; + this.audioPtr = null; + this.audioSize = 0; + this.hasMultiStreamSupport = + (typeof this.wasmModule._addIntToInputStream === 'function'); + + if (glCanvas !== undefined) { + this.wasmModule.canvas = glCanvas; + } else { + // If no canvas is provided, assume Chrome/Firefox and just make an + // OffscreenCanvas for GPU processing. + this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } + } + + /** + * Convenience helper to load a MediaPipe graph from a file and pass it to + * setGraph. + * @param graphFile The url of the MediaPipe graph file to load. + */ + async initializeGraph(graphFile: string): Promise { + // Fetch and set graph + const response = await fetch(graphFile); + const graphData = await response.arrayBuffer(); + const isBinary = + !(graphFile.endsWith('.pbtxt') || graphFile.endsWith('.textproto')); + this.setGraph(new Uint8Array(graphData), isBinary); + } + + /** + * Convenience helper for calling setGraph with a string representing a text + * proto config. + * @param graphConfig The text proto graph config, expected to be a string in + * default JavaScript UTF-16 format. + */ + setGraphFromString(graphConfig: string): void { + this.setGraph((new TextEncoder()).encode(graphConfig), false); + } + + /** + * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run + * over the video stream. Will replace the previously running MediaPipe graph, + * if there is one. + * @param graphData The raw MediaPipe graph data, either in binary + * protobuffer format (.binarypb), or else in raw text format (.pbtxt or + * .textproto). + * @param isBinary This should be set to true if the graph is in + * binary format, and false if it is in human-readable text format. + */ + setGraph(graphData: Uint8Array, isBinary: boolean): void { + const size = graphData.length; + const dataPtr = this.wasmModule._malloc(size); + this.wasmModule.HEAPU8.set(graphData, dataPtr); + if (isBinary) { + this.wasmModule._changeBinaryGraph(size, dataPtr); + } else { + this.wasmModule._changeTextGraph(size, dataPtr); + } + this.wasmModule._free(dataPtr); + } + + /** + * Configures the current graph to handle audio in a certain way. Must be + * called before the graph is set/started in order to use processAudio. + * @param numChannels The number of channels of audio input. Only 1 + * is supported for now. + * @param numSamples The number of samples that are taken in each + * audio capture. + * @param sampleRate The rate, in Hz, of the sampling. + */ + configureAudio(numChannels: number, numSamples: number, sampleRate: number) { + this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); + if (this.wasmModule._attachAudioOutputListener) { + this.wasmModule._attachAudioOutputListener(); + } + } + + /** + * Allows disabling automatic canvas resizing, in case clients want to control + * control this. + * @param resize True will re-enable automatic canvas resizing, while false + * will disable the feature. + */ + setAutoResizeCanvas(resize: boolean): void { + this.autoResizeCanvas = resize; + } + + /** + * Allows disabling the automatic render-to-screen code, in case clients don't + * need/want this. In particular, this removes the requirement for pipelines + * to have access to GPU resources, as well as the requirement for graphs to + * have "input_frames_gpu" and "output_frames_gpu" streams defined, so pure + * CPU pipelines and non-video pipelines can be created. + * NOTE: This only affects future graph initializations (via setGraph or + * initializeGraph), and does NOT affect the currently running graph, so + * calls to this should be made *before* setGraph/initializeGraph for the + * graph file being targeted. + * @param enabled True will re-enable automatic render-to-screen code and + * cause GPU resources to once again be requested, while false will + * disable the feature. + */ + setAutoRenderToScreen(enabled: boolean): void { + this.wasmModule._setAutoRenderToScreen(enabled); + } + + /** + * Bind texture to our internal canvas, and upload image source to GPU. + * Returns tuple [width, height] of texture. Intended for internal usage. + */ + bindTextureToStream(imageSource: ImageSource, streamNamePtr?: number): + [number, number] { + if (!this.wasmModule.canvas) { + throw new Error('No OpenGL canvas configured.'); + } + + if (!streamNamePtr) { + // TODO: Remove this path once completely refactored away. + console.assert(this.wasmModule._bindTextureToCanvas()); + } else { + this.wasmModule._bindTextureToStream(streamNamePtr); + } + const gl: any = + this.wasmModule.canvas.getContext('webgl2') || + this.wasmModule.canvas.getContext('webgl'); + console.assert(gl); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, imageSource); + + let width, height; + if ((imageSource as HTMLVideoElement).videoWidth) { + width = (imageSource as HTMLVideoElement).videoWidth; + height = (imageSource as HTMLVideoElement).videoHeight; + } else { + width = imageSource.width; + height = imageSource.height; + } + + if (this.autoResizeCanvas && + (width !== this.wasmModule.canvas.width || + height !== this.wasmModule.canvas.height)) { + this.wasmModule.canvas.width = width; + this.wasmModule.canvas.height = height; + } + + return [width, height]; + } + + /** + * Takes the raw data from a JS image source, and sends it to C++ to be + * processed, waiting synchronously for the response. Note that we will resize + * our GL canvas to fit the input, so input size should only change + * infrequently. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return texture? The WebGL texture reference, if one was produced. + */ + processGl(imageSource: ImageSource, timestamp: number): WebGLTexture + |undefined { + // Bind to default input stream + const [width, height] = this.bindTextureToStream(imageSource); + + // 2 ints and a ll (timestamp) + const frameDataPtr = this.wasmModule._malloc(16); + this.wasmModule.HEAPU32[frameDataPtr / 4] = width; + this.wasmModule.HEAPU32[(frameDataPtr / 4) + 1] = height; + this.wasmModule.HEAPF64[(frameDataPtr / 8) + 1] = timestamp; + // outputPtr points in HEAPF32-space to running mspf calculations, which we + // don't use at the moment. + // tslint:disable-next-line:no-unused-variable + const outputPtr = this.wasmModule._processGl(frameDataPtr) / 4; + this.wasmModule._free(frameDataPtr); + + // TODO: Hook up WebGLTexture output, when given. + // TODO: Allow user to toggle whether or not to render output into canvas. + return undefined; + } + + /** + * Converts JavaScript string input parameters into C++ c-string pointers. + * See b/204830158 for more details. Intended for internal usage. + */ + wrapStringPtr(stringData: string, stringPtrFunc: (ptr: number) => void): + void { + if (!this.hasMultiStreamSupport) { + console.error( + 'No wasm multistream support detected: ensure dependency ' + + 'inclusion of :gl_graph_runner_internal_multi_input target'); + } + const stringDataPtr = this.wasmModule.stringToNewUTF8(stringData); + stringPtrFunc(stringDataPtr); + this.wasmModule._free(stringDataPtr); + } + + /** + * Converts JavaScript string input parameters into C++ c-string pointers. + * See b/204830158 for more details. + */ + wrapStringPtrPtr(stringData: string[], ptrFunc: (ptr: number) => void): void { + if (!this.hasMultiStreamSupport) { + console.error( + 'No wasm multistream support detected: ensure dependency ' + + 'inclusion of :gl_graph_runner_internal_multi_input target'); + } + const uint32Array = new Uint32Array(stringData.length); + for (let i = 0; i < stringData.length; i++) { + uint32Array[i] = this.wasmModule.stringToNewUTF8(stringData[i]); + } + const heapSpace = this.wasmModule._malloc(uint32Array.length * 4); + this.wasmModule.HEAPU32.set(uint32Array, heapSpace >> 2); + + ptrFunc(heapSpace); + for (const uint32ptr of uint32Array) { + this.wasmModule._free(uint32ptr); + } + this.wasmModule._free(heapSpace); + } + + /** + * Ensures existence of the simple listeners table and registers the callback. + * Intended for internal usage. + */ + setListener(outputStreamName: string, callbackFcn: (data: T) => void) { + this.wasmModule.simpleListeners = this.wasmModule.simpleListeners || {}; + this.wasmModule.simpleListeners[outputStreamName] = + callbackFcn as (data: unknown) => void; + } + + /** + * Ensures existence of the vector listeners table and registers the callback. + * Intended for internal usage. + */ + setVectorListener( + outputStreamName: string, callbackFcn: (data: T[]) => void) { + const buffer: T[] = []; + this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; + this.wasmModule.vectorListeners[outputStreamName] = + (data: unknown, index: number, length: number) => { + // The Wasm listener gets invoked once for each element. Once we + // receive all elements, we invoke the registered callback with the + // full array. + buffer[index] = data as T; + if (index === length - 1) { + // Invoke the user callback directly, as the Wasm layer may clean up + // the underlying data elements once we leave the scope of the + // listener. + callbackFcn(buffer); + } + }; + } + + /** + * Attaches a listener that will be invoked when the MediaPipe framework + * returns an error. + */ + attachErrorListener(callbackFcn: (code: number, message: string) => void) { + this.wasmModule.errorListener = callbackFcn; + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStream(audioData: Float32Array, timestamp: number) { + // 4 bytes for each F32 + const size = audioData.length * 4; + if (this.audioSize !== size) { + if (this.audioPtr) { + this.wasmModule._free(this.audioPtr); + } + this.audioPtr = this.wasmModule._malloc(size); + this.audioSize = size; + } + this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); + this.wasmModule._processAudio(this.audioPtr!, timestamp); + } + + /** + * Takes the relevant information from the HTML video or image element, and + * passes it into the WebGL-based graph for processing on the given stream at + * the given timestamp. Can be used for additional auxiliary GpuBuffer input + * streams. Processing will not occur until a blocking call (like + * processVideoGl or finishProcessing) is made. For use with + * 'gl_graph_runner_internal_multi_input'. + * @param imageSource Reference to the video frame we wish to add into our + * graph. + * @param streamName The name of the MediaPipe graph stream to add the frame + * to. + * @param timestamp The timestamp of the input frame, in ms. + */ + addGpuBufferToStream( + imageSource: ImageSource, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + const [width, height] = + this.bindTextureToStream(imageSource, streamNamePtr); + this.wasmModule._addBoundTextureToStream( + streamNamePtr, width, height, timestamp); + }); + } + + /** + * Sends a boolean packet into the specified stream at the given timestamp. + * @param data The boolean data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addBoolToStream(data: boolean, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addBoolToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends a double packet into the specified stream at the given timestamp. + * @param data The double data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addDoubleToStream(data: number, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addDoubleToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends a float packet into the specified stream at the given timestamp. + * @param data The float data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addFloatToStream(data: number, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + // NOTE: _addFloatToStream and _addIntToStream are reserved for JS + // Calculators currently; we may want to revisit this naming scheme in the + // future. + this.wasmModule._addFloatToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends an integer packet into the specified stream at the given timestamp. + * @param data The integer data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addIntToStream(data: number, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addIntToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends a string packet into the specified stream at the given timestamp. + * @param data The string data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addStringToStream(data: string, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wrapStringPtr(data, (dataPtr: number) => { + this.wasmModule._addStringToInputStream( + dataPtr, streamNamePtr, timestamp); + }); + }); + } + + /** + * Sends a Record packet into the specified stream at the + * given timestamp. + * @param data The records to send (will become a + * std::flat_hash_map, streamName: string, + timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wrapStringPtrPtr(Object.keys(data), (keyList: number) => { + this.wrapStringPtrPtr(Object.values(data), (valueList: number) => { + this.wasmModule._addFlatHashMapToInputStream( + keyList, valueList, Object.keys(data).length, streamNamePtr, + timestamp); + }); + }); + }); + } + + /** + * Sends a serialized protobuffer packet into the specified stream at the + * given timestamp, to be parsed into the specified protobuffer type. + * @param data The binary (serialized) raw protobuffer data. + * @param protoType The C++ namespaced type this protobuffer data corresponds + * to. It will be converted to this type when output as a packet into the + * graph. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addProtoToStream( + data: Uint8Array, protoType: string, streamName: string, + timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wrapStringPtr(protoType, (protoTypePtr: number) => { + // Deep-copy proto data into Wasm heap + const dataPtr = this.wasmModule._malloc(data.length); + // TODO: Ensure this is the fastest way to copy this data. + this.wasmModule.HEAPU8.set(data, dataPtr); + this.wasmModule._addProtoToInputStream( + dataPtr, data.length, protoTypePtr, streamNamePtr, timestamp); + this.wasmModule._free(dataPtr); + }); + }); + } + + /** + * Attaches a boolean packet to the specified input_side_packet. + * @param data The boolean data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addBoolToInputSidePacket(data: boolean, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addBoolToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a double packet to the specified input_side_packet. + * @param data The double data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addDoubleToInputSidePacket(data: number, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addDoubleToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a float packet to the specified input_side_packet. + * @param data The float data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addFloatToInputSidePacket(data: number, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addFloatToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a integer packet to the specified input_side_packet. + * @param data The integer data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addIntToInputSidePacket(data: number, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addIntToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a string packet to the specified input_side_packet. + * @param data The string data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addStringToInputSidePacket(data: string, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wrapStringPtr(data, (dataPtr: number) => { + this.wasmModule._addStringToInputSidePacket(dataPtr, sidePacketNamePtr); + }); + }); + } + + /** + * Attaches a serialized proto packet to the specified input_side_packet. + * @param data The binary (serialized) raw protobuffer data. + * @param protoType The C++ namespaced type this protobuffer data corresponds + * to. It will be converted to this type for use in the graph. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addProtoToInputSidePacket( + data: Uint8Array, protoType: string, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wrapStringPtr(protoType, (protoTypePtr: number) => { + // Deep-copy proto data into Wasm heap + const dataPtr = this.wasmModule._malloc(data.length); + // TODO: Ensure this is the fastest way to copy this data. + this.wasmModule.HEAPU8.set(data, dataPtr); + this.wasmModule._addProtoToInputSidePacket( + dataPtr, data.length, protoTypePtr, sidePacketNamePtr); + this.wasmModule._free(dataPtr); + }); + }); + } + + /** + * Attaches a boolean packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab boolean + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachBoolListener( + outputStreamName: string, callbackFcn: (data: boolean) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for bool packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachBoolListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a bool[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachBoolVectorListener( + outputStreamName: string, callbackFcn: (data: boolean[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachBoolVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches an int packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab int + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachIntListener( + outputStreamName: string, callbackFcn: (data: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for int packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachIntListener(outputStreamNamePtr); + }); + } + + /** + * Attaches an int[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachIntVectorListener( + outputStreamName: string, callbackFcn: (data: number[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachIntVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a double packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab double + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachDoubleListener( + outputStreamName: string, callbackFcn: (data: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for double packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachDoubleListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a double[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachDoubleVectorListener( + outputStreamName: string, callbackFcn: (data: number[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachDoubleVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a float packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab float + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachFloatListener( + outputStreamName: string, callbackFcn: (data: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for float packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachFloatListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a float[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachFloatVectorListener( + outputStreamName: string, callbackFcn: (data: number[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachFloatVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a string packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab string + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachStringListener( + outputStreamName: string, callbackFcn: (data: string) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachStringListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a string[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachStringVectorListener( + outputStreamName: string, callbackFcn: (data: string[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachStringVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a serialized proto packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab binary + * serialized proto data from (in Uint8Array format). + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that by default the data is only guaranteed to + * exist for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. If the proto data needs to be able to outlive the call, you + * may set the optional makeDeepCopy parameter to true, or can manually + * deep-copy the data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). + */ + attachProtoListener( + outputStreamName: string, callbackFcn: (data: Uint8Array) => void, + makeDeepCopy?: boolean): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for binary serialized proto data packets on this + // stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachProtoListener( + outputStreamNamePtr, makeDeepCopy || false); + }); + } + + /** + * Attaches a listener for an array of serialized proto packets to the + * specified output_stream. + * @param outputStreamName The name of the graph output stream to grab a + * vector of binary serialized proto data from (in Uint8Array[] format). + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that by default the data is only guaranteed to + * exist for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. If the proto data needs to be able to outlive the call, you + * may set the optional makeDeepCopy parameter to true, or can manually + * deep-copy the data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). + */ + attachProtoVectorListener( + outputStreamName: string, callbackFcn: (data: Uint8Array[]) => void, + makeDeepCopy?: boolean): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for a vector of binary serialized proto packets + // on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachProtoVectorListener( + outputStreamNamePtr, makeDeepCopy || false); + }); + } + + /** + * Sets a listener to be called back with audio output packet data, as a + * Float32Array, when graph has finished processing it. + * @param audioOutputListener The caller's listener function. + */ + setOnAudioOutput(audioOutputListener: AudioOutputListener) { + this.wasmModule.onAudioOutput = audioOutputListener; + if (!this.wasmModule._attachAudioOutputListener) { + console.warn( + 'Attempting to use AudioOutputListener without support for ' + + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); + } + } + + /** + * Forces all queued-up packets to be pushed through the MediaPipe graph as + * far as possible, performing all processing until no more processing can be + * done. + */ + finishProcessing(): void { + this.wasmModule._waitUntilIdle(); + } +} + +// Quick private helper to run the given script safely +async function runScript(scriptUrl: string) { + if (typeof importScripts === 'function') { + importScripts(scriptUrl.toString()); + } else { + await new Promise((resolve, reject) => { + fetch(scriptUrl).then(response => response.text()).then(text => Function(text)).then(resolve, reject); + }); + } +} + +/** + * Global function to initialize Wasm blob and load runtime assets for a + * specialized MediaPipe library. This allows us to create a requested + * subclass inheriting from WasmMediaPipeLib. + * @param constructorFcn The name of the class to instantiate via "new". + * @param wasmLoaderScript Url for the wasm-runner script; produced by the build + * process. + * @param assetLoaderScript Url for the asset-loading script; produced by the + * build process. + * @param fileLocator A function to override the file locations for assets + * loaded by the MediaPipe library. + * @return promise A promise which will resolve when initialization has + * completed successfully. + */ +export async function createMediaPipeLib( + constructorFcn: WasmMediaPipeConstructor, + wasmLoaderScript?: string, + assetLoaderScript?: string, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + fileLocator?: FileLocator): Promise { + const scripts = []; + // Run wasm-loader script here + if (wasmLoaderScript) { + scripts.push(wasmLoaderScript); + } + // Run asset-loader script here + if (assetLoaderScript) { + scripts.push(assetLoaderScript); + } + // Load scripts in parallel, browser will execute them in sequence. + if (scripts.length) { + await Promise.all(scripts.map(runScript)); + } + if (!self.ModuleFactory) { + throw new Error('ModuleFactory not set.'); + } + // TODO: Ensure that fileLocator is passed in by all users + // and make it required + const module = + await self.ModuleFactory(fileLocator || self.Module as FileLocator); + // Don't reuse factory or module seed + self.ModuleFactory = self.Module = undefined; + return new constructorFcn(module, glCanvas); +} + +/** + * Global function to initialize Wasm blob and load runtime assets for a generic + * MediaPipe library. + * @param wasmLoaderScript Url for the wasm-runner script; produced by the build + * process. + * @param assetLoaderScript Url for the asset-loading script; produced by the + * build process. + * @param fileLocator A function to override the file locations for assets + * loaded by the MediaPipe library. + * @return promise A promise which will resolve when initialization has + * completed successfully. + */ +export async function createWasmMediaPipeLib( + wasmLoaderScript?: string, + assetLoaderScript?: string, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + fileLocator?: FileLocator): Promise { + return createMediaPipeLib( + WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + fileLocator); +}