Merge branch 'master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-08 02:53:20 +05:30 committed by GitHub
commit ba1ee5b404
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 4040 additions and 189 deletions

View File

@ -936,6 +936,7 @@ cc_test(
"//mediapipe/framework/tool:simulation_clock", "//mediapipe/framework/tool:simulation_clock",
"//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:simulation_clock_executor",
"//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:sink",
"//mediapipe/util:packet_test_util",
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
], ],
) )

View File

@ -18,7 +18,6 @@
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/header_util.h" #include "mediapipe/util/header_util.h"
@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS";
// FlowLimiterCalculator provides limited support for multiple input streams. // FlowLimiterCalculator provides limited support for multiple input streams.
// The first input stream is treated as the main input stream and successive // The first input stream is treated as the main input stream and successive
// input streams are treated as auxiliary input streams. The auxiliary input // input streams are treated as auxiliary input streams. The auxiliary input
// streams are limited to timestamps passed on the main input stream. // streams are limited to timestamps allowed by the "ALLOW" stream.
// //
class FlowLimiterCalculator : public CalculatorBase { class FlowLimiterCalculator : public CalculatorBase {
public: public:
@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase {
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>()); cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
} }
input_queues_.resize(cc->Inputs().NumEntries("")); input_queues_.resize(cc->Inputs().NumEntries(""));
allowed_[Timestamp::Unset()] = true;
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
return absl::OkStatus(); return absl::OkStatus();
} }
// Returns true if an additional frame can be released for processing.
// The "ALLOW" output stream indicates this condition at each input frame.
bool ProcessingAllowed() {
return frames_in_flight_.size() < options_.max_in_flight();
}
// Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
}
}
// Sets the timestamp bound or closes an output stream.
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
if (bound > Timestamp::Max()) {
stream->Close();
} else {
stream->SetNextTimestampBound(bound);
}
}
// Returns true if a certain timestamp is being processed.
bool IsInFlight(Timestamp timestamp) {
return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
timestamp) != frames_in_flight_.end();
}
// Releases input packets up to the latest settled input timestamp.
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
// Release settled frames from each input queue.
while (!input_queues_[i].empty() &&
input_queues_[i].front().Timestamp() < settled_bound) {
Packet packet = input_queues_[i].front();
input_queues_[i].pop_front();
if (IsInFlight(packet.Timestamp())) {
cc->Outputs().Get("", i).AddPacket(packet);
}
}
// Propagate each input timestamp bound.
if (!input_queues_[i].empty()) {
Timestamp bound = input_queues_[i].front().Timestamp();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
} else {
Timestamp bound =
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
}
}
}
// Releases input packets allowed by the max_in_flight constraint. // Releases input packets allowed by the max_in_flight constraint.
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
options_ = tool::RetrieveOptions(options_, cc->Inputs()); options_ = tool::RetrieveOptions(options_, cc->Inputs());
@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase {
} }
ProcessAuxiliaryInputs(cc); ProcessAuxiliaryInputs(cc);
// Discard old ALLOW ranges.
Timestamp input_bound = InputTimestampBound(cc);
auto first_range = std::prev(allowed_.upper_bound(input_bound));
allowed_.erase(allowed_.begin(), first_range);
return absl::OkStatus(); return absl::OkStatus();
} }
int LedgerSize() {
int result = frames_in_flight_.size() + allowed_.size();
for (const auto& queue : input_queues_) {
result += queue.size();
}
return result;
}
private:
// Returns true if an additional frame can be released for processing.
// The "ALLOW" output stream indicates this condition at each input frame.
bool ProcessingAllowed() {
return frames_in_flight_.size() < options_.max_in_flight();
}
// Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
}
allowed_[ts] = allow;
}
// Returns true if a timestamp falls within a range of allowed timestamps.
bool IsAllowed(Timestamp timestamp) {
auto it = allowed_.upper_bound(timestamp);
return std::prev(it)->second;
}
// Sets the timestamp bound or closes an output stream.
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
if (bound > Timestamp::Max()) {
stream->Close();
} else {
stream->SetNextTimestampBound(bound);
}
}
// Returns the lowest unprocessed input Timestamp.
Timestamp InputTimestampBound(CalculatorContext* cc) {
Timestamp result = Timestamp::Done();
for (int i = 0; i < input_queues_.size(); ++i) {
auto& queue = input_queues_[i];
auto& stream = cc->Inputs().Get("", i);
Timestamp bound = queue.empty()
? stream.Value().Timestamp().NextAllowedInStream()
: queue.front().Timestamp();
result = std::min(result, bound);
}
return result;
}
// Releases input packets up to the latest settled input timestamp.
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
// Release settled frames from each input queue.
while (!input_queues_[i].empty() &&
input_queues_[i].front().Timestamp() < settled_bound) {
Packet packet = input_queues_[i].front();
input_queues_[i].pop_front();
if (IsAllowed(packet.Timestamp())) {
cc->Outputs().Get("", i).AddPacket(packet);
}
}
// Propagate each input timestamp bound.
if (!input_queues_[i].empty()) {
Timestamp bound = input_queues_[i].front().Timestamp();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
} else {
Timestamp bound =
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
}
}
}
private: private:
FlowLimiterCalculatorOptions options_; FlowLimiterCalculatorOptions options_;
std::vector<std::deque<Packet>> input_queues_; std::vector<std::deque<Packet>> input_queues_;
std::deque<Timestamp> frames_in_flight_; std::deque<Timestamp> frames_in_flight_;
std::map<Timestamp, bool> allowed_;
}; };
REGISTER_CALCULATOR(FlowLimiterCalculator); REGISTER_CALCULATOR(FlowLimiterCalculator);

View File

@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "absl/time/clock.h" #include "absl/time/clock.h"
@ -32,6 +33,7 @@
#include "mediapipe/framework/tool/simulation_clock.h" #include "mediapipe/framework/tool/simulation_clock.h"
#include "mediapipe/framework/tool/simulation_clock_executor.h" #include "mediapipe/framework/tool/simulation_clock_executor.h"
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include "mediapipe/util/packet_test_util.h"
namespace mediapipe { namespace mediapipe {
@ -77,6 +79,77 @@ std::vector<T> PacketValues(const std::vector<Packet>& packets) {
return result; return result;
} }
template <typename T>
std::vector<Packet> MakePackets(std::vector<std::pair<Timestamp, T>> contents) {
std::vector<Packet> result;
for (auto& entry : contents) {
result.push_back(MakePacket<T>(entry.second).At(entry.first));
}
return result;
}
std::string SourceString(Timestamp t) {
return (t.IsSpecialValue())
? t.DebugString()
: absl::StrCat("Timestamp(", t.DebugString(), ")");
}
template <typename PacketContainer, typename PacketContent>
class PacketsEqMatcher
: public ::testing::MatcherInterface<const PacketContainer&> {
public:
PacketsEqMatcher(PacketContainer packets) : packets_(packets) {}
void DescribeTo(::std::ostream* os) const override {
*os << "The expected packet contents: \n";
Print(packets_, os);
}
bool MatchAndExplain(
const PacketContainer& value,
::testing::MatchResultListener* listener) const override {
if (!Equals(packets_, value)) {
if (listener->IsInterested()) {
*listener << "The actual packet contents: \n";
Print(value, listener->stream());
}
return false;
}
return true;
}
private:
bool Equals(const PacketContainer& c1, const PacketContainer& c2) const {
if (c1.size() != c2.size()) {
return false;
}
for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) {
Packet p1 = *i1, p2 = *i2;
if (p1.Timestamp() != p2.Timestamp() ||
p1.Get<PacketContent>() != p2.Get<PacketContent>()) {
return false;
}
}
return true;
}
void Print(const PacketContainer& packets, ::std::ostream* os) const {
for (auto it = packets.begin(); it != packets.end(); ++it) {
const Packet& packet = *it;
*os << (it == packets.begin() ? "{" : "") << "{"
<< SourceString(packet.Timestamp()) << ", "
<< packet.Get<PacketContent>() << "}"
<< (std::next(it) == packets.end() ? "}" : ", ");
}
}
const PacketContainer packets_;
};
template <typename PacketContainer, typename PacketContent>
::testing::Matcher<const PacketContainer&> PackestEq(
const PacketContainer& packets) {
return MakeMatcher(
new PacketsEqMatcher<PacketContainer, PacketContent>(packets));
}
// A Calculator::Process callback function. // A Calculator::Process callback function.
typedef std::function<absl::Status(const InputStreamShardSet&, typedef std::function<absl::Status(const InputStreamShardSet&,
OutputStreamShardSet*)> OutputStreamShardSet*)>
@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
input_packets_[17], input_packets_[19], input_packets_[20], input_packets_[17], input_packets_[19], input_packets_[20],
}; };
EXPECT_EQ(out_1_packets_, expected_output); EXPECT_EQ(out_1_packets_, expected_output);
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. // The timestamps released by FlowLimiterCalculator for in_1_sampled,
// plus input_packets_[21].
std::vector<Packet> expected_output_2 = { std::vector<Packet> expected_output_2 = {
input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[0], input_packets_[2], input_packets_[4],
input_packets_[14], input_packets_[17], input_packets_[19], input_packets_[14], input_packets_[17], input_packets_[19],
input_packets_[20], input_packets_[20], input_packets_[21],
}; };
EXPECT_EQ(out_2_packets, expected_output_2); EXPECT_EQ(out_2_packets, expected_output_2);
} }
@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
// The processing time "sleep_time" is reduced from 22ms to 12ms to create // The processing time "sleep_time" is reduced from 22ms to 12ms to create
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
// Configure the test. // Configure the test.
SetUpInputData(); SetUpInputData();
SetUpSimulationClock(); SetUpSimulationClock();
@ -699,10 +776,9 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
} }
)pb"); )pb");
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb( auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
max_in_flight: 1 R"pb(
max_in_queue: 0 max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000 # 100 ms
in_flight_timeout: 100000 # 100 ms
)pb"); )pb");
std::map<std::string, Packet> side_packets = { std::map<std::string, Packet> side_packets = {
{"limiter_options", {"limiter_options",
@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[0], input_packets_[2], input_packets_[15],
input_packets_[17], input_packets_[19], input_packets_[17], input_packets_[19],
}; };
EXPECT_EQ(out_1_packets_, expected_output); EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
std::vector<Packet> expected_output_2 = { std::vector<Packet> expected_output_2 = {
input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[0], input_packets_[2], input_packets_[4],
input_packets_[15], input_packets_[17], input_packets_[19], input_packets_[15], input_packets_[17], input_packets_[19],
}; };
EXPECT_EQ(out_2_packets, expected_output_2); EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
// Validate the ALLOW stream output.
std::vector<Packet> expected_allow = MakePackets<bool>( //
{{Timestamp(0), true}, {Timestamp(10000), false},
{Timestamp(20000), true}, {Timestamp(30000), false},
{Timestamp(40000), true}, {Timestamp(50000), false},
{Timestamp(60000), false}, {Timestamp(70000), false},
{Timestamp(80000), false}, {Timestamp(90000), false},
{Timestamp(100000), false}, {Timestamp(110000), false},
{Timestamp(120000), false}, {Timestamp(130000), false},
{Timestamp(140000), false}, {Timestamp(150000), true},
{Timestamp(160000), false}, {Timestamp(170000), true},
{Timestamp(180000), false}, {Timestamp(190000), true},
{Timestamp(200000), false}});
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
}
// Shows how FlowLimiterCalculator releases auxiliary input packets.
// In this test, auxiliary input packets arrive at twice the primary rate.
TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in_1'
input_stream: 'in_2'
node {
calculator: 'FlowLimiterCalculator'
input_side_packet: 'OPTIONS:limiter_options'
input_stream: 'in_1'
input_stream: 'in_2'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'in_2_sampled'
output_stream: 'ALLOW:allow'
}
node {
calculator: 'SleepCalculator'
input_side_packet: 'WARMUP_TIME:warmup_time'
input_side_packet: 'SLEEP_TIME:sleep_time'
input_side_packet: 'CLOCK:clock'
input_stream: 'PACKET:in_1_sampled'
output_stream: 'PACKET:out_1'
}
)pb");
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
R"pb(
max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000 # 1s
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(22000)},
{"sleep_time", MakePacket<int64>(22000)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
std::vector<Packet> out_2_packets;
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
out_2_packets.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2.
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 10; ++i) {
if (i % 2 == 0) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
}
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
}
// Finish the graph run.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// Input packets 4 and 8 are dropped due to max_in_flight.
std::vector<Packet> expected_output = {
input_packets_[2],
input_packets_[6],
};
EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
// Packets following input packets 2 and 6, and not input packets 4 and 8.
std::vector<Packet> expected_output_2 = {
input_packets_[1], input_packets_[2], input_packets_[3],
input_packets_[6], input_packets_[7],
};
EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
// Validate the ALLOW stream output.
std::vector<Packet> expected_allow =
MakePackets<bool>({{Timestamp(20000), 1},
{Timestamp(40000), 0},
{Timestamp(60000), 1},
{Timestamp(80000), 0}});
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
} }
} // anonymous namespace } // anonymous namespace

View File

@ -14,45 +14,43 @@
#include "mediapipe/framework/deps/status_builder.h" #include "mediapipe/framework/deps/status_builder.h"
#include <memory>
#include <sstream>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_cat.h"
namespace mediapipe { namespace mediapipe {
StatusBuilder::StatusBuilder(const StatusBuilder& sb) { StatusBuilder::StatusBuilder(const StatusBuilder& sb)
status_ = sb.status_; : impl_(sb.impl_ ? std::make_unique<Impl>(*sb.impl_) : nullptr) {}
file_ = sb.file_;
line_ = sb.line_;
no_logging_ = sb.no_logging_;
stream_ = sb.stream_
? absl::make_unique<std::ostringstream>(sb.stream_->str())
: nullptr;
join_style_ = sb.join_style_;
}
StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
status_ = sb.status_; if (!sb.impl_) {
file_ = sb.file_; impl_ = nullptr;
line_ = sb.line_; return *this;
no_logging_ = sb.no_logging_; }
stream_ = sb.stream_ if (impl_) {
? absl::make_unique<std::ostringstream>(sb.stream_->str()) *impl_ = *sb.impl_;
: nullptr; return *this;
join_style_ = sb.join_style_; }
impl_ = std::make_unique<Impl>(*sb.impl_);
return *this; return *this;
} }
StatusBuilder& StatusBuilder::SetAppend() & { StatusBuilder& StatusBuilder::SetAppend() & {
if (status_.ok()) return *this; if (!impl_) return *this;
join_style_ = MessageJoinStyle::kAppend; impl_->join_style = Impl::MessageJoinStyle::kAppend;
return *this; return *this;
} }
StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); } StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); }
StatusBuilder& StatusBuilder::SetPrepend() & { StatusBuilder& StatusBuilder::SetPrepend() & {
if (status_.ok()) return *this; if (!impl_) return *this;
join_style_ = MessageJoinStyle::kPrepend; impl_->join_style = Impl::MessageJoinStyle::kPrepend;
return *this; return *this;
} }
@ -61,7 +59,8 @@ StatusBuilder&& StatusBuilder::SetPrepend() && {
} }
StatusBuilder& StatusBuilder::SetNoLogging() & { StatusBuilder& StatusBuilder::SetNoLogging() & {
no_logging_ = true; if (!impl_) return *this;
impl_->no_logging = true;
return *this; return *this;
} }
@ -70,34 +69,72 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
} }
StatusBuilder::operator Status() const& { StatusBuilder::operator Status() const& {
if (!stream_ || stream_->str().empty() || no_logging_) {
return status_;
}
return StatusBuilder(*this).JoinMessageToStatus(); return StatusBuilder(*this).JoinMessageToStatus();
} }
StatusBuilder::operator Status() && { StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
if (!stream_ || stream_->str().empty() || no_logging_) {
return status_;
}
return JoinMessageToStatus();
}
absl::Status StatusBuilder::JoinMessageToStatus() { absl::Status StatusBuilder::JoinMessageToStatus() {
if (!stream_) { if (!impl_) {
return absl::OkStatus(); return absl::OkStatus();
} }
std::string message; return impl_->JoinMessageToStatus();
if (join_style_ == MessageJoinStyle::kAnnotate) { }
if (!status_.ok()) {
message = absl::StrCat(status_.message(), "; ", stream_->str()); absl::Status StatusBuilder::Impl::JoinMessageToStatus() {
if (stream.str().empty() || no_logging) {
return status;
} }
} else { return absl::Status(status.code(), [this]() {
message = join_style_ == MessageJoinStyle::kPrepend switch (join_style) {
? absl::StrCat(stream_->str(), status_.message()) case MessageJoinStyle::kAnnotate:
: absl::StrCat(status_.message(), stream_->str()); return absl::StrCat(status.message(), "; ", stream.str());
case MessageJoinStyle::kAppend:
return absl::StrCat(status.message(), stream.str());
case MessageJoinStyle::kPrepend:
return absl::StrCat(stream.str(), status.message());
} }
return Status(status_.code(), message); }());
}
StatusBuilder::Impl::Impl(const absl::Status& status, const char* file,
int line)
: status(status), line(line), file(file), stream() {}
StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line)
: status(std::move(status)), line(line), file(file), stream() {}
StatusBuilder::Impl::Impl(const absl::Status& status,
mediapipe::source_location location)
: status(status),
line(location.line()),
file(location.file_name()),
stream() {}
StatusBuilder::Impl::Impl(absl::Status&& status,
mediapipe::source_location location)
: status(std::move(status)),
line(location.line()),
file(location.file_name()),
stream() {}
StatusBuilder::Impl::Impl(const Impl& other)
: status(other.status),
line(other.line),
file(other.file),
no_logging(other.no_logging),
stream(other.stream.str()),
join_style(other.join_style) {}
StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) {
status = other.status;
line = other.line;
file = other.file;
no_logging = other.no_logging;
stream = std::ostringstream(other.stream.str());
join_style = other.join_style;
return *this;
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -22,7 +22,6 @@
#include "absl/base/attributes.h" #include "absl/base/attributes.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/source_location.h" #include "mediapipe/framework/deps/source_location.h"
#include "mediapipe/framework/deps/status.h" #include "mediapipe/framework/deps/status.h"
@ -42,34 +41,37 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
// occurs. A typical user will call this with `MEDIAPIPE_LOC`. // occurs. A typical user will call this with `MEDIAPIPE_LOC`.
StatusBuilder(const absl::Status& original_status, StatusBuilder(const absl::Status& original_status,
mediapipe::source_location location) mediapipe::source_location location)
: status_(original_status), : impl_(original_status.ok()
line_(location.line()), ? nullptr
file_(location.file_name()), : std::make_unique<Impl>(original_status, location)) {}
stream_(InitStream(status_)) {}
StatusBuilder(absl::Status&& original_status, StatusBuilder(absl::Status&& original_status,
mediapipe::source_location location) mediapipe::source_location location)
: status_(std::move(original_status)), : impl_(original_status.ok()
line_(location.line()), ? nullptr
file_(location.file_name()), : std::make_unique<Impl>(std::move(original_status),
stream_(InitStream(status_)) {} location)) {}
// Creates a `StatusBuilder` from a mediapipe status code. If logging is // Creates a `StatusBuilder` from a mediapipe status code. If logging is
// enabled, it will use `location` as the location from which the log message // enabled, it will use `location` as the location from which the log message
// occurs. A typical user will call this with `MEDIAPIPE_LOC`. // occurs. A typical user will call this with `MEDIAPIPE_LOC`.
StatusBuilder(absl::StatusCode code, mediapipe::source_location location) StatusBuilder(absl::StatusCode code, mediapipe::source_location location)
: status_(code, ""), : impl_(code == absl::StatusCode::kOk
line_(location.line()), ? nullptr
file_(location.file_name()), : std::make_unique<Impl>(absl::Status(code, ""), location)) {}
stream_(InitStream(status_)) {}
StatusBuilder(const absl::Status& original_status, const char* file, int line) StatusBuilder(const absl::Status& original_status, const char* file, int line)
: status_(original_status), : impl_(original_status.ok()
line_(line), ? nullptr
file_(file), : std::make_unique<Impl>(original_status, file, line)) {}
stream_(InitStream(status_)) {}
bool ok() const { return status_.ok(); } StatusBuilder(absl::Status&& original_status, const char* file, int line)
: impl_(original_status.ok()
? nullptr
: std::make_unique<Impl>(std::move(original_status), file,
line)) {}
bool ok() const { return !impl_; }
StatusBuilder& SetAppend() &; StatusBuilder& SetAppend() &;
StatusBuilder&& SetAppend() &&; StatusBuilder&& SetAppend() &&;
@ -82,8 +84,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
template <typename T> template <typename T>
StatusBuilder& operator<<(const T& msg) & { StatusBuilder& operator<<(const T& msg) & {
if (!stream_) return *this; if (!impl_) return *this;
*stream_ << msg; impl_->stream << msg;
return *this; return *this;
} }
@ -98,6 +100,7 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
absl::Status JoinMessageToStatus(); absl::Status JoinMessageToStatus();
private: private:
struct Impl {
// Specifies how to join the error message in the original status and any // Specifies how to join the error message in the original status and any
// additional message that has been streamed into the builder. // additional message that has been streamed into the builder.
enum class MessageJoinStyle { enum class MessageJoinStyle {
@ -106,27 +109,33 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
kPrepend, kPrepend,
}; };
// Conditionally creates an ostringstream if the status is not ok. Impl(const absl::Status& status, const char* file, int line);
static std::unique_ptr<std::ostringstream> InitStream( Impl(absl::Status&& status, const char* file, int line);
const absl::Status status) { Impl(const absl::Status& status, mediapipe::source_location location);
if (status.ok()) { Impl(absl::Status&& status, mediapipe::source_location location);
return nullptr; Impl(const Impl&);
} Impl& operator=(const Impl&);
return absl::make_unique<std::ostringstream>();
} absl::Status JoinMessageToStatus();
// The status that the result will be based on. // The status that the result will be based on.
absl::Status status_; absl::Status status;
// The line to record if this file is logged. // The line to record if this file is logged.
int line_; int line;
// Not-owned: The file to record if this status is logged. // Not-owned: The file to record if this status is logged.
const char* file_; const char* file;
bool no_logging_ = false; // Logging disabled if true.
bool no_logging = false;
// The additional messages added with `<<`. This is nullptr when status_ is // The additional messages added with `<<`. This is nullptr when status_ is
// ok. // ok.
std::unique_ptr<std::ostringstream> stream_; std::ostringstream stream;
// Specifies how to join the message in `status_` and `stream_`. // Specifies how to join the message in `status_` and `stream_`.
MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate; MessageJoinStyle join_style = MessageJoinStyle::kAnnotate;
};
// Internal store of data for the class. An invariant of the class is that
// this is null when the original status is okay, and not-null otherwise.
std::unique_ptr<Impl> impl_;
}; };
inline StatusBuilder AlreadyExistsErrorBuilder( inline StatusBuilder AlreadyExistsErrorBuilder(

View File

@ -33,6 +33,21 @@ TEST(StatusBuilder, OkStatusRvalue) {
ASSERT_EQ(status, absl::OkStatus()); ASSERT_EQ(status, absl::OkStatus());
} }
TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) {
absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234)
<< "annotated message1 "
<< "annotated message2";
ASSERT_EQ(status, absl::OkStatus());
}
TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) {
const auto original_status = absl::OkStatus();
absl::Status status = StatusBuilder(original_status, "hello.cc", 1234)
<< "annotated message1 "
<< "annotated message2";
ASSERT_EQ(status, absl::OkStatus());
}
TEST(StatusBuilder, AnnotateMode) { TEST(StatusBuilder, AnnotateMode) {
absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
"original message"), "original message"),
@ -45,6 +60,30 @@ TEST(StatusBuilder, AnnotateMode) {
"original message; annotated message1 annotated message2"); "original message; annotated message1 annotated message2");
} }
TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) {
absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
"original message"),
"hello.cc", 1234)
<< "annotated message1 "
<< "annotated message2";
ASSERT_FALSE(status.ok());
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_EQ(status.message(),
"original message; annotated message1 annotated message2");
}
TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) {
const auto original_status =
absl::Status(absl::StatusCode::kNotFound, "original message");
absl::Status status = StatusBuilder(original_status, "hello.cc", 1234)
<< "annotated message1 "
<< "annotated message2";
ASSERT_FALSE(status.ok());
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_EQ(status.message(),
"original message; annotated message1 annotated message2");
}
TEST(StatusBuilder, PrependModeLvalue) { TEST(StatusBuilder, PrependModeLvalue) {
StatusBuilder builder( StatusBuilder builder(
absl::Status(absl::StatusCode::kInvalidArgument, "original message"), absl::Status(absl::StatusCode::kInvalidArgument, "original message"),

View File

@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool, def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
full_train: bool):
"""Initilizes a classifier with its specifications. """Initilizes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
label_names: A list of label names for the classes. label_names: A list of label names for the classes.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top
classification layers.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._label_names = label_names self._label_names = label_names
self._full_train = full_train
self._num_classes = len(label_names) self._num_classes = len(label_names)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:

View File

@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase):
super(ClassifierTest, self).setUp() super(ClassifierTest, self).setUp()
label_names = ['cat', 'dog'] label_names = ['cat', 'dog']
self.model = MockClassifier( self.model = MockClassifier(
model_spec=None, model_spec=None, label_names=label_names, shuffle=False)
label_names=label_names,
shuffle=False,
full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2) self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath): def _check_nonempty_file(self, filepath):

View File

@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier):
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
""" """
super().__init__( super().__init__(
model_spec=model_spec, model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
label_names=label_names,
shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning)
self._hparams = hparams self._hparams = hparams
self._preprocess = image_preprocessing.Preprocessor( self._preprocess = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape, input_shape=self._model_spec.input_image_shape,

View File

@ -93,6 +93,13 @@ cc_library(
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
], ],
] + select({
# TODO: Build text_classifier_graph on Windows.
"//mediapipe:windows": [],
"//conditions:default": [
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
],
}),
) )
py_library( py_library(

View File

@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) {
// TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&" // TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
// as the input argument. Investigate why bazel non-optimized mode // as the input argument. Investigate why bazel non-optimized mode
// triggers a memory allocation bug in Eigen::internal::aligned_free(). // triggers a memory allocation bug in Eigen::internal::aligned_free().
[](const Eigen::MatrixXf& matrix) { [](const Eigen::MatrixXf& matrix, bool transpose) {
// MakePacket copies the data. // MakePacket copies the data.
if (transpose) {
return MakePacket<Matrix>(matrix.transpose());
}
return MakePacket<Matrix>(matrix); return MakePacket<Matrix>(matrix);
}, },
R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray. R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) {
Args: Args:
matrix: A 2d numpy float ndarray. matrix: A 2d numpy float ndarray.
transpose: A boolean to indicate if the input matrix needs to be transposed.
Default to False.
Returns: Returns:
A MediaPipe Matrix Packet. A MediaPipe Matrix Packet.
@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) {
np.array([[.1, .2, .3], [.4, .5, .6]]) np.array([[.1, .2, .3], [.4, .5, .6]])
matrix = mp.packet_getter.get_matrix(packet) matrix = mp.packet_getter.get_matrix(packet)
)doc", )doc",
py::arg("matrix"), py::arg("transpose") = false,
py::return_value_policy::move); py::return_value_policy::move);
} // NOLINT(readability/fn_size) } // NOLINT(readability/fn_size)

View File

@ -31,13 +31,13 @@ namespace core {
// Base options for MediaPipe C++ Tasks. // Base options for MediaPipe C++ Tasks.
struct BaseOptions { struct BaseOptions {
// The model asset file contents as as string. // The model asset file contents as a string.
std::unique_ptr<std::string> model_asset_buffer; std::unique_ptr<std::string> model_asset_buffer;
// The path to the model asset to open and mmap in memory. // The path to the model asset to open and mmap in memory.
std::string model_asset_path = ""; std::string model_asset_path = "";
// The delegate to run MediaPipe. If the delegate is not set, default // The delegate to run MediaPipe. If the delegate is not set, the default
// delegate CPU is used. // delegate CPU is used.
enum Delegate { enum Delegate {
CPU = 0, CPU = 0,

View File

@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
hand_gesture_subgraph[Output<std::vector<ClassificationList>>( hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
kHandGesturesTag)]; kHandGesturesTag)];
return {{.gesture = hand_gestures, return GestureRecognizerOutputs{
.handedness = handedness, /*gesture=*/hand_gestures,
.hand_landmarks = hand_landmarks, /*handedness=*/handedness,
.hand_world_landmarks = hand_world_landmarks, /*hand_landmarks=*/hand_landmarks,
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}}; /*hand_world_landmarks=*/hand_world_landmarks,
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
} }
// Populate normalized non rotated face bounding box // Populate normalized non rotated face bounding box
return {.left = bounding_box_left, return Rect{/*left=*/bounding_box_left,
.top = bounding_box_top, /*top=*/bounding_box_top,
.right = bounding_box_right, /*right=*/bounding_box_right,
.bottom = bounding_box_bottom}; /*bottom=*/bounding_box_bottom};
} }
// Uses IoU and distance of some corresponding hand landmarks to detect // Uses IoU and distance of some corresponding hand landmarks to detect

View File

@ -48,7 +48,7 @@ public abstract class BaseOptions {
public abstract Builder setModelAssetBuffer(ByteBuffer value); public abstract Builder setModelAssetBuffer(ByteBuffer value);
/** /**
* Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default * Sets device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
* delegate CPU is used. * delegate CPU is used.
*/ */
public abstract Builder setDelegate(Delegate delegate); public abstract Builder setDelegate(Delegate delegate);

View File

@ -68,7 +68,11 @@ class AudioData(object):
ValueError: If the input array has an incorrect shape or if ValueError: If the input array has an incorrect shape or if
`offset` + `size` exceeds the length of the `src` array. `offset` + `size` exceeds the length of the `src` array.
""" """
if src.shape[1] != self._audio_format.num_channels: if len(src.shape) == 1:
if self._audio_format.num_channels != 1:
raise ValueError(f"Input audio is mono, but the audio data is expected "
f"to have {self._audio_format.num_channels} channels.")
elif src.shape[1] != self._audio_format.num_channels:
raise ValueError(f"Input audio contains an invalid number of channels. " raise ValueError(f"Input audio contains an invalid number of channels. "
f"Expect {self._audio_format.num_channels}.") f"Expect {self._audio_format.num_channels}.")
@ -93,6 +97,28 @@ class AudioData(object):
self._buffer = np.roll(self._buffer, -shift, axis=0) self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = src[offset:offset + size].copy() self._buffer[-shift:, :] = src[offset:offset + size].copy()
@classmethod
def create_from_array(cls,
src: np.ndarray,
sample_rate: Optional[float] = None) -> "AudioData":
"""Creates an `AudioData` object from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
sample_rate: the optional audio sample rate.
Returns:
An `AudioData` object that contains a copy of the NumPy source array as
the data.
"""
obj = cls(
buffer_length=src.shape[0],
audio_format=AudioFormat(
num_channels=1 if len(src.shape) == 1 else src.shape[1],
sample_rate=sample_rate))
obj.load_from_array(src)
return obj
@property @property
def audio_format(self) -> AudioFormat: def audio_format(self) -> AudioFormat:
"""Gets the audio format of the audio.""" """Gets the audio format of the audio."""

View File

@ -157,7 +157,7 @@ def _normalize_number_fields(pb):
descriptor.FieldDescriptor.TYPE_ENUM): descriptor.FieldDescriptor.TYPE_ENUM):
normalized_values = [int(x) for x in values] normalized_values = [int(x) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
normalized_values = [round(x, 5) for x in values] normalized_values = [round(x, 4) for x in values]
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
normalized_values = [round(float(x), 6) for x in values] normalized_values = [round(float(x), 6) for x in values]

View File

@ -0,0 +1,36 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.py"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
deps = [
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/text:text_classifier",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,244 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for text classifier."""
import enum
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.text import text_classifier
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
_TextClassifierResult = classifications_module.ClassificationResult
_TextClassifier = text_classifier.TextClassifier
_TextClassifierOptions = text_classifier.TextClassifierOptions
_BERT_MODEL_FILE = 'bert_text_classifier.tflite'
_REGEX_MODEL_FILE = 'test_model_text_classifier_with_regex_tokenizer.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
_NEGATIVE_TEXT = 'What a waste of my time.'
_POSITIVE_TEXT = ('This is the best movie Ive seen in recent years.'
'Strongly recommend it!')
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=0,
score=0.999479,
display_name='',
category_name='negative'),
_Category(
index=1,
score=0.00052154,
display_name='',
category_name='positive')
],
timestamp_ms=0)
],
head_index=0,
head_name='probability')
])
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=1,
score=0.999466,
display_name='',
category_name='positive'),
_Category(
index=0,
score=0.000533596,
display_name='',
category_name='negative')
],
timestamp_ms=0)
],
head_index=0,
head_name='probability')
])
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=0,
score=0.81313,
display_name='',
category_name='Negative'),
_Category(
index=1,
score=0.1868704,
display_name='',
category_name='Positive')
],
timestamp_ms=0)
],
head_index=0,
head_name='probability')
])
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=1,
score=0.5134273,
display_name='',
category_name='Positive'),
_Category(
index=0,
score=0.486573,
display_name='',
category_name='Negative')
],
timestamp_ms=0)
],
head_index=0,
head_name='probability')
])
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class ImageClassifierTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE))
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _TextClassifier.create_from_model_path(self.model_path) as classifier:
self.assertIsInstance(classifier, _TextClassifier)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _TextClassifierOptions(base_options=base_options)
with _TextClassifier.create_from_options(options) as classifier:
self.assertIsInstance(classifier, _TextClassifier)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _TextClassifierOptions(base_options=base_options)
_TextClassifier.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _TextClassifierOptions(base_options=base_options)
classifier = _TextClassifier.create_from_options(options)
self.assertIsInstance(classifier, _TextClassifier)
@parameterized.parameters(
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
_BERT_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
_NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS),
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _POSITIVE_TEXT,
_BERT_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
_POSITIVE_TEXT, _BERT_POSITIVE_RESULTS),
(ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _NEGATIVE_TEXT,
_REGEX_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE,
_NEGATIVE_TEXT, _REGEX_NEGATIVE_RESULTS),
(ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _POSITIVE_TEXT,
_REGEX_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE,
_POSITIVE_TEXT, _REGEX_POSITIVE_RESULTS))
def test_classify(self, model_file_type, model_name, text,
expected_classification_result):
# Creates classifier.
model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, model_name))
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _TextClassifierOptions(base_options=base_options)
classifier = _TextClassifier.create_from_options(options)
# Performs text classification on the input.
text_result = classifier.classify(text)
# Comparing results.
test_utils.assert_proto_equals(self, text_result.to_pb2(),
expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in
# a context.
classifier.close()
@parameterized.parameters((ModelFileType.FILE_NAME, _BERT_MODEL_FILE,
_NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS),
(ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
_NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS))
def test_classify_in_context(self, model_file_type, model_name, text,
expected_classification_result):
# Creates classifier.
model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, model_name))
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _TextClassifierOptions(base_options=base_options)
with _TextClassifier.create_from_options(options) as classifier:
# Performs text classification on the input.
text_result = classifier.classify(text)
# Comparing results.
test_utils.assert_proto_equals(self, text_result.to_pb2(),
expected_classification_result.to_pb2())
if __name__ == '__main__':
absltest.main()

View File

@ -0,0 +1,38 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "text_classifier",
srcs = [
"text_classifier.py",
],
deps = [
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/text/core:base_text_task_api",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,31 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "base_text_task_api",
srcs = [
"base_text_task_api.py",
],
deps = [
"//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -0,0 +1,16 @@
"""Copyright 2022 The MediaPipe Authors.
All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

View File

@ -0,0 +1,55 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe text task base api."""
from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import task_runner
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_TaskRunner = task_runner.TaskRunner
class BaseTextTaskApi(object):
"""The base class of the user-facing mediapipe text task api classes."""
def __init__(self,
graph_config: calculator_pb2.CalculatorGraphConfig) -> None:
"""Initializes the `BaseVisionTaskApi` object.
Args:
graph_config: The mediapipe text task graph config proto.
"""
self._runner = _TaskRunner.create(graph_config)
def close(self) -> None:
"""Shuts down the mediapipe text task instance.
Raises:
RuntimeError: If the mediapipe text task failed to close.
"""
self._runner.close()
@doc_controls.do_not_generate_docs
def __enter__(self):
"""Returns `self` upon entering the runtime context."""
return self
@doc_controls.do_not_generate_docs
def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
"""Shuts down the mediapipe text task instance on exit of the context manager.
Raises:
RuntimeError: If the mediapipe text task failed to close.
"""
self.close()

View File

@ -0,0 +1,140 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe text classifier task."""
import dataclasses
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.text.core import base_text_task_api
TextClassifierResult = classifications.ClassificationResult
_BaseOptions = base_options_module.BaseOptions
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_TEXT_IN_STREAM_NAME = 'text_in'
_TEXT_TAG = 'TEXT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
@dataclasses.dataclass
class TextClassifierOptions:
"""Options for the text classifier task.
Attributes:
base_options: Base options for the text classifier task.
classifier_options: Options for the text classification task.
"""
base_options: _BaseOptions
classifier_options: _ClassifierOptions = _ClassifierOptions()
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
"""Generates an TextClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
classifier_options_proto = self.classifier_options.to_pb2()
return _TextClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto)
class TextClassifier(base_text_task_api.BaseTextTaskApi):
"""Class that performs classification on text."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'TextClassifier':
"""Creates an `TextClassifier` object from a TensorFlow Lite model and the default `TextClassifierOptions`.
Args:
model_path: Path to the model.
Returns:
`TextClassifier` object that's created from the model file and the
default `TextClassifierOptions`.
Raises:
ValueError: If failed to create `TextClassifier` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = TextClassifierOptions(base_options=base_options)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: TextClassifierOptions) -> 'TextClassifier':
"""Creates the `TextClassifier` object from text classifier options.
Args:
options: Options for the text classifier task.
Returns:
`TextClassifier` object that's created from `options`.
Raises:
ValueError: If failed to create `TextClassifier` object from
`TextClassifierOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
output_streams=[
':'.join([
_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
])
],
task_options=options)
return cls(task_info.generate_graph_config())
def classify(self, text: str) -> TextClassifierResult:
"""Performs classification on the input `text`.
Args:
text: The input text.
Returns:
A `TextClassifierResult` object that contains a list of text
classifications.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If text classification failed to run.
"""
output_packets = self._runner.process(
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return TextClassifierResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])

View File

@ -0,0 +1,33 @@
# This contains the MediaPipe Audio Classifier Task.
#
# This task takes audio data and outputs the classification result.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "audio_classifier",
srcs = [
"audio_classifier.ts",
"audio_classifier_options.ts",
"audio_classifier_result.ts",
],
deps = [
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
],
)

View File

@ -0,0 +1,215 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
// Placeholder for internal dependency on trusted resource url
import {AudioClassifierOptions} from './audio_classifier_options';
import {Classifications} from './audio_classifier_result';
const MEDIAPIPE_GRAPH =
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and
// cannot be changed
// TODO: Change this to `audio_in` to match the name in the CC
// implementation
const AUDIO_STREAM = 'input_audio';
const SAMPLE_RATE_STREAM = 'sample_rate';
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Performs audio classification. */
export class AudioClassifier extends TaskRunner {
private classifications: Classifications[] = [];
private defaultSampleRate = 48000;
private readonly options = new AudioClassifierGraphOptions();
/**
* Initializes the Wasm runtime and creates a new audio classifier from the
* provided options.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param audioClassifierOptions The options for the audio classifier. Note
* that either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`).
*/
static async createFromOptions(
wasmLoaderOptions: WasmLoaderOptions,
audioClassifierOptions: AudioClassifierOptions):
Promise<AudioClassifier> {
// Create a file locator based on the loader options
const fileLocator: FileLocator = {
locateFile() {
// The only file loaded with this mechanism is the Wasm binary
return wasmLoaderOptions.wasmBinaryPath.toString();
}
};
const classifier = await createMediaPipeLib(
AudioClassifier, wasmLoaderOptions.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await classifier.setOptions(audioClassifierOptions);
return classifier;
}
/**
* Initializes the Wasm runtime and creates a new audio classifier based on
* the provided model asset buffer.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
return AudioClassifier.createFromOptions(
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new audio classifier based on
* the path to the model asset.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static async createFromModelPath(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetPath: string): Promise<AudioClassifier> {
const response = await fetch(modelAssetPath.toString());
const graphData = await response.arrayBuffer();
return AudioClassifier.createFromModelBuffer(
wasmLoaderOptions, new Uint8Array(graphData));
}
/**
* Sets new options for the audio classifier.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @param options The options for the audio classifier.
*/
async setOptions(options: AudioClassifierOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions()));
this.refreshGraph();
}
/**
* Sets the sample rate for all calls to `classify()` that omit an explicit
* sample rate. `48000` is used as a default if this method is not called.
*
* @param sampleRate A sample rate (e.g. `44100`).
*/
setDefaultSampleRate(sampleRate: number) {
this.defaultSampleRate = sampleRate;
}
/**
* Performs audio classification on the provided audio data and waits
* synchronously for the response.
*
* @param audioData An array of raw audio capture data, like
* from a call to getChannelData on an AudioBuffer.
* @param sampleRate The sample rate in Hz of the provided audio data. If not
* set, defaults to the sample rate set via `setDefaultSampleRate()` or
* `48000` if no custom default was set.
* @return The classification result of the audio datas
*/
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
sampleRate = sampleRate ?? this.defaultSampleRate;
// Configures the number of samples in the WASM layer. We re-configure the
// number of samples and the sample rate for every frame, but ignore other
// side effects of this function (such as sending the input side packet and
// the input stream header).
this.configureAudio(
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
const timestamp = performance.now();
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
this.addAudioToStream(audioData, timestamp);
this.classifications = [];
this.finishProcessing();
return [...this.classifications];
}
/**
* Internal function for converting raw data into a classification, and
* adding it to our classfications list.
**/
private addJsAudioClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
AudioClassifierGraphOptions.ext, this.options);
// Perform audio classification. Pre-processing and results post-processing
// are built-in.
const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(MEDIAPIPE_GRAPH);
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
classifierNode.addOutputStream(
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
this.addJsAudioClassification(binaryProto);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -0,0 +1,17 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options';

View File

@ -0,0 +1,18 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';

View File

@ -29,31 +29,31 @@ export function convertClassifierOptionsToProto(
baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto { baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
const classifierOptions = const classifierOptions =
baseOptions ? baseOptions.clone() : new ClassifierOptionsProto(); baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
if (options.displayNamesLocale) { if (options.displayNamesLocale !== undefined) {
classifierOptions.setDisplayNamesLocale(options.displayNamesLocale); classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
} else if (options.displayNamesLocale === undefined) { } else if (options.displayNamesLocale === undefined) {
classifierOptions.clearDisplayNamesLocale(); classifierOptions.clearDisplayNamesLocale();
} }
if (options.maxResults) { if (options.maxResults !== undefined) {
classifierOptions.setMaxResults(options.maxResults); classifierOptions.setMaxResults(options.maxResults);
} else if ('maxResults' in options) { // Check for undefined } else if ('maxResults' in options) { // Check for undefined
classifierOptions.clearMaxResults(); classifierOptions.clearMaxResults();
} }
if (options.scoreThreshold) { if (options.scoreThreshold !== undefined) {
classifierOptions.setScoreThreshold(options.scoreThreshold); classifierOptions.setScoreThreshold(options.scoreThreshold);
} else if ('scoreThreshold' in options) { // Check for undefined } else if ('scoreThreshold' in options) { // Check for undefined
classifierOptions.clearScoreThreshold(); classifierOptions.clearScoreThreshold();
} }
if (options.categoryAllowlist) { if (options.categoryAllowlist !== undefined) {
classifierOptions.setCategoryAllowlistList(options.categoryAllowlist); classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
} else if ('categoryAllowlist' in options) { // Check for undefined } else if ('categoryAllowlist' in options) { // Check for undefined
classifierOptions.clearCategoryAllowlistList(); classifierOptions.clearCategoryAllowlistList();
} }
if (options.categoryDenylist) { if (options.categoryDenylist !== undefined) {
classifierOptions.setCategoryDenylistList(options.categoryDenylist); classifierOptions.setCategoryDenylistList(options.categoryDenylist);
} else if ('categoryDenylist' in options) { // Check for undefined } else if ('categoryDenylist' in options) { // Check for undefined
classifierOptions.clearCategoryDenylistList(); classifierOptions.clearCategoryDenylistList();

View File

@ -12,6 +12,18 @@ mediapipe_ts_library(
], ],
) )
mediapipe_ts_library(
name = "task_runner",
srcs = [
"task_runner.ts",
],
deps = [
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
"//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts",
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
],
)
mediapipe_ts_library( mediapipe_ts_library(
name = "classifier_options", name = "classifier_options",
srcs = [ srcs = [

View File

@ -0,0 +1,83 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib';
import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib';
// tslint:disable-next-line:enforce-name-casing
const WasmMediaPipeImageLib =
SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib));
/** Base class for all MediaPipe Tasks. */
export abstract class TaskRunner extends WasmMediaPipeImageLib {
private processingErrors: Error[] = [];
constructor(wasmModule: WasmModule) {
super(wasmModule);
// Disables the automatic render-to-screen code, which allows for pure
// CPU processing.
this.setAutoRenderToScreen(false);
// Enables use of our model resource caching graph service.
this.registerModelResourcesGraphService();
}
/**
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
* over the video stream. Will replace the previously running MediaPipe graph,
* if there is one.
* @param graphData The raw MediaPipe graph data, either in binary
* protobuffer format (.binarypb), or else in raw text format (.pbtxt or
* .textproto).
* @param isBinary This should be set to true if the graph is in
* binary format, and false if it is in human-readable text format.
*/
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
this.attachErrorListener((code, message) => {
this.processingErrors.push(new Error(message));
});
super.setGraph(graphData, isBinary);
this.handleErrors();
}
/**
* Forces all queued-up packets to be pushed through the MediaPipe graph as
* far as possible, performing all processing until no more processing can be
* done.
*/
override finishProcessing(): void {
super.finishProcessing();
this.handleErrors();
}
/** Throws the error from the error listener if an error was raised. */
private handleErrors() {
const errorCount = this.processingErrors.length;
if (errorCount === 1) {
// Re-throw error to get a more meaningful stacktrace
throw new Error(this.processingErrors[0].message);
} else if (errorCount > 1) {
throw new Error(
'Encountered multiple errors: ' +
this.processingErrors.map(e => e.message).join(', '));
}
this.processingErrors = [];
}
}

View File

@ -0,0 +1,34 @@
# This contains the MediaPipe Text Classifier Task.
#
# This task takes text input performs Natural Language classification (including
# BERT-based text classification).
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "text_classifier",
srcs = [
"text_classifier.ts",
"text_classifier_options.d.ts",
"text_classifier_result.d.ts",
],
deps = [
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
],
)

View File

@ -0,0 +1,180 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
// Placeholder for internal dependency on trusted resource url
import {TextClassifierOptions} from './text_classifier_options';
import {Classifications} from './text_classifier_result';
const INPUT_STREAM = 'text_in';
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
const TEXT_CLASSIFIER_GRAPH =
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Performs Natural Language classification. */
export class TextClassifier extends TaskRunner {
private classifications: Classifications[] = [];
private readonly options = new TextClassifierGraphOptions();
/**
* Initializes the Wasm runtime and creates a new text classifier from the
* provided options.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param textClassifierOptions The options for the text classifier. Note that
* either a path to the TFLite model or the model itself needs to be
* provided (via `baseOptions`).
*/
static async createFromOptions(
wasmLoaderOptions: WasmLoaderOptions,
textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
// Create a file locator based on the loader options
const fileLocator: FileLocator = {
locateFile() {
// The only file we load is the Wasm binary
return wasmLoaderOptions.wasmBinaryPath.toString();
}
};
const classifier = await createMediaPipeLib(
TextClassifier, wasmLoaderOptions.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await classifier.setOptions(textClassifierOptions);
return classifier;
}
/**
* Initializes the Wasm runtime and creates a new text classifier based on the
* provided model asset buffer.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
return TextClassifier.createFromOptions(
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new text classifier based on the
* path to the model asset.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static async createFromModelPath(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetPath: string): Promise<TextClassifier> {
const response = await fetch(modelAssetPath.toString());
const graphData = await response.arrayBuffer();
return TextClassifier.createFromModelBuffer(
wasmLoaderOptions, new Uint8Array(graphData));
}
/**
* Sets new options for the text classifier.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @param options The options for the text classifier.
*/
async setOptions(options: TextClassifierOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions()));
this.refreshGraph();
}
/**
* Performs Natural Language classification on the provided text and waits
* synchronously for the response.
*
* @param text The text to process.
* @return The classification result of the text
*/
classify(text: string): Classifications[] {
// Get classification classes by running our MediaPipe graph.
this.classifications = [];
this.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing();
return [...this.classifications];
}
// Internal function for converting raw data into a classification, and
// adding it to our classifications list.
private addJsTextClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
console.log(classificationResult.toObject());
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
TextClassifierGraphOptions.ext, this.options);
const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
classifierNode.addOutputStream(
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
this.addJsTextClassification(binaryProto);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -0,0 +1,17 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options';

View File

@ -0,0 +1,18 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';

View File

@ -0,0 +1,40 @@
# This contains the MediaPipe Gesture Recognizer Task.
#
# This task takes video frames and outputs synchronized frames along with
# the detection results for one or more gesture categories, using Gesture Recognizer.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "gesture_recognizer",
srcs = [
"gesture_recognizer.ts",
"gesture_recognizer_options.d.ts",
"gesture_recognizer_result.d.ts",
],
deps = [
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/framework/formats:rect_jspb_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
],
)

View File

@ -0,0 +1,374 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {ClassificationList} from '../../../../framework/formats/classification_pb';
import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {NormalizedRect} from '../../../../framework/formats/rect_pb';
import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb';
import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb';
import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb';
import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb';
import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb';
import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb';
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark} from '../../../../tasks/web/components/containers/landmark';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib';
// Placeholder for internal dependency on trusted resource url
import {GestureRecognizerOptions} from './gesture_recognizer_options';
import {GestureRecognitionResult} from './gesture_recognizer_result';
export {ImageSource};
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const HAND_GESTURES_STREAM = 'hand_gestures';
const LANDMARKS_STREAM = 'hand_landmarks';
const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks';
const HANDEDNESS_STREAM = 'handedness';
const GESTURE_RECOGNIZER_GRAPH =
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph';
const DEFAULT_NUM_HANDS = 1;
const DEFAULT_SCORE_THRESHOLD = 0.5;
const DEFAULT_CATEGORY_INDEX = -1;
const FULL_IMAGE_RECT = new NormalizedRect();
FULL_IMAGE_RECT.setXCenter(0.5);
FULL_IMAGE_RECT.setYCenter(0.5);
FULL_IMAGE_RECT.setWidth(1);
FULL_IMAGE_RECT.setHeight(1);
/** Performs hand gesture recognition on images. */
export class GestureRecognizer extends TaskRunner {
private gestures: Category[][] = [];
private landmarks: Landmark[][] = [];
private worldLandmarks: Landmark[][] = [];
private handednesses: Category[][] = [];
private readonly options: GestureRecognizerGraphOptions;
private readonly handLandmarkerGraphOptions: HandLandmarkerGraphOptions;
private readonly handLandmarksDetectorGraphOptions:
HandLandmarksDetectorGraphOptions;
private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
private readonly handGestureRecognizerGraphOptions:
HandGestureRecognizerGraphOptions;
/**
* Initializes the Wasm runtime and creates a new gesture recognizer from the
* provided options.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param gestureRecognizerOptions The options for the gesture recognizer.
* Note that either a path to the model asset or a model buffer needs to
* be provided (via `baseOptions`).
*/
static async createFromOptions(
wasmLoaderOptions: WasmLoaderOptions,
gestureRecognizerOptions: GestureRecognizerOptions):
Promise<GestureRecognizer> {
// Create a file locator based on the loader options
const fileLocator: FileLocator = {
locateFile() {
// The only file we load via this mechanism is the Wasm binary
return wasmLoaderOptions.wasmBinaryPath.toString();
}
};
const recognizer = await createMediaPipeLib(
GestureRecognizer, wasmLoaderOptions.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await recognizer.setOptions(gestureRecognizerOptions);
return recognizer;
}
/**
* Initializes the Wasm runtime and creates a new gesture recognizer based on
* the provided model asset buffer.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
return GestureRecognizer.createFromOptions(
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new gesture recognizer based on
* the path to the model asset.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static async createFromModelPath(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetPath: string): Promise<GestureRecognizer> {
const response = await fetch(modelAssetPath.toString());
const graphData = await response.arrayBuffer();
return GestureRecognizer.createFromModelBuffer(
wasmLoaderOptions, new Uint8Array(graphData));
}
constructor(wasmModule: WasmModule) {
super(wasmModule);
this.options = new GestureRecognizerGraphOptions();
this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions();
this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions);
this.handLandmarksDetectorGraphOptions =
new HandLandmarksDetectorGraphOptions();
this.handLandmarkerGraphOptions.setHandLandmarksDetectorGraphOptions(
this.handLandmarksDetectorGraphOptions);
this.handDetectorGraphOptions = new HandDetectorGraphOptions();
this.handLandmarkerGraphOptions.setHandDetectorGraphOptions(
this.handDetectorGraphOptions);
this.handGestureRecognizerGraphOptions =
new HandGestureRecognizerGraphOptions();
this.options.setHandGestureRecognizerGraphOptions(
this.handGestureRecognizerGraphOptions);
this.initDefaults();
// Disables the automatic render-to-screen code, which allows for pure
// CPU processing.
this.setAutoRenderToScreen(false);
}
/**
* Sets new options for the gesture recognizer.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @param options The options for the gesture recognizer.
*/
async setOptions(options: GestureRecognizerOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
this.options.setBaseOptions(baseOptionsProto);
}
if ('numHands' in options) {
this.handDetectorGraphOptions.setNumHands(
options.numHands ?? DEFAULT_NUM_HANDS);
}
if ('minHandDetectionConfidence' in options) {
this.handDetectorGraphOptions.setMinDetectionConfidence(
options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if ('minHandPresenceConfidence' in options) {
this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if ('minTrackingConfidence' in options) {
this.handLandmarkerGraphOptions.setMinTrackingConfidence(
options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if (options.cannedGesturesClassifierOptions) {
// Note that we have to support both JSPB and ProtobufJS and cannot
// use JSPB's getMutableX() APIs.
const graphOptions = new GestureClassifierGraphOptions();
graphOptions.setClassifierOptions(convertClassifierOptionsToProto(
options.cannedGesturesClassifierOptions,
this.handGestureRecognizerGraphOptions
.getCannedGestureClassifierGraphOptions()
?.getClassifierOptions()));
this.handGestureRecognizerGraphOptions
.setCannedGestureClassifierGraphOptions(graphOptions);
} else if (options.cannedGesturesClassifierOptions === undefined) {
this.handGestureRecognizerGraphOptions
.getCannedGestureClassifierGraphOptions()
?.clearClassifierOptions();
}
if (options.customGesturesClassifierOptions) {
const graphOptions = new GestureClassifierGraphOptions();
graphOptions.setClassifierOptions(convertClassifierOptionsToProto(
options.customGesturesClassifierOptions,
this.handGestureRecognizerGraphOptions
.getCustomGestureClassifierGraphOptions()
?.getClassifierOptions()));
this.handGestureRecognizerGraphOptions
.setCustomGestureClassifierGraphOptions(graphOptions);
} else if (options.customGesturesClassifierOptions === undefined) {
this.handGestureRecognizerGraphOptions
.getCustomGestureClassifierGraphOptions()
?.clearClassifierOptions();
}
this.refreshGraph();
}
/**
* Performs gesture recognition on the provided single image and waits
* synchronously for the response.
* @param imageSource An image source to process.
* @param timestamp The timestamp of the current frame, in ms. If not
* provided, defaults to `performance.now()`.
* @return The detected gestures.
*/
recognize(imageSource: ImageSource, timestamp: number = performance.now()):
GestureRecognitionResult {
this.gestures = [];
this.landmarks = [];
this.worldLandmarks = [];
this.handednesses = [];
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
this.addProtoToStream(
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
NORM_RECT_STREAM, timestamp);
this.finishProcessing();
return {
gestures: this.gestures,
landmarks: this.landmarks,
worldLandmarks: this.worldLandmarks,
handednesses: this.handednesses
};
}
/** Sets the default values for the graph. */
private initDefaults(): void {
this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS);
this.handDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
this.handLandmarkerGraphOptions.setMinTrackingConfidence(
DEFAULT_SCORE_THRESHOLD);
}
/** Converts the proto data to a Category[][] structure. */
private toJsCategories(data: Uint8Array[]): Category[][] {
const result: Category[][] = [];
for (const binaryProto of data) {
const inputList = ClassificationList.deserializeBinary(binaryProto);
const outputList: Category[] = [];
for (const classification of inputList.getClassificationList()) {
outputList.push({
score: classification.getScore() ?? 0,
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
});
}
result.push(outputList);
}
return result;
}
/** Converts raw data into a landmark, and adds it to our landmarks list. */
private addJsLandmarks(data: Uint8Array[]): void {
for (const binaryProto of data) {
const handLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto);
const landmarks: Landmark[] = [];
for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) {
landmarks.push({
x: handLandmarkProto.getX() ?? 0,
y: handLandmarkProto.getY() ?? 0,
z: handLandmarkProto.getZ() ?? 0,
normalized: true
});
}
this.landmarks.push(landmarks);
}
}
/**
* Converts raw data into a landmark, and adds it to our worldLandmarks
* list.
*/
private adddJsWorldLandmarks(data: Uint8Array[]): void {
for (const binaryProto of data) {
const handWorldLandmarksProto =
LandmarkList.deserializeBinary(binaryProto);
const worldLandmarks: Landmark[] = [];
for (const handWorldLandmarkProto of
handWorldLandmarksProto.getLandmarkList()) {
worldLandmarks.push({
x: handWorldLandmarkProto.getX() ?? 0,
y: handWorldLandmarkProto.getY() ?? 0,
z: handWorldLandmarkProto.getZ() ?? 0,
normalized: false
});
}
this.worldLandmarks.push(worldLandmarks);
}
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM);
graphConfig.addOutputStream(HAND_GESTURES_STREAM);
graphConfig.addOutputStream(LANDMARKS_STREAM);
graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM);
graphConfig.addOutputStream(HANDEDNESS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
GestureRecognizerGraphOptions.ext, this.options);
const recognizerNode = new CalculatorGraphConfig.Node();
recognizerNode.setCalculator(GESTURE_RECOGNIZER_GRAPH);
recognizerNode.addInputStream('IMAGE:' + IMAGE_STREAM);
recognizerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
recognizerNode.addOutputStream('HAND_GESTURES:' + HAND_GESTURES_STREAM);
recognizerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM);
recognizerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM);
recognizerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM);
recognizerNode.setOptions(calculatorOptions);
graphConfig.addNode(recognizerNode);
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
this.addJsLandmarks(binaryProto);
});
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
this.adddJsWorldLandmarks(binaryProto);
});
this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => {
this.gestures.push(...this.toJsCategories(binaryProto));
});
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -0,0 +1,65 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions} from '../../../../tasks/web/core/base_options';
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
/** Options to configure the MediaPipe Gesture Recognizer Task */
export interface GestureRecognizerOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
/**
* The maximum number of hands can be detected by the GestureRecognizer.
* Defaults to 1.
*/
numHands?: number|undefined;
/**
* The minimum confidence score for the hand detection to be considered
* successful. Defaults to 0.5.
*/
minHandDetectionConfidence?: number|undefined;
/**
* The minimum confidence score of hand presence score in the hand landmark
* detection. Defaults to 0.5.
*/
minHandPresenceConfidence?: number|undefined;
/**
* The minimum confidence score for the hand tracking to be considered
* successful. Defaults to 0.5.
*/
minTrackingConfidence?: number|undefined;
/**
* Sets the optional `ClassifierOptions` controling the canned gestures
* classifier, such as score threshold, allow list and deny list of gestures.
* The categories for canned gesture
* classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up",
* "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"]
*/
// TODO: Note this option is subject to change
cannedGesturesClassifierOptions?: ClassifierOptions|undefined;
/**
* Options for configuring the custom gestures classifier, such as score
* threshold, allow list and deny list of gestures.
*/
// TODO b/251816640): Note this option is subject to change.
customGesturesClassifierOptions?: ClassifierOptions|undefined;
}

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark} from '../../../../tasks/web/components/containers/landmark';
/**
* Represents the gesture recognition results generated by `GestureRecognizer`.
*/
export interface GestureRecognitionResult {
/** Hand landmarks of detected hands. */
landmarks: Landmark[][];
/** Hand landmarks in world coordniates of detected hands. */
worldLandmarks: Landmark[][];
/** Handedness of detected hands. */
handednesses: Category[][];
/** Recognized hand gestures of detected hands */
gestures: Category[][];
}

View File

@ -0,0 +1,33 @@
# This contains the MediaPipe Image Classifier Task.
#
# This task takes video or image frames and outputs the classification result.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "image_classifier",
srcs = [
"image_classifier.ts",
"image_classifier_options.ts",
"image_classifier_result.ts",
],
deps = [
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
],
)

View File

@ -0,0 +1,186 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib';
// Placeholder for internal dependency on trusted resource url
import {ImageClassifierOptions} from './image_classifier_options';
import {Classifications} from './image_classifier_result';
const IMAGE_CLASSIFIER_GRAPH =
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
const INPUT_STREAM = 'input_image';
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
export {ImageSource}; // Used in the public API
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Performs classification on images. */
export class ImageClassifier extends TaskRunner {
private classifications: Classifications[] = [];
private readonly options = new ImageClassifierGraphOptions();
/**
* Initializes the Wasm runtime and creates a new image classifier from the
* provided options.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param imageClassifierOptions The options for the image classifier. Note
* that either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`).
*/
static async createFromOptions(
wasmLoaderOptions: WasmLoaderOptions,
imageClassifierOptions: ImageClassifierOptions):
Promise<ImageClassifier> {
// Create a file locator based on the loader options
const fileLocator: FileLocator = {
locateFile() {
// The only file we load is the Wasm binary
return wasmLoaderOptions.wasmBinaryPath.toString();
}
};
const classifier = await createMediaPipeLib(
ImageClassifier, wasmLoaderOptions.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await classifier.setOptions(imageClassifierOptions);
return classifier;
}
/**
* Initializes the Wasm runtime and creates a new image classifier based on
* the provided model asset buffer.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
return ImageClassifier.createFromOptions(
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new image classifier based on
* the path to the model asset.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static async createFromModelPath(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetPath: string): Promise<ImageClassifier> {
const response = await fetch(modelAssetPath.toString());
const graphData = await response.arrayBuffer();
return ImageClassifier.createFromModelBuffer(
wasmLoaderOptions, new Uint8Array(graphData));
}
/**
* Sets new options for the image classifier.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @param options The options for the image classifier.
*/
async setOptions(options: ImageClassifierOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
this.options.setBaseOptions(baseOptionsProto);
}
this.options.setClassifierOptions(convertClassifierOptionsToProto(
options, this.options.getClassifierOptions()));
this.refreshGraph();
}
/**
* Performs image classification on the provided image and waits synchronously
* for the response.
*
* @param imageSource An image source to process.
* @param timestamp The timestamp of the current frame, in ms. If not
* provided, defaults to `performance.now()`.
* @return The classification result of the image
*/
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
// Get classification classes by running our MediaPipe graph.
this.classifications = [];
this.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing();
return [...this.classifications];
}
/**
* Internal function for converting raw data into a classification, and
* adding it to our classfications list.
**/
private addJsImageClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
ImageClassifierGraphOptions.ext, this.options);
// Perform image classification. Pre-processing and results post-processing
// are built-in.
const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
classifierNode.addOutputStream(
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
this.addJsImageClassification(binaryProto);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -0,0 +1,17 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options';

View File

@ -0,0 +1,18 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';

View File

@ -0,0 +1,30 @@
# This contains the MediaPipe Object Detector Task.
#
# This task takes video frames and outputs synchronized frames along with
# the detection results for one or more object categories, using Object Detector.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "object_detector",
srcs = [
"object_detector.ts",
"object_detector_options.d.ts",
"object_detector_result.d.ts",
],
deps = [
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/framework/formats:detection_jspb_proto",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
],
)

View File

@ -0,0 +1,233 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib';
// Placeholder for internal dependency on trusted resource url
import {ObjectDetectorOptions} from './object_detector_options';
import {Detection} from './object_detector_result';
const INPUT_STREAM = 'input_frame_gpu';
const DETECTIONS_STREAM = 'detections';
const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph';
const DEFAULT_CATEGORY_INDEX = -1;
export {ImageSource}; // Used in the public API
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Performs object detection on images. */
export class ObjectDetector extends TaskRunner {
private detections: Detection[] = [];
private readonly options = new ObjectDetectorOptionsProto();
/**
* Initializes the Wasm runtime and creates a new object detector from the
* provided options.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param objectDetectorOptions The options for the Object Detector. Note that
* either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`).
*/
static async createFromOptions(
wasmLoaderOptions: WasmLoaderOptions,
objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
// Create a file locator based on the loader options
const fileLocator: FileLocator = {
locateFile() {
// The only file we load is the Wasm binary
return wasmLoaderOptions.wasmBinaryPath.toString();
}
};
const detector = await createMediaPipeLib(
ObjectDetector, wasmLoaderOptions.wasmLoaderPath,
/* assetLoaderScript= */ undefined,
/* glCanvas= */ undefined, fileLocator);
await detector.setOptions(objectDetectorOptions);
return detector;
}
/**
* Initializes the Wasm runtime and creates a new object detector based on the
* provided model asset buffer.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
return ObjectDetector.createFromOptions(
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new object detector based on the
* path to the model asset.
* @param wasmLoaderOptions A configuration object that provides the location
* of the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static async createFromModelPath(
wasmLoaderOptions: WasmLoaderOptions,
modelAssetPath: string): Promise<ObjectDetector> {
const response = await fetch(modelAssetPath.toString());
const graphData = await response.arrayBuffer();
return ObjectDetector.createFromModelBuffer(
wasmLoaderOptions, new Uint8Array(graphData));
}
/**
* Sets new options for the object detector.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @param options The options for the object detector.
*/
async setOptions(options: ObjectDetectorOptions): Promise<void> {
if (options.baseOptions) {
const baseOptionsProto =
await convertBaseOptionsToProto(options.baseOptions);
this.options.setBaseOptions(baseOptionsProto);
}
// Note that we have to support both JSPB and ProtobufJS, hence we
// have to expliclity clear the values instead of setting them to
// `undefined`.
if (options.displayNamesLocale !== undefined) {
this.options.setDisplayNamesLocale(options.displayNamesLocale);
} else if ('displayNamesLocale' in options) { // Check for undefined
this.options.clearDisplayNamesLocale();
}
if (options.maxResults !== undefined) {
this.options.setMaxResults(options.maxResults);
} else if ('maxResults' in options) { // Check for undefined
this.options.clearMaxResults();
}
if (options.scoreThreshold !== undefined) {
this.options.setScoreThreshold(options.scoreThreshold);
} else if ('scoreThreshold' in options) { // Check for undefined
this.options.clearScoreThreshold();
}
if (options.categoryAllowlist !== undefined) {
this.options.setCategoryAllowlistList(options.categoryAllowlist);
} else if ('categoryAllowlist' in options) { // Check for undefined
this.options.clearCategoryAllowlistList();
}
if (options.categoryDenylist !== undefined) {
this.options.setCategoryDenylistList(options.categoryDenylist);
} else if ('categoryDenylist' in options) { // Check for undefined
this.options.clearCategoryDenylistList();
}
this.refreshGraph();
}
/**
* Performs object detection on the provided single image and waits
* synchronously for the response.
* @param imageSource An image source to process.
* @param timestamp The timestamp of the current frame, in ms. If not
* provided, defaults to `performance.now()`.
* @return The list of detected objects
*/
detect(imageSource: ImageSource, timestamp?: number): Detection[] {
// Get detections by running our MediaPipe graph.
this.detections = [];
this.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing();
return [...this.detections];
}
/** Converts raw data into a Detection, and adds it to our detection list. */
private addJsObjectDetections(data: Uint8Array[]): void {
for (const binaryProto of data) {
const detectionProto = DetectionProto.deserializeBinary(binaryProto);
const scores = detectionProto.getScoreList();
const indexes = detectionProto.getLabelIdList();
const labels = detectionProto.getLabelList();
const displayNames = detectionProto.getDisplayNameList();
const detection: Detection = {categories: []};
for (let i = 0; i < scores.length; i++) {
detection.categories.push({
score: scores[i],
index: indexes[i] ?? DEFAULT_CATEGORY_INDEX,
categoryName: labels[i] ?? '',
displayName: displayNames[i] ?? '',
});
}
const boundingBox = detectionProto.getLocationData()?.getBoundingBox();
if (boundingBox) {
detection.boundingBox = {
originX: boundingBox.getXmin() ?? 0,
originY: boundingBox.getYmin() ?? 0,
width: boundingBox.getWidth() ?? 0,
height: boundingBox.getHeight() ?? 0
};
}
this.detections.push(detection);
}
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(DETECTIONS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
ObjectDetectorOptionsProto.ext, this.options);
const detectorNode = new CalculatorGraphConfig.Node();
detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH);
detectorNode.addInputStream('IMAGE:' + INPUT_STREAM);
detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM);
detectorNode.setOptions(calculatorOptions);
graphConfig.addNode(detectorNode);
this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => {
this.addJsObjectDetections(binaryProto);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions} from '../../../../tasks/web/core/base_options';
/** Options to configure the MediaPipe Object Detector Task */
export interface ObjectDetectorOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
/**
* The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English.
*/
displayNamesLocale?: string|undefined;
/** The maximum number of top-scored detection results to return. */
maxResults?: number|undefined;
/**
* Overrides the value provided in the model metadata. Results below this
* value are rejected.
*/
scoreThreshold?: number|undefined;
/**
* Allowlist of category names. If non-empty, detection results whose category
* name is not in this set will be filtered out. Duplicate or unknown category
* names are ignored. Mutually exclusive with `categoryDenylist`.
*/
categoryAllowlist?: string[]|undefined;
/**
* Denylist of category names. If non-empty, detection results whose category
* name is in this set will be filtered out. Duplicate or unknown category
* names are ignored. Mutually exclusive with `categoryAllowlist`.
*/
categoryDenylist?: string[]|undefined;
}

View File

@ -0,0 +1,38 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Category} from '../../../../tasks/web/components/containers/category';
/** An integer bounding box, axis aligned. */
export interface BoundingBox {
/** The X coordinate of the top-left corner, in pixels. */
originX: number;
/** The Y coordinate of the top-left corner, in pixels. */
originY: number;
/** The width of the bounding box, in pixels. */
width: number;
/** The height of the bounding box, in pixels. */
height: number;
}
/** Represents one object detected by the `ObjectDetector`. */
export interface Detection {
/** A list of `Category` objects. */
categories: Category[];
/** The bounding box of the detected objects. */
boundingBox?: BoundingBox;
}

View File

@ -0,0 +1,41 @@
# The TypeScript graph runner used by all MediaPipe Web tasks.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = [
":internal",
"//mediapipe/tasks:internal",
])
package_group(
name = "internal",
packages = [
"//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...",
],
)
mediapipe_ts_library(
name = "wasm_mediapipe_lib_ts",
srcs = [
":wasm_mediapipe_lib.ts",
],
allow_unoptimized_namespaces = True,
)
mediapipe_ts_library(
name = "wasm_mediapipe_image_lib_ts",
srcs = [
":wasm_mediapipe_image_lib.ts",
],
allow_unoptimized_namespaces = True,
deps = [":wasm_mediapipe_lib_ts"],
)
mediapipe_ts_library(
name = "register_model_resources_graph_service_ts",
srcs = [
":register_model_resources_graph_service.ts",
],
allow_unoptimized_namespaces = True,
deps = [":wasm_mediapipe_lib_ts"],
)

View File

@ -0,0 +1,41 @@
import {WasmMediaPipeLib} from './wasm_mediapipe_lib';
/**
* We extend from a WasmMediaPipeLib constructor. This ensures our mixin has
* access to the wasmModule, among other things. The `any` type is required for
* mixin constructors.
*/
// tslint:disable-next-line:no-any
type LibConstructor = new (...args: any[]) => WasmMediaPipeLib;
/**
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
* doesn't break our JS/C++ bridge.
*/
export declare interface WasmModuleRegisterModelResources {
_registerModelResourcesGraphService: () => void;
}
/**
* An implementation of WasmMediaPipeLib that supports registering model
* resources to a cache, in the form of a GraphService C++-side. We implement as
* a proper TS mixin, to allow for effective multiple inheritance. Sample usage:
* `const WasmMediaPipeImageLib = SupportModelResourcesGraphService(
* WasmMediaPipeLib);`
*/
// tslint:disable:enforce-name-casing
export function SupportModelResourcesGraphService<TBase extends LibConstructor>(
Base: TBase) {
return class extends Base {
// tslint:enable:enforce-name-casing
/**
* Instructs the graph runner to use the model resource caching graph
* service for both graph expansion/inintialization, as well as for graph
* run.
*/
registerModelResourcesGraphService(): void {
(this.wasmModule as unknown as WasmModuleRegisterModelResources)
._registerModelResourcesGraphService();
}
};
}

View File

@ -0,0 +1,52 @@
import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib';
/**
* We extend from a WasmMediaPipeLib constructor. This ensures our mixin has
* access to the wasmModule, among other things. The `any` type is required for
* mixin constructors.
*/
// tslint:disable-next-line:no-any
type LibConstructor = new (...args: any[]) => WasmMediaPipeLib;
/**
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
* doesn't break our JS/C++ bridge.
*/
export declare interface WasmImageModule {
_addBoundTextureAsImageToStream:
(streamNamePtr: number, width: number, height: number,
timestamp: number) => void;
}
/**
* An implementation of WasmMediaPipeLib that supports binding GPU image data as
* `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for
* effective multiple inheritance. Example usage:
* `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);`
*/
// tslint:disable-next-line:enforce-name-casing
export function SupportImage<TBase extends LibConstructor>(Base: TBase) {
return class extends Base {
/**
* Takes the relevant information from the HTML video or image element, and
* passes it into the WebGL-based graph for processing on the given stream
* at the given timestamp as a MediaPipe image. Processing will not occur
* until a blocking call (like processVideoGl or finishProcessing) is made.
* @param imageSource Reference to the video frame we wish to add into our
* graph.
* @param streamName The name of the MediaPipe graph stream to add the frame
* to.
* @param timestamp The timestamp of the input frame, in ms.
*/
addGpuBufferAsImageToStream(
imageSource: ImageSource, streamName: string, timestamp: number): void {
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
const [width, height] =
this.bindTextureToStream(imageSource, streamNamePtr);
(this.wasmModule as unknown as WasmImageModule)
._addBoundTextureAsImageToStream(
streamNamePtr, width, height, timestamp);
});
}
};
}

File diff suppressed because it is too large Load Diff