Merge branch 'master' into image-embedder-python
This commit is contained in:
commit
ba1ee5b404
|
@ -936,6 +936,7 @@ cc_test(
|
||||||
"//mediapipe/framework/tool:simulation_clock",
|
"//mediapipe/framework/tool:simulation_clock",
|
||||||
"//mediapipe/framework/tool:simulation_clock_executor",
|
"//mediapipe/framework/tool:simulation_clock_executor",
|
||||||
"//mediapipe/framework/tool:sink",
|
"//mediapipe/framework/tool:sink",
|
||||||
|
"//mediapipe/util:packet_test_util",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
|
|
||||||
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
|
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/util/header_util.h"
|
#include "mediapipe/util/header_util.h"
|
||||||
|
|
||||||
|
@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS";
|
||||||
// FlowLimiterCalculator provides limited support for multiple input streams.
|
// FlowLimiterCalculator provides limited support for multiple input streams.
|
||||||
// The first input stream is treated as the main input stream and successive
|
// The first input stream is treated as the main input stream and successive
|
||||||
// input streams are treated as auxiliary input streams. The auxiliary input
|
// input streams are treated as auxiliary input streams. The auxiliary input
|
||||||
// streams are limited to timestamps passed on the main input stream.
|
// streams are limited to timestamps allowed by the "ALLOW" stream.
|
||||||
//
|
//
|
||||||
class FlowLimiterCalculator : public CalculatorBase {
|
class FlowLimiterCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
|
@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
|
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
|
||||||
}
|
}
|
||||||
input_queues_.resize(cc->Inputs().NumEntries(""));
|
input_queues_.resize(cc->Inputs().NumEntries(""));
|
||||||
|
allowed_[Timestamp::Unset()] = true;
|
||||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if an additional frame can be released for processing.
|
|
||||||
// The "ALLOW" output stream indicates this condition at each input frame.
|
|
||||||
bool ProcessingAllowed() {
|
|
||||||
return frames_in_flight_.size() < options_.max_in_flight();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Outputs a packet indicating whether a frame was sent or dropped.
|
|
||||||
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
|
||||||
if (cc->Outputs().HasTag(kAllowTag)) {
|
|
||||||
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the timestamp bound or closes an output stream.
|
|
||||||
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
|
|
||||||
if (bound > Timestamp::Max()) {
|
|
||||||
stream->Close();
|
|
||||||
} else {
|
|
||||||
stream->SetNextTimestampBound(bound);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if a certain timestamp is being processed.
|
|
||||||
bool IsInFlight(Timestamp timestamp) {
|
|
||||||
return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
|
|
||||||
timestamp) != frames_in_flight_.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases input packets up to the latest settled input timestamp.
|
|
||||||
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
|
|
||||||
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
|
|
||||||
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
|
|
||||||
// Release settled frames from each input queue.
|
|
||||||
while (!input_queues_[i].empty() &&
|
|
||||||
input_queues_[i].front().Timestamp() < settled_bound) {
|
|
||||||
Packet packet = input_queues_[i].front();
|
|
||||||
input_queues_[i].pop_front();
|
|
||||||
if (IsInFlight(packet.Timestamp())) {
|
|
||||||
cc->Outputs().Get("", i).AddPacket(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Propagate each input timestamp bound.
|
|
||||||
if (!input_queues_[i].empty()) {
|
|
||||||
Timestamp bound = input_queues_[i].front().Timestamp();
|
|
||||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
|
||||||
} else {
|
|
||||||
Timestamp bound =
|
|
||||||
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
|
|
||||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases input packets allowed by the max_in_flight constraint.
|
// Releases input packets allowed by the max_in_flight constraint.
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
||||||
|
@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
ProcessAuxiliaryInputs(cc);
|
ProcessAuxiliaryInputs(cc);
|
||||||
|
|
||||||
|
// Discard old ALLOW ranges.
|
||||||
|
Timestamp input_bound = InputTimestampBound(cc);
|
||||||
|
auto first_range = std::prev(allowed_.upper_bound(input_bound));
|
||||||
|
allowed_.erase(allowed_.begin(), first_range);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int LedgerSize() {
|
||||||
|
int result = frames_in_flight_.size() + allowed_.size();
|
||||||
|
for (const auto& queue : input_queues_) {
|
||||||
|
result += queue.size();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns true if an additional frame can be released for processing.
|
||||||
|
// The "ALLOW" output stream indicates this condition at each input frame.
|
||||||
|
bool ProcessingAllowed() {
|
||||||
|
return frames_in_flight_.size() < options_.max_in_flight();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Outputs a packet indicating whether a frame was sent or dropped.
|
||||||
|
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
||||||
|
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||||
|
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
|
||||||
|
}
|
||||||
|
allowed_[ts] = allow;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if a timestamp falls within a range of allowed timestamps.
|
||||||
|
bool IsAllowed(Timestamp timestamp) {
|
||||||
|
auto it = allowed_.upper_bound(timestamp);
|
||||||
|
return std::prev(it)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets the timestamp bound or closes an output stream.
|
||||||
|
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
|
||||||
|
if (bound > Timestamp::Max()) {
|
||||||
|
stream->Close();
|
||||||
|
} else {
|
||||||
|
stream->SetNextTimestampBound(bound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the lowest unprocessed input Timestamp.
|
||||||
|
Timestamp InputTimestampBound(CalculatorContext* cc) {
|
||||||
|
Timestamp result = Timestamp::Done();
|
||||||
|
for (int i = 0; i < input_queues_.size(); ++i) {
|
||||||
|
auto& queue = input_queues_[i];
|
||||||
|
auto& stream = cc->Inputs().Get("", i);
|
||||||
|
Timestamp bound = queue.empty()
|
||||||
|
? stream.Value().Timestamp().NextAllowedInStream()
|
||||||
|
: queue.front().Timestamp();
|
||||||
|
result = std::min(result, bound);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Releases input packets up to the latest settled input timestamp.
|
||||||
|
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
|
||||||
|
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
|
||||||
|
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
|
||||||
|
// Release settled frames from each input queue.
|
||||||
|
while (!input_queues_[i].empty() &&
|
||||||
|
input_queues_[i].front().Timestamp() < settled_bound) {
|
||||||
|
Packet packet = input_queues_[i].front();
|
||||||
|
input_queues_[i].pop_front();
|
||||||
|
if (IsAllowed(packet.Timestamp())) {
|
||||||
|
cc->Outputs().Get("", i).AddPacket(packet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate each input timestamp bound.
|
||||||
|
if (!input_queues_[i].empty()) {
|
||||||
|
Timestamp bound = input_queues_[i].front().Timestamp();
|
||||||
|
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
||||||
|
} else {
|
||||||
|
Timestamp bound =
|
||||||
|
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
|
||||||
|
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FlowLimiterCalculatorOptions options_;
|
FlowLimiterCalculatorOptions options_;
|
||||||
std::vector<std::deque<Packet>> input_queues_;
|
std::vector<std::deque<Packet>> input_queues_;
|
||||||
std::deque<Timestamp> frames_in_flight_;
|
std::deque<Timestamp> frames_in_flight_;
|
||||||
|
std::map<Timestamp, bool> allowed_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(FlowLimiterCalculator);
|
REGISTER_CALCULATOR(FlowLimiterCalculator);
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/time/clock.h"
|
#include "absl/time/clock.h"
|
||||||
|
@ -32,6 +33,7 @@
|
||||||
#include "mediapipe/framework/tool/simulation_clock.h"
|
#include "mediapipe/framework/tool/simulation_clock.h"
|
||||||
#include "mediapipe/framework/tool/simulation_clock_executor.h"
|
#include "mediapipe/framework/tool/simulation_clock_executor.h"
|
||||||
#include "mediapipe/framework/tool/sink.h"
|
#include "mediapipe/framework/tool/sink.h"
|
||||||
|
#include "mediapipe/util/packet_test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -77,6 +79,77 @@ std::vector<T> PacketValues(const std::vector<Packet>& packets) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<Packet> MakePackets(std::vector<std::pair<Timestamp, T>> contents) {
|
||||||
|
std::vector<Packet> result;
|
||||||
|
for (auto& entry : contents) {
|
||||||
|
result.push_back(MakePacket<T>(entry.second).At(entry.first));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SourceString(Timestamp t) {
|
||||||
|
return (t.IsSpecialValue())
|
||||||
|
? t.DebugString()
|
||||||
|
: absl::StrCat("Timestamp(", t.DebugString(), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PacketContainer, typename PacketContent>
|
||||||
|
class PacketsEqMatcher
|
||||||
|
: public ::testing::MatcherInterface<const PacketContainer&> {
|
||||||
|
public:
|
||||||
|
PacketsEqMatcher(PacketContainer packets) : packets_(packets) {}
|
||||||
|
void DescribeTo(::std::ostream* os) const override {
|
||||||
|
*os << "The expected packet contents: \n";
|
||||||
|
Print(packets_, os);
|
||||||
|
}
|
||||||
|
bool MatchAndExplain(
|
||||||
|
const PacketContainer& value,
|
||||||
|
::testing::MatchResultListener* listener) const override {
|
||||||
|
if (!Equals(packets_, value)) {
|
||||||
|
if (listener->IsInterested()) {
|
||||||
|
*listener << "The actual packet contents: \n";
|
||||||
|
Print(value, listener->stream());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool Equals(const PacketContainer& c1, const PacketContainer& c2) const {
|
||||||
|
if (c1.size() != c2.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) {
|
||||||
|
Packet p1 = *i1, p2 = *i2;
|
||||||
|
if (p1.Timestamp() != p2.Timestamp() ||
|
||||||
|
p1.Get<PacketContent>() != p2.Get<PacketContent>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void Print(const PacketContainer& packets, ::std::ostream* os) const {
|
||||||
|
for (auto it = packets.begin(); it != packets.end(); ++it) {
|
||||||
|
const Packet& packet = *it;
|
||||||
|
*os << (it == packets.begin() ? "{" : "") << "{"
|
||||||
|
<< SourceString(packet.Timestamp()) << ", "
|
||||||
|
<< packet.Get<PacketContent>() << "}"
|
||||||
|
<< (std::next(it) == packets.end() ? "}" : ", ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const PacketContainer packets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename PacketContainer, typename PacketContent>
|
||||||
|
::testing::Matcher<const PacketContainer&> PackestEq(
|
||||||
|
const PacketContainer& packets) {
|
||||||
|
return MakeMatcher(
|
||||||
|
new PacketsEqMatcher<PacketContainer, PacketContent>(packets));
|
||||||
|
}
|
||||||
|
|
||||||
// A Calculator::Process callback function.
|
// A Calculator::Process callback function.
|
||||||
typedef std::function<absl::Status(const InputStreamShardSet&,
|
typedef std::function<absl::Status(const InputStreamShardSet&,
|
||||||
OutputStreamShardSet*)>
|
OutputStreamShardSet*)>
|
||||||
|
@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
|
||||||
input_packets_[17], input_packets_[19], input_packets_[20],
|
input_packets_[17], input_packets_[19], input_packets_[20],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_1_packets_, expected_output);
|
EXPECT_EQ(out_1_packets_, expected_output);
|
||||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
// The timestamps released by FlowLimiterCalculator for in_1_sampled,
|
||||||
|
// plus input_packets_[21].
|
||||||
std::vector<Packet> expected_output_2 = {
|
std::vector<Packet> expected_output_2 = {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||||
input_packets_[14], input_packets_[17], input_packets_[19],
|
input_packets_[14], input_packets_[17], input_packets_[19],
|
||||||
input_packets_[20],
|
input_packets_[20], input_packets_[21],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_2_packets, expected_output_2);
|
EXPECT_EQ(out_2_packets, expected_output_2);
|
||||||
}
|
}
|
||||||
|
@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
|
||||||
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
||||||
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
||||||
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
|
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
|
||||||
|
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
|
||||||
|
|
||||||
// Configure the test.
|
// Configure the test.
|
||||||
SetUpInputData();
|
SetUpInputData();
|
||||||
SetUpSimulationClock();
|
SetUpSimulationClock();
|
||||||
|
@ -699,10 +776,9 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
}
|
}
|
||||||
)pb");
|
)pb");
|
||||||
|
|
||||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
|
||||||
max_in_flight: 1
|
R"pb(
|
||||||
max_in_queue: 0
|
max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000 # 100 ms
|
||||||
in_flight_timeout: 100000 # 100 ms
|
|
||||||
)pb");
|
)pb");
|
||||||
std::map<std::string, Packet> side_packets = {
|
std::map<std::string, Packet> side_packets = {
|
||||||
{"limiter_options",
|
{"limiter_options",
|
||||||
|
@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[15],
|
input_packets_[0], input_packets_[2], input_packets_[15],
|
||||||
input_packets_[17], input_packets_[19],
|
input_packets_[17], input_packets_[19],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_1_packets_, expected_output);
|
EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
|
||||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
||||||
std::vector<Packet> expected_output_2 = {
|
std::vector<Packet> expected_output_2 = {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||||
input_packets_[15], input_packets_[17], input_packets_[19],
|
input_packets_[15], input_packets_[17], input_packets_[19],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_2_packets, expected_output_2);
|
EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
|
||||||
|
|
||||||
|
// Validate the ALLOW stream output.
|
||||||
|
std::vector<Packet> expected_allow = MakePackets<bool>( //
|
||||||
|
{{Timestamp(0), true}, {Timestamp(10000), false},
|
||||||
|
{Timestamp(20000), true}, {Timestamp(30000), false},
|
||||||
|
{Timestamp(40000), true}, {Timestamp(50000), false},
|
||||||
|
{Timestamp(60000), false}, {Timestamp(70000), false},
|
||||||
|
{Timestamp(80000), false}, {Timestamp(90000), false},
|
||||||
|
{Timestamp(100000), false}, {Timestamp(110000), false},
|
||||||
|
{Timestamp(120000), false}, {Timestamp(130000), false},
|
||||||
|
{Timestamp(140000), false}, {Timestamp(150000), true},
|
||||||
|
{Timestamp(160000), false}, {Timestamp(170000), true},
|
||||||
|
{Timestamp(180000), false}, {Timestamp(190000), true},
|
||||||
|
{Timestamp(200000), false}});
|
||||||
|
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows how FlowLimiterCalculator releases auxiliary input packets.
|
||||||
|
// In this test, auxiliary input packets arrive at twice the primary rate.
|
||||||
|
TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
||||||
|
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
|
||||||
|
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
|
||||||
|
|
||||||
|
// Configure the test.
|
||||||
|
SetUpInputData();
|
||||||
|
SetUpSimulationClock();
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'in_1'
|
||||||
|
input_stream: 'in_2'
|
||||||
|
node {
|
||||||
|
calculator: 'FlowLimiterCalculator'
|
||||||
|
input_side_packet: 'OPTIONS:limiter_options'
|
||||||
|
input_stream: 'in_1'
|
||||||
|
input_stream: 'in_2'
|
||||||
|
input_stream: 'FINISHED:out_1'
|
||||||
|
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
|
||||||
|
output_stream: 'in_1_sampled'
|
||||||
|
output_stream: 'in_2_sampled'
|
||||||
|
output_stream: 'ALLOW:allow'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'SleepCalculator'
|
||||||
|
input_side_packet: 'WARMUP_TIME:warmup_time'
|
||||||
|
input_side_packet: 'SLEEP_TIME:sleep_time'
|
||||||
|
input_side_packet: 'CLOCK:clock'
|
||||||
|
input_stream: 'PACKET:in_1_sampled'
|
||||||
|
output_stream: 'PACKET:out_1'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
|
||||||
|
R"pb(
|
||||||
|
max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000 # 1s
|
||||||
|
)pb");
|
||||||
|
std::map<std::string, Packet> side_packets = {
|
||||||
|
{"limiter_options",
|
||||||
|
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||||
|
{"warmup_time", MakePacket<int64>(22000)},
|
||||||
|
{"sleep_time", MakePacket<int64>(22000)},
|
||||||
|
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start the graph.
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||||
|
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||||
|
out_1_packets_.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
std::vector<Packet> out_2_packets;
|
||||||
|
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
|
||||||
|
out_2_packets.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||||
|
allow_packets_.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
simulation_clock_->ThreadStart();
|
||||||
|
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||||
|
|
||||||
|
// Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2.
|
||||||
|
clock_->Sleep(absl::Microseconds(10000));
|
||||||
|
for (int i = 1; i < 10; ++i) {
|
||||||
|
if (i % 2 == 0) {
|
||||||
|
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||||
|
}
|
||||||
|
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i]));
|
||||||
|
clock_->Sleep(absl::Microseconds(10000));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish the graph run.
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||||
|
clock_->Sleep(absl::Microseconds(40000));
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
simulation_clock_->ThreadFinish();
|
||||||
|
|
||||||
|
// Validate the output.
|
||||||
|
// Input packets 4 and 8 are dropped due to max_in_flight.
|
||||||
|
std::vector<Packet> expected_output = {
|
||||||
|
input_packets_[2],
|
||||||
|
input_packets_[6],
|
||||||
|
};
|
||||||
|
EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
|
||||||
|
// Packets following input packets 2 and 6, and not input packets 4 and 8.
|
||||||
|
std::vector<Packet> expected_output_2 = {
|
||||||
|
input_packets_[1], input_packets_[2], input_packets_[3],
|
||||||
|
input_packets_[6], input_packets_[7],
|
||||||
|
};
|
||||||
|
EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
|
||||||
|
|
||||||
|
// Validate the ALLOW stream output.
|
||||||
|
std::vector<Packet> expected_allow =
|
||||||
|
MakePackets<bool>({{Timestamp(20000), 1},
|
||||||
|
{Timestamp(40000), 0},
|
||||||
|
{Timestamp(60000), 1},
|
||||||
|
{Timestamp(80000), 0}});
|
||||||
|
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
|
@ -14,45 +14,43 @@
|
||||||
|
|
||||||
#include "mediapipe/framework/deps/status_builder.h"
|
#include "mediapipe/framework/deps/status_builder.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
StatusBuilder::StatusBuilder(const StatusBuilder& sb) {
|
StatusBuilder::StatusBuilder(const StatusBuilder& sb)
|
||||||
status_ = sb.status_;
|
: impl_(sb.impl_ ? std::make_unique<Impl>(*sb.impl_) : nullptr) {}
|
||||||
file_ = sb.file_;
|
|
||||||
line_ = sb.line_;
|
|
||||||
no_logging_ = sb.no_logging_;
|
|
||||||
stream_ = sb.stream_
|
|
||||||
? absl::make_unique<std::ostringstream>(sb.stream_->str())
|
|
||||||
: nullptr;
|
|
||||||
join_style_ = sb.join_style_;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
|
StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) {
|
||||||
status_ = sb.status_;
|
if (!sb.impl_) {
|
||||||
file_ = sb.file_;
|
impl_ = nullptr;
|
||||||
line_ = sb.line_;
|
return *this;
|
||||||
no_logging_ = sb.no_logging_;
|
}
|
||||||
stream_ = sb.stream_
|
if (impl_) {
|
||||||
? absl::make_unique<std::ostringstream>(sb.stream_->str())
|
*impl_ = *sb.impl_;
|
||||||
: nullptr;
|
return *this;
|
||||||
join_style_ = sb.join_style_;
|
}
|
||||||
|
impl_ = std::make_unique<Impl>(*sb.impl_);
|
||||||
|
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder& StatusBuilder::SetAppend() & {
|
StatusBuilder& StatusBuilder::SetAppend() & {
|
||||||
if (status_.ok()) return *this;
|
if (!impl_) return *this;
|
||||||
join_style_ = MessageJoinStyle::kAppend;
|
impl_->join_style = Impl::MessageJoinStyle::kAppend;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); }
|
StatusBuilder&& StatusBuilder::SetAppend() && { return std::move(SetAppend()); }
|
||||||
|
|
||||||
StatusBuilder& StatusBuilder::SetPrepend() & {
|
StatusBuilder& StatusBuilder::SetPrepend() & {
|
||||||
if (status_.ok()) return *this;
|
if (!impl_) return *this;
|
||||||
join_style_ = MessageJoinStyle::kPrepend;
|
impl_->join_style = Impl::MessageJoinStyle::kPrepend;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +59,8 @@ StatusBuilder&& StatusBuilder::SetPrepend() && {
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder& StatusBuilder::SetNoLogging() & {
|
StatusBuilder& StatusBuilder::SetNoLogging() & {
|
||||||
no_logging_ = true;
|
if (!impl_) return *this;
|
||||||
|
impl_->no_logging = true;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,34 +69,72 @@ StatusBuilder&& StatusBuilder::SetNoLogging() && {
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder::operator Status() const& {
|
StatusBuilder::operator Status() const& {
|
||||||
if (!stream_ || stream_->str().empty() || no_logging_) {
|
|
||||||
return status_;
|
|
||||||
}
|
|
||||||
return StatusBuilder(*this).JoinMessageToStatus();
|
return StatusBuilder(*this).JoinMessageToStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusBuilder::operator Status() && {
|
StatusBuilder::operator Status() && { return JoinMessageToStatus(); }
|
||||||
if (!stream_ || stream_->str().empty() || no_logging_) {
|
|
||||||
return status_;
|
|
||||||
}
|
|
||||||
return JoinMessageToStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status StatusBuilder::JoinMessageToStatus() {
|
absl::Status StatusBuilder::JoinMessageToStatus() {
|
||||||
if (!stream_) {
|
if (!impl_) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
std::string message;
|
return impl_->JoinMessageToStatus();
|
||||||
if (join_style_ == MessageJoinStyle::kAnnotate) {
|
}
|
||||||
if (!status_.ok()) {
|
|
||||||
message = absl::StrCat(status_.message(), "; ", stream_->str());
|
absl::Status StatusBuilder::Impl::JoinMessageToStatus() {
|
||||||
|
if (stream.str().empty() || no_logging) {
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
} else {
|
return absl::Status(status.code(), [this]() {
|
||||||
message = join_style_ == MessageJoinStyle::kPrepend
|
switch (join_style) {
|
||||||
? absl::StrCat(stream_->str(), status_.message())
|
case MessageJoinStyle::kAnnotate:
|
||||||
: absl::StrCat(status_.message(), stream_->str());
|
return absl::StrCat(status.message(), "; ", stream.str());
|
||||||
|
case MessageJoinStyle::kAppend:
|
||||||
|
return absl::StrCat(status.message(), stream.str());
|
||||||
|
case MessageJoinStyle::kPrepend:
|
||||||
|
return absl::StrCat(stream.str(), status.message());
|
||||||
}
|
}
|
||||||
return Status(status_.code(), message);
|
}());
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusBuilder::Impl::Impl(const absl::Status& status, const char* file,
|
||||||
|
int line)
|
||||||
|
: status(status), line(line), file(file), stream() {}
|
||||||
|
|
||||||
|
StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line)
|
||||||
|
: status(std::move(status)), line(line), file(file), stream() {}
|
||||||
|
|
||||||
|
StatusBuilder::Impl::Impl(const absl::Status& status,
|
||||||
|
mediapipe::source_location location)
|
||||||
|
: status(status),
|
||||||
|
line(location.line()),
|
||||||
|
file(location.file_name()),
|
||||||
|
stream() {}
|
||||||
|
|
||||||
|
StatusBuilder::Impl::Impl(absl::Status&& status,
|
||||||
|
mediapipe::source_location location)
|
||||||
|
: status(std::move(status)),
|
||||||
|
line(location.line()),
|
||||||
|
file(location.file_name()),
|
||||||
|
stream() {}
|
||||||
|
|
||||||
|
StatusBuilder::Impl::Impl(const Impl& other)
|
||||||
|
: status(other.status),
|
||||||
|
line(other.line),
|
||||||
|
file(other.file),
|
||||||
|
no_logging(other.no_logging),
|
||||||
|
stream(other.stream.str()),
|
||||||
|
join_style(other.join_style) {}
|
||||||
|
|
||||||
|
StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) {
|
||||||
|
status = other.status;
|
||||||
|
line = other.line;
|
||||||
|
file = other.file;
|
||||||
|
no_logging = other.no_logging;
|
||||||
|
stream = std::ostringstream(other.stream.str());
|
||||||
|
join_style = other.join_style;
|
||||||
|
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include "absl/base/attributes.h"
|
#include "absl/base/attributes.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/deps/source_location.h"
|
#include "mediapipe/framework/deps/source_location.h"
|
||||||
#include "mediapipe/framework/deps/status.h"
|
#include "mediapipe/framework/deps/status.h"
|
||||||
|
@ -42,34 +41,37 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
||||||
// occurs. A typical user will call this with `MEDIAPIPE_LOC`.
|
// occurs. A typical user will call this with `MEDIAPIPE_LOC`.
|
||||||
StatusBuilder(const absl::Status& original_status,
|
StatusBuilder(const absl::Status& original_status,
|
||||||
mediapipe::source_location location)
|
mediapipe::source_location location)
|
||||||
: status_(original_status),
|
: impl_(original_status.ok()
|
||||||
line_(location.line()),
|
? nullptr
|
||||||
file_(location.file_name()),
|
: std::make_unique<Impl>(original_status, location)) {}
|
||||||
stream_(InitStream(status_)) {}
|
|
||||||
|
|
||||||
StatusBuilder(absl::Status&& original_status,
|
StatusBuilder(absl::Status&& original_status,
|
||||||
mediapipe::source_location location)
|
mediapipe::source_location location)
|
||||||
: status_(std::move(original_status)),
|
: impl_(original_status.ok()
|
||||||
line_(location.line()),
|
? nullptr
|
||||||
file_(location.file_name()),
|
: std::make_unique<Impl>(std::move(original_status),
|
||||||
stream_(InitStream(status_)) {}
|
location)) {}
|
||||||
|
|
||||||
// Creates a `StatusBuilder` from a mediapipe status code. If logging is
|
// Creates a `StatusBuilder` from a mediapipe status code. If logging is
|
||||||
// enabled, it will use `location` as the location from which the log message
|
// enabled, it will use `location` as the location from which the log message
|
||||||
// occurs. A typical user will call this with `MEDIAPIPE_LOC`.
|
// occurs. A typical user will call this with `MEDIAPIPE_LOC`.
|
||||||
StatusBuilder(absl::StatusCode code, mediapipe::source_location location)
|
StatusBuilder(absl::StatusCode code, mediapipe::source_location location)
|
||||||
: status_(code, ""),
|
: impl_(code == absl::StatusCode::kOk
|
||||||
line_(location.line()),
|
? nullptr
|
||||||
file_(location.file_name()),
|
: std::make_unique<Impl>(absl::Status(code, ""), location)) {}
|
||||||
stream_(InitStream(status_)) {}
|
|
||||||
|
|
||||||
StatusBuilder(const absl::Status& original_status, const char* file, int line)
|
StatusBuilder(const absl::Status& original_status, const char* file, int line)
|
||||||
: status_(original_status),
|
: impl_(original_status.ok()
|
||||||
line_(line),
|
? nullptr
|
||||||
file_(file),
|
: std::make_unique<Impl>(original_status, file, line)) {}
|
||||||
stream_(InitStream(status_)) {}
|
|
||||||
|
|
||||||
bool ok() const { return status_.ok(); }
|
StatusBuilder(absl::Status&& original_status, const char* file, int line)
|
||||||
|
: impl_(original_status.ok()
|
||||||
|
? nullptr
|
||||||
|
: std::make_unique<Impl>(std::move(original_status), file,
|
||||||
|
line)) {}
|
||||||
|
|
||||||
|
bool ok() const { return !impl_; }
|
||||||
|
|
||||||
StatusBuilder& SetAppend() &;
|
StatusBuilder& SetAppend() &;
|
||||||
StatusBuilder&& SetAppend() &&;
|
StatusBuilder&& SetAppend() &&;
|
||||||
|
@ -82,8 +84,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
StatusBuilder& operator<<(const T& msg) & {
|
StatusBuilder& operator<<(const T& msg) & {
|
||||||
if (!stream_) return *this;
|
if (!impl_) return *this;
|
||||||
*stream_ << msg;
|
impl_->stream << msg;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,6 +100,7 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
||||||
absl::Status JoinMessageToStatus();
|
absl::Status JoinMessageToStatus();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
struct Impl {
|
||||||
// Specifies how to join the error message in the original status and any
|
// Specifies how to join the error message in the original status and any
|
||||||
// additional message that has been streamed into the builder.
|
// additional message that has been streamed into the builder.
|
||||||
enum class MessageJoinStyle {
|
enum class MessageJoinStyle {
|
||||||
|
@ -106,27 +109,33 @@ class ABSL_MUST_USE_RESULT StatusBuilder {
|
||||||
kPrepend,
|
kPrepend,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Conditionally creates an ostringstream if the status is not ok.
|
Impl(const absl::Status& status, const char* file, int line);
|
||||||
static std::unique_ptr<std::ostringstream> InitStream(
|
Impl(absl::Status&& status, const char* file, int line);
|
||||||
const absl::Status status) {
|
Impl(const absl::Status& status, mediapipe::source_location location);
|
||||||
if (status.ok()) {
|
Impl(absl::Status&& status, mediapipe::source_location location);
|
||||||
return nullptr;
|
Impl(const Impl&);
|
||||||
}
|
Impl& operator=(const Impl&);
|
||||||
return absl::make_unique<std::ostringstream>();
|
|
||||||
}
|
absl::Status JoinMessageToStatus();
|
||||||
|
|
||||||
// The status that the result will be based on.
|
// The status that the result will be based on.
|
||||||
absl::Status status_;
|
absl::Status status;
|
||||||
// The line to record if this file is logged.
|
// The line to record if this file is logged.
|
||||||
int line_;
|
int line;
|
||||||
// Not-owned: The file to record if this status is logged.
|
// Not-owned: The file to record if this status is logged.
|
||||||
const char* file_;
|
const char* file;
|
||||||
bool no_logging_ = false;
|
// Logging disabled if true.
|
||||||
|
bool no_logging = false;
|
||||||
// The additional messages added with `<<`. This is nullptr when status_ is
|
// The additional messages added with `<<`. This is nullptr when status_ is
|
||||||
// ok.
|
// ok.
|
||||||
std::unique_ptr<std::ostringstream> stream_;
|
std::ostringstream stream;
|
||||||
// Specifies how to join the message in `status_` and `stream_`.
|
// Specifies how to join the message in `status_` and `stream_`.
|
||||||
MessageJoinStyle join_style_ = MessageJoinStyle::kAnnotate;
|
MessageJoinStyle join_style = MessageJoinStyle::kAnnotate;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Internal store of data for the class. An invariant of the class is that
|
||||||
|
// this is null when the original status is okay, and not-null otherwise.
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline StatusBuilder AlreadyExistsErrorBuilder(
|
inline StatusBuilder AlreadyExistsErrorBuilder(
|
||||||
|
|
|
@ -33,6 +33,21 @@ TEST(StatusBuilder, OkStatusRvalue) {
|
||||||
ASSERT_EQ(status, absl::OkStatus());
|
ASSERT_EQ(status, absl::OkStatus());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) {
|
||||||
|
absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234)
|
||||||
|
<< "annotated message1 "
|
||||||
|
<< "annotated message2";
|
||||||
|
ASSERT_EQ(status, absl::OkStatus());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) {
|
||||||
|
const auto original_status = absl::OkStatus();
|
||||||
|
absl::Status status = StatusBuilder(original_status, "hello.cc", 1234)
|
||||||
|
<< "annotated message1 "
|
||||||
|
<< "annotated message2";
|
||||||
|
ASSERT_EQ(status, absl::OkStatus());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StatusBuilder, AnnotateMode) {
|
TEST(StatusBuilder, AnnotateMode) {
|
||||||
absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
|
absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
|
||||||
"original message"),
|
"original message"),
|
||||||
|
@ -45,6 +60,30 @@ TEST(StatusBuilder, AnnotateMode) {
|
||||||
"original message; annotated message1 annotated message2");
|
"original message; annotated message1 annotated message2");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) {
|
||||||
|
absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound,
|
||||||
|
"original message"),
|
||||||
|
"hello.cc", 1234)
|
||||||
|
<< "annotated message1 "
|
||||||
|
<< "annotated message2";
|
||||||
|
ASSERT_FALSE(status.ok());
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_EQ(status.message(),
|
||||||
|
"original message; annotated message1 annotated message2");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) {
|
||||||
|
const auto original_status =
|
||||||
|
absl::Status(absl::StatusCode::kNotFound, "original message");
|
||||||
|
absl::Status status = StatusBuilder(original_status, "hello.cc", 1234)
|
||||||
|
<< "annotated message1 "
|
||||||
|
<< "annotated message2";
|
||||||
|
ASSERT_FALSE(status.ok());
|
||||||
|
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||||
|
EXPECT_EQ(status.message(),
|
||||||
|
"original message; annotated message1 annotated message2");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(StatusBuilder, PrependModeLvalue) {
|
TEST(StatusBuilder, PrependModeLvalue) {
|
||||||
StatusBuilder builder(
|
StatusBuilder builder(
|
||||||
absl::Status(absl::StatusCode::kInvalidArgument, "original message"),
|
absl::Status(absl::StatusCode::kInvalidArgument, "original message"),
|
||||||
|
|
|
@ -29,21 +29,16 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
||||||
class Classifier(custom_model.CustomModel):
|
class Classifier(custom_model.CustomModel):
|
||||||
"""An abstract base class that represents a TensorFlow classifier."""
|
"""An abstract base class that represents a TensorFlow classifier."""
|
||||||
|
|
||||||
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool,
|
def __init__(self, model_spec: Any, label_names: List[str], shuffle: bool):
|
||||||
full_train: bool):
|
|
||||||
"""Initilizes a classifier with its specifications.
|
"""Initilizes a classifier with its specifications.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_spec: Specification for the model.
|
model_spec: Specification for the model.
|
||||||
label_names: A list of label names for the classes.
|
label_names: A list of label names for the classes.
|
||||||
shuffle: Whether the dataset should be shuffled.
|
shuffle: Whether the dataset should be shuffled.
|
||||||
full_train: If true, train the model end-to-end including the backbone
|
|
||||||
and the classification layers on top. Otherwise, only train the top
|
|
||||||
classification layers.
|
|
||||||
"""
|
"""
|
||||||
super(Classifier, self).__init__(model_spec, shuffle)
|
super(Classifier, self).__init__(model_spec, shuffle)
|
||||||
self._label_names = label_names
|
self._label_names = label_names
|
||||||
self._full_train = full_train
|
|
||||||
self._num_classes = len(label_names)
|
self._num_classes = len(label_names)
|
||||||
|
|
||||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
|
|
|
@ -38,10 +38,7 @@ class ClassifierTest(tf.test.TestCase):
|
||||||
super(ClassifierTest, self).setUp()
|
super(ClassifierTest, self).setUp()
|
||||||
label_names = ['cat', 'dog']
|
label_names = ['cat', 'dog']
|
||||||
self.model = MockClassifier(
|
self.model = MockClassifier(
|
||||||
model_spec=None,
|
model_spec=None, label_names=label_names, shuffle=False)
|
||||||
label_names=label_names,
|
|
||||||
shuffle=False,
|
|
||||||
full_train=False)
|
|
||||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||||
|
|
||||||
def _check_nonempty_file(self, filepath):
|
def _check_nonempty_file(self, filepath):
|
||||||
|
|
|
@ -44,10 +44,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
hparams: The hyperparameters for training image classifier.
|
hparams: The hyperparameters for training image classifier.
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_spec=model_spec,
|
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
||||||
label_names=label_names,
|
|
||||||
shuffle=hparams.shuffle,
|
|
||||||
full_train=hparams.do_fine_tuning)
|
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
self._preprocess = image_preprocessing.Preprocessor(
|
self._preprocess = image_preprocessing.Preprocessor(
|
||||||
input_shape=self._model_spec.input_image_shape,
|
input_shape=self._model_spec.input_image_shape,
|
||||||
|
|
|
@ -93,6 +93,13 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||||
],
|
],
|
||||||
|
] + select({
|
||||||
|
# TODO: Build text_classifier_graph on Windows.
|
||||||
|
"//mediapipe:windows": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||||
|
],
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
|
@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
// TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
|
// TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
|
||||||
// as the input argument. Investigate why bazel non-optimized mode
|
// as the input argument. Investigate why bazel non-optimized mode
|
||||||
// triggers a memory allocation bug in Eigen::internal::aligned_free().
|
// triggers a memory allocation bug in Eigen::internal::aligned_free().
|
||||||
[](const Eigen::MatrixXf& matrix) {
|
[](const Eigen::MatrixXf& matrix, bool transpose) {
|
||||||
// MakePacket copies the data.
|
// MakePacket copies the data.
|
||||||
|
if (transpose) {
|
||||||
|
return MakePacket<Matrix>(matrix.transpose());
|
||||||
|
}
|
||||||
return MakePacket<Matrix>(matrix);
|
return MakePacket<Matrix>(matrix);
|
||||||
},
|
},
|
||||||
R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
|
R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
|
||||||
|
@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
matrix: A 2d numpy float ndarray.
|
matrix: A 2d numpy float ndarray.
|
||||||
|
transpose: A boolean to indicate if the input matrix needs to be transposed.
|
||||||
|
Default to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A MediaPipe Matrix Packet.
|
A MediaPipe Matrix Packet.
|
||||||
|
@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
np.array([[.1, .2, .3], [.4, .5, .6]])
|
np.array([[.1, .2, .3], [.4, .5, .6]])
|
||||||
matrix = mp.packet_getter.get_matrix(packet)
|
matrix = mp.packet_getter.get_matrix(packet)
|
||||||
)doc",
|
)doc",
|
||||||
|
py::arg("matrix"), py::arg("transpose") = false,
|
||||||
py::return_value_policy::move);
|
py::return_value_policy::move);
|
||||||
} // NOLINT(readability/fn_size)
|
} // NOLINT(readability/fn_size)
|
||||||
|
|
||||||
|
|
|
@ -31,13 +31,13 @@ namespace core {
|
||||||
|
|
||||||
// Base options for MediaPipe C++ Tasks.
|
// Base options for MediaPipe C++ Tasks.
|
||||||
struct BaseOptions {
|
struct BaseOptions {
|
||||||
// The model asset file contents as as string.
|
// The model asset file contents as a string.
|
||||||
std::unique_ptr<std::string> model_asset_buffer;
|
std::unique_ptr<std::string> model_asset_buffer;
|
||||||
|
|
||||||
// The path to the model asset to open and mmap in memory.
|
// The path to the model asset to open and mmap in memory.
|
||||||
std::string model_asset_path = "";
|
std::string model_asset_path = "";
|
||||||
|
|
||||||
// The delegate to run MediaPipe. If the delegate is not set, default
|
// The delegate to run MediaPipe. If the delegate is not set, the default
|
||||||
// delegate CPU is used.
|
// delegate CPU is used.
|
||||||
enum Delegate {
|
enum Delegate {
|
||||||
CPU = 0,
|
CPU = 0,
|
||||||
|
|
|
@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
||||||
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
|
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
|
||||||
kHandGesturesTag)];
|
kHandGesturesTag)];
|
||||||
|
|
||||||
return {{.gesture = hand_gestures,
|
return GestureRecognizerOutputs{
|
||||||
.handedness = handedness,
|
/*gesture=*/hand_gestures,
|
||||||
.hand_landmarks = hand_landmarks,
|
/*handedness=*/handedness,
|
||||||
.hand_world_landmarks = hand_world_landmarks,
|
/*hand_landmarks=*/hand_landmarks,
|
||||||
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}};
|
/*hand_world_landmarks=*/hand_world_landmarks,
|
||||||
|
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate normalized non rotated face bounding box
|
// Populate normalized non rotated face bounding box
|
||||||
return {.left = bounding_box_left,
|
return Rect{/*left=*/bounding_box_left,
|
||||||
.top = bounding_box_top,
|
/*top=*/bounding_box_top,
|
||||||
.right = bounding_box_right,
|
/*right=*/bounding_box_right,
|
||||||
.bottom = bounding_box_bottom};
|
/*bottom=*/bounding_box_bottom};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uses IoU and distance of some corresponding hand landmarks to detect
|
// Uses IoU and distance of some corresponding hand landmarks to detect
|
||||||
|
|
|
@ -48,7 +48,7 @@ public abstract class BaseOptions {
|
||||||
public abstract Builder setModelAssetBuffer(ByteBuffer value);
|
public abstract Builder setModelAssetBuffer(ByteBuffer value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default
|
* Sets device delegate to run the MediaPipe pipeline. If the delegate is not set, the default
|
||||||
* delegate CPU is used.
|
* delegate CPU is used.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setDelegate(Delegate delegate);
|
public abstract Builder setDelegate(Delegate delegate);
|
||||||
|
|
|
@ -68,7 +68,11 @@ class AudioData(object):
|
||||||
ValueError: If the input array has an incorrect shape or if
|
ValueError: If the input array has an incorrect shape or if
|
||||||
`offset` + `size` exceeds the length of the `src` array.
|
`offset` + `size` exceeds the length of the `src` array.
|
||||||
"""
|
"""
|
||||||
if src.shape[1] != self._audio_format.num_channels:
|
if len(src.shape) == 1:
|
||||||
|
if self._audio_format.num_channels != 1:
|
||||||
|
raise ValueError(f"Input audio is mono, but the audio data is expected "
|
||||||
|
f"to have {self._audio_format.num_channels} channels.")
|
||||||
|
elif src.shape[1] != self._audio_format.num_channels:
|
||||||
raise ValueError(f"Input audio contains an invalid number of channels. "
|
raise ValueError(f"Input audio contains an invalid number of channels. "
|
||||||
f"Expect {self._audio_format.num_channels}.")
|
f"Expect {self._audio_format.num_channels}.")
|
||||||
|
|
||||||
|
@ -93,6 +97,28 @@ class AudioData(object):
|
||||||
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||||
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_array(cls,
|
||||||
|
src: np.ndarray,
|
||||||
|
sample_rate: Optional[float] = None) -> "AudioData":
|
||||||
|
"""Creates an `AudioData` object from a NumPy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: A NumPy source array contains the input audio.
|
||||||
|
sample_rate: the optional audio sample rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `AudioData` object that contains a copy of the NumPy source array as
|
||||||
|
the data.
|
||||||
|
"""
|
||||||
|
obj = cls(
|
||||||
|
buffer_length=src.shape[0],
|
||||||
|
audio_format=AudioFormat(
|
||||||
|
num_channels=1 if len(src.shape) == 1 else src.shape[1],
|
||||||
|
sample_rate=sample_rate))
|
||||||
|
obj.load_from_array(src)
|
||||||
|
return obj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_format(self) -> AudioFormat:
|
def audio_format(self) -> AudioFormat:
|
||||||
"""Gets the audio format of the audio."""
|
"""Gets the audio format of the audio."""
|
||||||
|
|
|
@ -157,7 +157,7 @@ def _normalize_number_fields(pb):
|
||||||
descriptor.FieldDescriptor.TYPE_ENUM):
|
descriptor.FieldDescriptor.TYPE_ENUM):
|
||||||
normalized_values = [int(x) for x in values]
|
normalized_values = [int(x) for x in values]
|
||||||
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
|
||||||
normalized_values = [round(x, 5) for x in values]
|
normalized_values = [round(x, 4) for x in values]
|
||||||
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
|
elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
|
||||||
normalized_values = [round(float(x), 6) for x in values]
|
normalized_values = [round(float(x), 6) for x in values]
|
||||||
|
|
||||||
|
|
36
mediapipe/tasks/python/test/text/BUILD
Normal file
36
mediapipe/tasks/python/test/text/BUILD
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
srcs = ["text_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||||
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/text:text_classifier",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/tasks/python/test/text/__init__.py
Normal file
13
mediapipe/tasks/python/test/text/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
244
mediapipe/tasks/python/test/text/text_classifier_test.py
Normal file
244
mediapipe/tasks/python/test/text/text_classifier_test.py
Normal file
|
@ -0,0 +1,244 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for text classifier."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import os
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.containers import category
|
||||||
|
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.text import text_classifier
|
||||||
|
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
_Category = category.Category
|
||||||
|
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||||
|
_Classifications = classifications_module.Classifications
|
||||||
|
_TextClassifierResult = classifications_module.ClassificationResult
|
||||||
|
_TextClassifier = text_classifier.TextClassifier
|
||||||
|
_TextClassifierOptions = text_classifier.TextClassifierOptions
|
||||||
|
|
||||||
|
_BERT_MODEL_FILE = 'bert_text_classifier.tflite'
|
||||||
|
_REGEX_MODEL_FILE = 'test_model_text_classifier_with_regex_tokenizer.tflite'
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||||
|
|
||||||
|
_NEGATIVE_TEXT = 'What a waste of my time.'
|
||||||
|
_POSITIVE_TEXT = ('This is the best movie I’ve seen in recent years.'
|
||||||
|
'Strongly recommend it!')
|
||||||
|
|
||||||
|
_BERT_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
_Category(
|
||||||
|
index=0,
|
||||||
|
score=0.999479,
|
||||||
|
display_name='',
|
||||||
|
category_name='negative'),
|
||||||
|
_Category(
|
||||||
|
index=1,
|
||||||
|
score=0.00052154,
|
||||||
|
display_name='',
|
||||||
|
category_name='positive')
|
||||||
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
_BERT_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
_Category(
|
||||||
|
index=1,
|
||||||
|
score=0.999466,
|
||||||
|
display_name='',
|
||||||
|
category_name='positive'),
|
||||||
|
_Category(
|
||||||
|
index=0,
|
||||||
|
score=0.000533596,
|
||||||
|
display_name='',
|
||||||
|
category_name='negative')
|
||||||
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
_REGEX_NEGATIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
_Category(
|
||||||
|
index=0,
|
||||||
|
score=0.81313,
|
||||||
|
display_name='',
|
||||||
|
category_name='Negative'),
|
||||||
|
_Category(
|
||||||
|
index=1,
|
||||||
|
score=0.1868704,
|
||||||
|
display_name='',
|
||||||
|
category_name='Positive')
|
||||||
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
_REGEX_POSITIVE_RESULTS = _TextClassifierResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
_Category(
|
||||||
|
index=1,
|
||||||
|
score=0.5134273,
|
||||||
|
display_name='',
|
||||||
|
category_name='Positive'),
|
||||||
|
_Category(
|
||||||
|
index=0,
|
||||||
|
score=0.486573,
|
||||||
|
display_name='',
|
||||||
|
category_name='Negative')
|
||||||
|
],
|
||||||
|
timestamp_ms=0)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE))
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _TextClassifier.create_from_model_path(self.model_path) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _TextClassifier)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _TextClassifierOptions(base_options=base_options)
|
||||||
|
with _TextClassifier.create_from_options(options) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _TextClassifier)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
# Invalid empty model path.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
r"ExternalFile must specify at least one of 'file_content', "
|
||||||
|
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||||
|
base_options = _BaseOptions(model_asset_path='')
|
||||||
|
options = _TextClassifierOptions(base_options=base_options)
|
||||||
|
_TextClassifier.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _TextClassifierOptions(base_options=base_options)
|
||||||
|
classifier = _TextClassifier.create_from_options(options)
|
||||||
|
self.assertIsInstance(classifier, _TextClassifier)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _NEGATIVE_TEXT,
|
||||||
|
_BERT_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
|
||||||
|
_NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS),
|
||||||
|
(ModelFileType.FILE_NAME, _BERT_MODEL_FILE, _POSITIVE_TEXT,
|
||||||
|
_BERT_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
|
||||||
|
_POSITIVE_TEXT, _BERT_POSITIVE_RESULTS),
|
||||||
|
(ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _NEGATIVE_TEXT,
|
||||||
|
_REGEX_NEGATIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE,
|
||||||
|
_NEGATIVE_TEXT, _REGEX_NEGATIVE_RESULTS),
|
||||||
|
(ModelFileType.FILE_NAME, _REGEX_MODEL_FILE, _POSITIVE_TEXT,
|
||||||
|
_REGEX_POSITIVE_RESULTS), (ModelFileType.FILE_CONTENT, _REGEX_MODEL_FILE,
|
||||||
|
_POSITIVE_TEXT, _REGEX_POSITIVE_RESULTS))
|
||||||
|
def test_classify(self, model_file_type, model_name, text,
|
||||||
|
expected_classification_result):
|
||||||
|
# Creates classifier.
|
||||||
|
model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, model_name))
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _TextClassifierOptions(base_options=base_options)
|
||||||
|
classifier = _TextClassifier.create_from_options(options)
|
||||||
|
|
||||||
|
# Performs text classification on the input.
|
||||||
|
text_result = classifier.classify(text)
|
||||||
|
# Comparing results.
|
||||||
|
test_utils.assert_proto_equals(self, text_result.to_pb2(),
|
||||||
|
expected_classification_result.to_pb2())
|
||||||
|
# Closes the classifier explicitly when the classifier is not used in
|
||||||
|
# a context.
|
||||||
|
classifier.close()
|
||||||
|
|
||||||
|
@parameterized.parameters((ModelFileType.FILE_NAME, _BERT_MODEL_FILE,
|
||||||
|
_NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS),
|
||||||
|
(ModelFileType.FILE_CONTENT, _BERT_MODEL_FILE,
|
||||||
|
_NEGATIVE_TEXT, _BERT_NEGATIVE_RESULTS))
|
||||||
|
def test_classify_in_context(self, model_file_type, model_name, text,
|
||||||
|
expected_classification_result):
|
||||||
|
# Creates classifier.
|
||||||
|
model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, model_name))
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _TextClassifierOptions(base_options=base_options)
|
||||||
|
|
||||||
|
with _TextClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs text classification on the input.
|
||||||
|
text_result = classifier.classify(text)
|
||||||
|
# Comparing results.
|
||||||
|
test_utils.assert_proto_equals(self, text_result.to_pb2(),
|
||||||
|
expected_classification_result.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
38
mediapipe/tasks/python/text/BUILD
Normal file
38
mediapipe/tasks/python/text/BUILD
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "text_classifier",
|
||||||
|
srcs = [
|
||||||
|
"text_classifier.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/text/core:base_text_task_api",
|
||||||
|
],
|
||||||
|
)
|
13
mediapipe/tasks/python/text/__init__.py
Normal file
13
mediapipe/tasks/python/text/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
31
mediapipe/tasks/python/text/core/BUILD
Normal file
31
mediapipe/tasks/python/text/core/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "base_text_task_api",
|
||||||
|
srcs = [
|
||||||
|
"base_text_task_api.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
16
mediapipe/tasks/python/text/core/__init__.py
Normal file
16
mediapipe/tasks/python/text/core/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
"""Copyright 2022 The MediaPipe Authors.
|
||||||
|
|
||||||
|
All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
55
mediapipe/tasks/python/text/core/base_text_task_api.py
Normal file
55
mediapipe/tasks/python/text/core/base_text_task_api.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe text task base api."""
|
||||||
|
|
||||||
|
from mediapipe.framework import calculator_pb2
|
||||||
|
from mediapipe.python._framework_bindings import task_runner
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_TaskRunner = task_runner.TaskRunner
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTextTaskApi(object):
|
||||||
|
"""The base class of the user-facing mediapipe text task api classes."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
graph_config: calculator_pb2.CalculatorGraphConfig) -> None:
|
||||||
|
"""Initializes the `BaseVisionTaskApi` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph_config: The mediapipe text task graph config proto.
|
||||||
|
"""
|
||||||
|
self._runner = _TaskRunner.create(graph_config)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Shuts down the mediapipe text task instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the mediapipe text task failed to close.
|
||||||
|
"""
|
||||||
|
self._runner.close()
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def __enter__(self):
|
||||||
|
"""Returns `self` upon entering the runtime context."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
|
||||||
|
"""Shuts down the mediapipe text task instance on exit of the context manager.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If the mediapipe text task failed to close.
|
||||||
|
"""
|
||||||
|
self.close()
|
140
mediapipe/tasks/python/text/text_classifier.py
Normal file
140
mediapipe/tasks/python/text/text_classifier.py
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe text classifier task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.cc.text.text_classifier.proto import text_classifier_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import classifications
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.text.core import base_text_task_api
|
||||||
|
|
||||||
|
TextClassifierResult = classifications.ClassificationResult
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_TextClassifierGraphOptionsProto = text_classifier_graph_options_pb2.TextClassifierGraphOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||||
|
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||||
|
_TEXT_IN_STREAM_NAME = 'text_in'
|
||||||
|
_TEXT_TAG = 'TEXT'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class TextClassifierOptions:
|
||||||
|
"""Options for the text classifier task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the text classifier task.
|
||||||
|
classifier_options: Options for the text classification task.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _TextClassifierGraphOptionsProto:
|
||||||
|
"""Generates an TextClassifierOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
classifier_options_proto = self.classifier_options.to_pb2()
|
||||||
|
|
||||||
|
return _TextClassifierGraphOptionsProto(
|
||||||
|
base_options=base_options_proto,
|
||||||
|
classifier_options=classifier_options_proto)
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifier(base_text_task_api.BaseTextTaskApi):
|
||||||
|
"""Class that performs classification on text."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'TextClassifier':
|
||||||
|
"""Creates an `TextClassifier` object from a TensorFlow Lite model and the default `TextClassifierOptions`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`TextClassifier` object that's created from the model file and the
|
||||||
|
default `TextClassifierOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `TextClassifier` object from the provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = TextClassifierOptions(base_options=base_options)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls,
|
||||||
|
options: TextClassifierOptions) -> 'TextClassifier':
|
||||||
|
"""Creates the `TextClassifier` object from text classifier options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the text classifier task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`TextClassifier` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `TextClassifier` object from
|
||||||
|
`TextClassifierOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[':'.join([_TEXT_TAG, _TEXT_IN_STREAM_NAME])],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([
|
||||||
|
_CLASSIFICATION_RESULT_TAG,
|
||||||
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
||||||
|
])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(task_info.generate_graph_config())
|
||||||
|
|
||||||
|
def classify(self, text: str) -> TextClassifierResult:
|
||||||
|
"""Performs classification on the input `text`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The input text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `TextClassifierResult` object that contains a list of text
|
||||||
|
classifications.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If text classification failed to run.
|
||||||
|
"""
|
||||||
|
output_packets = self._runner.process(
|
||||||
|
{_TEXT_IN_STREAM_NAME: packet_creator.create_string(text)})
|
||||||
|
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(
|
||||||
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
|
return TextClassifierResult([
|
||||||
|
classifications.Classifications.create_from_pb2(classification)
|
||||||
|
for classification in classification_result_proto.classifications
|
||||||
|
])
|
33
mediapipe/tasks/web/audio/audio_classifier/BUILD
Normal file
33
mediapipe/tasks/web/audio/audio_classifier/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# This contains the MediaPipe Audio Classifier Task.
|
||||||
|
#
|
||||||
|
# This task takes audio data and outputs the classification result.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "audio_classifier",
|
||||||
|
srcs = [
|
||||||
|
"audio_classifier.ts",
|
||||||
|
"audio_classifier_options.ts",
|
||||||
|
"audio_classifier_result.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
215
mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts
Normal file
215
mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb';
|
||||||
|
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||||
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
|
||||||
|
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {AudioClassifierOptions} from './audio_classifier_options';
|
||||||
|
import {Classifications} from './audio_classifier_result';
|
||||||
|
|
||||||
|
const MEDIAPIPE_GRAPH =
|
||||||
|
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||||
|
|
||||||
|
// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and
|
||||||
|
// cannot be changed
|
||||||
|
// TODO: Change this to `audio_in` to match the name in the CC
|
||||||
|
// implementation
|
||||||
|
const AUDIO_STREAM = 'input_audio';
|
||||||
|
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||||
|
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/** Performs audio classification. */
|
||||||
|
export class AudioClassifier extends TaskRunner {
|
||||||
|
private classifications: Classifications[] = [];
|
||||||
|
private defaultSampleRate = 48000;
|
||||||
|
private readonly options = new AudioClassifierGraphOptions();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new audio classifier from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param audioClassifierOptions The options for the audio classifier. Note
|
||||||
|
* that either a path to the model asset or a model buffer needs to be
|
||||||
|
* provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static async createFromOptions(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
audioClassifierOptions: AudioClassifierOptions):
|
||||||
|
Promise<AudioClassifier> {
|
||||||
|
// Create a file locator based on the loader options
|
||||||
|
const fileLocator: FileLocator = {
|
||||||
|
locateFile() {
|
||||||
|
// The only file loaded with this mechanism is the Wasm binary
|
||||||
|
return wasmLoaderOptions.wasmBinaryPath.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const classifier = await createMediaPipeLib(
|
||||||
|
AudioClassifier, wasmLoaderOptions.wasmLoaderPath,
|
||||||
|
/* assetLoaderScript= */ undefined,
|
||||||
|
/* glCanvas= */ undefined, fileLocator);
|
||||||
|
await classifier.setOptions(audioClassifierOptions);
|
||||||
|
return classifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new audio classifier based on
|
||||||
|
* the provided model asset buffer.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
|
||||||
|
return AudioClassifier.createFromOptions(
|
||||||
|
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new audio classifier based on
|
||||||
|
* the path to the model asset.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static async createFromModelPath(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetPath: string): Promise<AudioClassifier> {
|
||||||
|
const response = await fetch(modelAssetPath.toString());
|
||||||
|
const graphData = await response.arrayBuffer();
|
||||||
|
return AudioClassifier.createFromModelBuffer(
|
||||||
|
wasmLoaderOptions, new Uint8Array(graphData));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the audio classifier.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those options.
|
||||||
|
* You can reset an option back to its default value by explicitly setting it
|
||||||
|
* to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the audio classifier.
|
||||||
|
*/
|
||||||
|
async setOptions(options: AudioClassifierOptions): Promise<void> {
|
||||||
|
if (options.baseOptions) {
|
||||||
|
const baseOptionsProto =
|
||||||
|
await convertBaseOptionsToProto(options.baseOptions);
|
||||||
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
|
options, this.options.getClassifierOptions()));
|
||||||
|
this.refreshGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the sample rate for all calls to `classify()` that omit an explicit
|
||||||
|
* sample rate. `48000` is used as a default if this method is not called.
|
||||||
|
*
|
||||||
|
* @param sampleRate A sample rate (e.g. `44100`).
|
||||||
|
*/
|
||||||
|
setDefaultSampleRate(sampleRate: number) {
|
||||||
|
this.defaultSampleRate = sampleRate;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs audio classification on the provided audio data and waits
|
||||||
|
* synchronously for the response.
|
||||||
|
*
|
||||||
|
* @param audioData An array of raw audio capture data, like
|
||||||
|
* from a call to getChannelData on an AudioBuffer.
|
||||||
|
* @param sampleRate The sample rate in Hz of the provided audio data. If not
|
||||||
|
* set, defaults to the sample rate set via `setDefaultSampleRate()` or
|
||||||
|
* `48000` if no custom default was set.
|
||||||
|
* @return The classification result of the audio datas
|
||||||
|
*/
|
||||||
|
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
|
||||||
|
sampleRate = sampleRate ?? this.defaultSampleRate;
|
||||||
|
|
||||||
|
// Configures the number of samples in the WASM layer. We re-configure the
|
||||||
|
// number of samples and the sample rate for every frame, but ignore other
|
||||||
|
// side effects of this function (such as sending the input side packet and
|
||||||
|
// the input stream header).
|
||||||
|
this.configureAudio(
|
||||||
|
/* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate);
|
||||||
|
|
||||||
|
const timestamp = performance.now();
|
||||||
|
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
||||||
|
this.addAudioToStream(audioData, timestamp);
|
||||||
|
|
||||||
|
this.classifications = [];
|
||||||
|
this.finishProcessing();
|
||||||
|
return [...this.classifications];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal function for converting raw data into a classification, and
|
||||||
|
* adding it to our classfications list.
|
||||||
|
**/
|
||||||
|
private addJsAudioClassification(binaryProto: Uint8Array): void {
|
||||||
|
const classificationResult =
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto);
|
||||||
|
this.classifications.push(
|
||||||
|
...convertFromClassificationResultProto(classificationResult));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
private refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
|
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
AudioClassifierGraphOptions.ext, this.options);
|
||||||
|
|
||||||
|
// Perform audio classification. Pre-processing and results post-processing
|
||||||
|
// are built-in.
|
||||||
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
|
classifierNode.setCalculator(MEDIAPIPE_GRAPH);
|
||||||
|
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
||||||
|
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
||||||
|
classifierNode.addOutputStream(
|
||||||
|
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||||
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
|
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||||
|
this.addJsAudioClassification(binaryProto);
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
|
@ -0,0 +1,18 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
|
@ -29,31 +29,31 @@ export function convertClassifierOptionsToProto(
|
||||||
baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
|
baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
|
||||||
const classifierOptions =
|
const classifierOptions =
|
||||||
baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
|
baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
|
||||||
if (options.displayNamesLocale) {
|
if (options.displayNamesLocale !== undefined) {
|
||||||
classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
|
classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
|
||||||
} else if (options.displayNamesLocale === undefined) {
|
} else if (options.displayNamesLocale === undefined) {
|
||||||
classifierOptions.clearDisplayNamesLocale();
|
classifierOptions.clearDisplayNamesLocale();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.maxResults) {
|
if (options.maxResults !== undefined) {
|
||||||
classifierOptions.setMaxResults(options.maxResults);
|
classifierOptions.setMaxResults(options.maxResults);
|
||||||
} else if ('maxResults' in options) { // Check for undefined
|
} else if ('maxResults' in options) { // Check for undefined
|
||||||
classifierOptions.clearMaxResults();
|
classifierOptions.clearMaxResults();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.scoreThreshold) {
|
if (options.scoreThreshold !== undefined) {
|
||||||
classifierOptions.setScoreThreshold(options.scoreThreshold);
|
classifierOptions.setScoreThreshold(options.scoreThreshold);
|
||||||
} else if ('scoreThreshold' in options) { // Check for undefined
|
} else if ('scoreThreshold' in options) { // Check for undefined
|
||||||
classifierOptions.clearScoreThreshold();
|
classifierOptions.clearScoreThreshold();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.categoryAllowlist) {
|
if (options.categoryAllowlist !== undefined) {
|
||||||
classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
|
classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
|
||||||
} else if ('categoryAllowlist' in options) { // Check for undefined
|
} else if ('categoryAllowlist' in options) { // Check for undefined
|
||||||
classifierOptions.clearCategoryAllowlistList();
|
classifierOptions.clearCategoryAllowlistList();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.categoryDenylist) {
|
if (options.categoryDenylist !== undefined) {
|
||||||
classifierOptions.setCategoryDenylistList(options.categoryDenylist);
|
classifierOptions.setCategoryDenylistList(options.categoryDenylist);
|
||||||
} else if ('categoryDenylist' in options) { // Check for undefined
|
} else if ('categoryDenylist' in options) { // Check for undefined
|
||||||
classifierOptions.clearCategoryDenylistList();
|
classifierOptions.clearCategoryDenylistList();
|
||||||
|
|
|
@ -12,6 +12,18 @@ mediapipe_ts_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "task_runner",
|
||||||
|
srcs = [
|
||||||
|
"task_runner.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "classifier_options",
|
name = "classifier_options",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
83
mediapipe/tasks/web/core/task_runner.ts
Normal file
83
mediapipe/tasks/web/core/task_runner.ts
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||||
|
import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib';
|
||||||
|
import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib';
|
||||||
|
|
||||||
|
// tslint:disable-next-line:enforce-name-casing
|
||||||
|
const WasmMediaPipeImageLib =
|
||||||
|
SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib));
|
||||||
|
|
||||||
|
/** Base class for all MediaPipe Tasks. */
|
||||||
|
export abstract class TaskRunner extends WasmMediaPipeImageLib {
|
||||||
|
private processingErrors: Error[] = [];
|
||||||
|
|
||||||
|
constructor(wasmModule: WasmModule) {
|
||||||
|
super(wasmModule);
|
||||||
|
|
||||||
|
// Disables the automatic render-to-screen code, which allows for pure
|
||||||
|
// CPU processing.
|
||||||
|
this.setAutoRenderToScreen(false);
|
||||||
|
|
||||||
|
// Enables use of our model resource caching graph service.
|
||||||
|
this.registerModelResourcesGraphService();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||||
|
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||||
|
* if there is one.
|
||||||
|
* @param graphData The raw MediaPipe graph data, either in binary
|
||||||
|
* protobuffer format (.binarypb), or else in raw text format (.pbtxt or
|
||||||
|
* .textproto).
|
||||||
|
* @param isBinary This should be set to true if the graph is in
|
||||||
|
* binary format, and false if it is in human-readable text format.
|
||||||
|
*/
|
||||||
|
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||||
|
this.attachErrorListener((code, message) => {
|
||||||
|
this.processingErrors.push(new Error(message));
|
||||||
|
});
|
||||||
|
super.setGraph(graphData, isBinary);
|
||||||
|
this.handleErrors();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Forces all queued-up packets to be pushed through the MediaPipe graph as
|
||||||
|
* far as possible, performing all processing until no more processing can be
|
||||||
|
* done.
|
||||||
|
*/
|
||||||
|
override finishProcessing(): void {
|
||||||
|
super.finishProcessing();
|
||||||
|
this.handleErrors();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Throws the error from the error listener if an error was raised. */
|
||||||
|
private handleErrors() {
|
||||||
|
const errorCount = this.processingErrors.length;
|
||||||
|
if (errorCount === 1) {
|
||||||
|
// Re-throw error to get a more meaningful stacktrace
|
||||||
|
throw new Error(this.processingErrors[0].message);
|
||||||
|
} else if (errorCount > 1) {
|
||||||
|
throw new Error(
|
||||||
|
'Encountered multiple errors: ' +
|
||||||
|
this.processingErrors.map(e => e.message).join(', '));
|
||||||
|
}
|
||||||
|
this.processingErrors = [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
34
mediapipe/tasks/web/text/text_classifier/BUILD
Normal file
34
mediapipe/tasks/web/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# This contains the MediaPipe Text Classifier Task.
|
||||||
|
#
|
||||||
|
# This task takes text input performs Natural Language classification (including
|
||||||
|
# BERT-based text classification).
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "text_classifier",
|
||||||
|
srcs = [
|
||||||
|
"text_classifier.ts",
|
||||||
|
"text_classifier_options.d.ts",
|
||||||
|
"text_classifier_result.d.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
180
mediapipe/tasks/web/text/text_classifier/text_classifier.ts
Normal file
180
mediapipe/tasks/web/text/text_classifier/text_classifier.ts
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
|
||||||
|
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||||
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
|
||||||
|
import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {TextClassifierOptions} from './text_classifier_options';
|
||||||
|
import {Classifications} from './text_classifier_result';
|
||||||
|
|
||||||
|
const INPUT_STREAM = 'text_in';
|
||||||
|
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
|
||||||
|
const TEXT_CLASSIFIER_GRAPH =
|
||||||
|
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/** Performs Natural Language classification. */
|
||||||
|
export class TextClassifier extends TaskRunner {
|
||||||
|
private classifications: Classifications[] = [];
|
||||||
|
private readonly options = new TextClassifierGraphOptions();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new text classifier from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param textClassifierOptions The options for the text classifier. Note that
|
||||||
|
* either a path to the TFLite model or the model itself needs to be
|
||||||
|
* provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static async createFromOptions(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
|
||||||
|
// Create a file locator based on the loader options
|
||||||
|
const fileLocator: FileLocator = {
|
||||||
|
locateFile() {
|
||||||
|
// The only file we load is the Wasm binary
|
||||||
|
return wasmLoaderOptions.wasmBinaryPath.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const classifier = await createMediaPipeLib(
|
||||||
|
TextClassifier, wasmLoaderOptions.wasmLoaderPath,
|
||||||
|
/* assetLoaderScript= */ undefined,
|
||||||
|
/* glCanvas= */ undefined, fileLocator);
|
||||||
|
await classifier.setOptions(textClassifierOptions);
|
||||||
|
return classifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new text classifier based on the
|
||||||
|
* provided model asset buffer.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
|
||||||
|
return TextClassifier.createFromOptions(
|
||||||
|
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new text classifier based on the
|
||||||
|
* path to the model asset.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static async createFromModelPath(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetPath: string): Promise<TextClassifier> {
|
||||||
|
const response = await fetch(modelAssetPath.toString());
|
||||||
|
const graphData = await response.arrayBuffer();
|
||||||
|
return TextClassifier.createFromModelBuffer(
|
||||||
|
wasmLoaderOptions, new Uint8Array(graphData));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the text classifier.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those options.
|
||||||
|
* You can reset an option back to its default value by explicitly setting it
|
||||||
|
* to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the text classifier.
|
||||||
|
*/
|
||||||
|
async setOptions(options: TextClassifierOptions): Promise<void> {
|
||||||
|
if (options.baseOptions) {
|
||||||
|
const baseOptionsProto =
|
||||||
|
await convertBaseOptionsToProto(options.baseOptions);
|
||||||
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
|
options, this.options.getClassifierOptions()));
|
||||||
|
this.refreshGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs Natural Language classification on the provided text and waits
|
||||||
|
* synchronously for the response.
|
||||||
|
*
|
||||||
|
* @param text The text to process.
|
||||||
|
* @return The classification result of the text
|
||||||
|
*/
|
||||||
|
classify(text: string): Classifications[] {
|
||||||
|
// Get classification classes by running our MediaPipe graph.
|
||||||
|
this.classifications = [];
|
||||||
|
this.addStringToStream(
|
||||||
|
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||||
|
this.finishProcessing();
|
||||||
|
return [...this.classifications];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal function for converting raw data into a classification, and
|
||||||
|
// adding it to our classifications list.
|
||||||
|
private addJsTextClassification(binaryProto: Uint8Array): void {
|
||||||
|
const classificationResult =
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto);
|
||||||
|
console.log(classificationResult.toObject());
|
||||||
|
this.classifications.push(
|
||||||
|
...convertFromClassificationResultProto(classificationResult));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
private refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
|
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
TextClassifierGraphOptions.ext, this.options);
|
||||||
|
|
||||||
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
|
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
||||||
|
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
||||||
|
classifierNode.addOutputStream(
|
||||||
|
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||||
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
|
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||||
|
this.addJsTextClassification(binaryProto);
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
17
mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts
vendored
Normal file
17
mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts
vendored
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
18
mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts
vendored
Normal file
18
mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts
vendored
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
40
mediapipe/tasks/web/vision/gesture_recognizer/BUILD
Normal file
40
mediapipe/tasks/web/vision/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# This contains the MediaPipe Gesture Recognizer Task.
|
||||||
|
#
|
||||||
|
# This task takes video frames and outputs synchronized frames along with
|
||||||
|
# the detection results for one or more gesture categories, using Gesture Recognizer.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "gesture_recognizer",
|
||||||
|
srcs = [
|
||||||
|
"gesture_recognizer.ts",
|
||||||
|
"gesture_recognizer_options.d.ts",
|
||||||
|
"gesture_recognizer_result.d.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:rect_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/containers:landmark",
|
||||||
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,374 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
|
||||||
|
import {NormalizedRect} from '../../../../framework/formats/rect_pb';
|
||||||
|
import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb';
|
||||||
|
import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb';
|
||||||
|
import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb';
|
||||||
|
import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb';
|
||||||
|
import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb';
|
||||||
|
import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb';
|
||||||
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
import {Landmark} from '../../../../tasks/web/components/containers/landmark';
|
||||||
|
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||||
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
|
||||||
|
import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {GestureRecognizerOptions} from './gesture_recognizer_options';
|
||||||
|
import {GestureRecognitionResult} from './gesture_recognizer_result';
|
||||||
|
|
||||||
|
export {ImageSource};
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
const IMAGE_STREAM = 'image_in';
|
||||||
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
|
const HAND_GESTURES_STREAM = 'hand_gestures';
|
||||||
|
const LANDMARKS_STREAM = 'hand_landmarks';
|
||||||
|
const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks';
|
||||||
|
const HANDEDNESS_STREAM = 'handedness';
|
||||||
|
const GESTURE_RECOGNIZER_GRAPH =
|
||||||
|
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph';
|
||||||
|
|
||||||
|
const DEFAULT_NUM_HANDS = 1;
|
||||||
|
const DEFAULT_SCORE_THRESHOLD = 0.5;
|
||||||
|
const DEFAULT_CATEGORY_INDEX = -1;
|
||||||
|
|
||||||
|
const FULL_IMAGE_RECT = new NormalizedRect();
|
||||||
|
FULL_IMAGE_RECT.setXCenter(0.5);
|
||||||
|
FULL_IMAGE_RECT.setYCenter(0.5);
|
||||||
|
FULL_IMAGE_RECT.setWidth(1);
|
||||||
|
FULL_IMAGE_RECT.setHeight(1);
|
||||||
|
|
||||||
|
/** Performs hand gesture recognition on images. */
|
||||||
|
export class GestureRecognizer extends TaskRunner {
|
||||||
|
private gestures: Category[][] = [];
|
||||||
|
private landmarks: Landmark[][] = [];
|
||||||
|
private worldLandmarks: Landmark[][] = [];
|
||||||
|
private handednesses: Category[][] = [];
|
||||||
|
|
||||||
|
private readonly options: GestureRecognizerGraphOptions;
|
||||||
|
private readonly handLandmarkerGraphOptions: HandLandmarkerGraphOptions;
|
||||||
|
private readonly handLandmarksDetectorGraphOptions:
|
||||||
|
HandLandmarksDetectorGraphOptions;
|
||||||
|
private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
|
||||||
|
private readonly handGestureRecognizerGraphOptions:
|
||||||
|
HandGestureRecognizerGraphOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new gesture recognizer from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param gestureRecognizerOptions The options for the gesture recognizer.
|
||||||
|
* Note that either a path to the model asset or a model buffer needs to
|
||||||
|
* be provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static async createFromOptions(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
gestureRecognizerOptions: GestureRecognizerOptions):
|
||||||
|
Promise<GestureRecognizer> {
|
||||||
|
// Create a file locator based on the loader options
|
||||||
|
const fileLocator: FileLocator = {
|
||||||
|
locateFile() {
|
||||||
|
// The only file we load via this mechanism is the Wasm binary
|
||||||
|
return wasmLoaderOptions.wasmBinaryPath.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const recognizer = await createMediaPipeLib(
|
||||||
|
GestureRecognizer, wasmLoaderOptions.wasmLoaderPath,
|
||||||
|
/* assetLoaderScript= */ undefined,
|
||||||
|
/* glCanvas= */ undefined, fileLocator);
|
||||||
|
await recognizer.setOptions(gestureRecognizerOptions);
|
||||||
|
return recognizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new gesture recognizer based on
|
||||||
|
* the provided model asset buffer.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
|
||||||
|
return GestureRecognizer.createFromOptions(
|
||||||
|
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new gesture recognizer based on
|
||||||
|
* the path to the model asset.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static async createFromModelPath(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetPath: string): Promise<GestureRecognizer> {
|
||||||
|
const response = await fetch(modelAssetPath.toString());
|
||||||
|
const graphData = await response.arrayBuffer();
|
||||||
|
return GestureRecognizer.createFromModelBuffer(
|
||||||
|
wasmLoaderOptions, new Uint8Array(graphData));
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor(wasmModule: WasmModule) {
|
||||||
|
super(wasmModule);
|
||||||
|
|
||||||
|
this.options = new GestureRecognizerGraphOptions();
|
||||||
|
this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions();
|
||||||
|
this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions);
|
||||||
|
this.handLandmarksDetectorGraphOptions =
|
||||||
|
new HandLandmarksDetectorGraphOptions();
|
||||||
|
this.handLandmarkerGraphOptions.setHandLandmarksDetectorGraphOptions(
|
||||||
|
this.handLandmarksDetectorGraphOptions);
|
||||||
|
this.handDetectorGraphOptions = new HandDetectorGraphOptions();
|
||||||
|
this.handLandmarkerGraphOptions.setHandDetectorGraphOptions(
|
||||||
|
this.handDetectorGraphOptions);
|
||||||
|
this.handGestureRecognizerGraphOptions =
|
||||||
|
new HandGestureRecognizerGraphOptions();
|
||||||
|
this.options.setHandGestureRecognizerGraphOptions(
|
||||||
|
this.handGestureRecognizerGraphOptions);
|
||||||
|
|
||||||
|
this.initDefaults();
|
||||||
|
|
||||||
|
// Disables the automatic render-to-screen code, which allows for pure
|
||||||
|
// CPU processing.
|
||||||
|
this.setAutoRenderToScreen(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the gesture recognizer.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those options.
|
||||||
|
* You can reset an option back to its default value by explicitly setting it
|
||||||
|
* to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the gesture recognizer.
|
||||||
|
*/
|
||||||
|
async setOptions(options: GestureRecognizerOptions): Promise<void> {
|
||||||
|
if (options.baseOptions) {
|
||||||
|
const baseOptionsProto =
|
||||||
|
await convertBaseOptionsToProto(options.baseOptions);
|
||||||
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ('numHands' in options) {
|
||||||
|
this.handDetectorGraphOptions.setNumHands(
|
||||||
|
options.numHands ?? DEFAULT_NUM_HANDS);
|
||||||
|
}
|
||||||
|
if ('minHandDetectionConfidence' in options) {
|
||||||
|
this.handDetectorGraphOptions.setMinDetectionConfidence(
|
||||||
|
options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||||
|
}
|
||||||
|
if ('minHandPresenceConfidence' in options) {
|
||||||
|
this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
|
||||||
|
options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||||
|
}
|
||||||
|
if ('minTrackingConfidence' in options) {
|
||||||
|
this.handLandmarkerGraphOptions.setMinTrackingConfidence(
|
||||||
|
options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.cannedGesturesClassifierOptions) {
|
||||||
|
// Note that we have to support both JSPB and ProtobufJS and cannot
|
||||||
|
// use JSPB's getMutableX() APIs.
|
||||||
|
const graphOptions = new GestureClassifierGraphOptions();
|
||||||
|
graphOptions.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
|
options.cannedGesturesClassifierOptions,
|
||||||
|
this.handGestureRecognizerGraphOptions
|
||||||
|
.getCannedGestureClassifierGraphOptions()
|
||||||
|
?.getClassifierOptions()));
|
||||||
|
this.handGestureRecognizerGraphOptions
|
||||||
|
.setCannedGestureClassifierGraphOptions(graphOptions);
|
||||||
|
} else if (options.cannedGesturesClassifierOptions === undefined) {
|
||||||
|
this.handGestureRecognizerGraphOptions
|
||||||
|
.getCannedGestureClassifierGraphOptions()
|
||||||
|
?.clearClassifierOptions();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.customGesturesClassifierOptions) {
|
||||||
|
const graphOptions = new GestureClassifierGraphOptions();
|
||||||
|
graphOptions.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
|
options.customGesturesClassifierOptions,
|
||||||
|
this.handGestureRecognizerGraphOptions
|
||||||
|
.getCustomGestureClassifierGraphOptions()
|
||||||
|
?.getClassifierOptions()));
|
||||||
|
this.handGestureRecognizerGraphOptions
|
||||||
|
.setCustomGestureClassifierGraphOptions(graphOptions);
|
||||||
|
} else if (options.customGesturesClassifierOptions === undefined) {
|
||||||
|
this.handGestureRecognizerGraphOptions
|
||||||
|
.getCustomGestureClassifierGraphOptions()
|
||||||
|
?.clearClassifierOptions();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.refreshGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs gesture recognition on the provided single image and waits
|
||||||
|
* synchronously for the response.
|
||||||
|
* @param imageSource An image source to process.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms. If not
|
||||||
|
* provided, defaults to `performance.now()`.
|
||||||
|
* @return The detected gestures.
|
||||||
|
*/
|
||||||
|
recognize(imageSource: ImageSource, timestamp: number = performance.now()):
|
||||||
|
GestureRecognitionResult {
|
||||||
|
this.gestures = [];
|
||||||
|
this.landmarks = [];
|
||||||
|
this.worldLandmarks = [];
|
||||||
|
this.handednesses = [];
|
||||||
|
|
||||||
|
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
|
||||||
|
this.addProtoToStream(
|
||||||
|
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
|
||||||
|
NORM_RECT_STREAM, timestamp);
|
||||||
|
this.finishProcessing();
|
||||||
|
|
||||||
|
return {
|
||||||
|
gestures: this.gestures,
|
||||||
|
landmarks: this.landmarks,
|
||||||
|
worldLandmarks: this.worldLandmarks,
|
||||||
|
handednesses: this.handednesses
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets the default values for the graph. */
|
||||||
|
private initDefaults(): void {
|
||||||
|
this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS);
|
||||||
|
this.handDetectorGraphOptions.setMinDetectionConfidence(
|
||||||
|
DEFAULT_SCORE_THRESHOLD);
|
||||||
|
this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
|
||||||
|
DEFAULT_SCORE_THRESHOLD);
|
||||||
|
this.handLandmarkerGraphOptions.setMinTrackingConfidence(
|
||||||
|
DEFAULT_SCORE_THRESHOLD);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts the proto data to a Category[][] structure. */
|
||||||
|
private toJsCategories(data: Uint8Array[]): Category[][] {
|
||||||
|
const result: Category[][] = [];
|
||||||
|
for (const binaryProto of data) {
|
||||||
|
const inputList = ClassificationList.deserializeBinary(binaryProto);
|
||||||
|
const outputList: Category[] = [];
|
||||||
|
for (const classification of inputList.getClassificationList()) {
|
||||||
|
outputList.push({
|
||||||
|
score: classification.getScore() ?? 0,
|
||||||
|
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX,
|
||||||
|
categoryName: classification.getLabel() ?? '',
|
||||||
|
displayName: classification.getDisplayName() ?? '',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
result.push(outputList);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts raw data into a landmark, and adds it to our landmarks list. */
|
||||||
|
private addJsLandmarks(data: Uint8Array[]): void {
|
||||||
|
for (const binaryProto of data) {
|
||||||
|
const handLandmarksProto =
|
||||||
|
NormalizedLandmarkList.deserializeBinary(binaryProto);
|
||||||
|
const landmarks: Landmark[] = [];
|
||||||
|
for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) {
|
||||||
|
landmarks.push({
|
||||||
|
x: handLandmarkProto.getX() ?? 0,
|
||||||
|
y: handLandmarkProto.getY() ?? 0,
|
||||||
|
z: handLandmarkProto.getZ() ?? 0,
|
||||||
|
normalized: true
|
||||||
|
});
|
||||||
|
}
|
||||||
|
this.landmarks.push(landmarks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts raw data into a landmark, and adds it to our worldLandmarks
|
||||||
|
* list.
|
||||||
|
*/
|
||||||
|
private adddJsWorldLandmarks(data: Uint8Array[]): void {
|
||||||
|
for (const binaryProto of data) {
|
||||||
|
const handWorldLandmarksProto =
|
||||||
|
LandmarkList.deserializeBinary(binaryProto);
|
||||||
|
const worldLandmarks: Landmark[] = [];
|
||||||
|
for (const handWorldLandmarkProto of
|
||||||
|
handWorldLandmarksProto.getLandmarkList()) {
|
||||||
|
worldLandmarks.push({
|
||||||
|
x: handWorldLandmarkProto.getX() ?? 0,
|
||||||
|
y: handWorldLandmarkProto.getY() ?? 0,
|
||||||
|
z: handWorldLandmarkProto.getZ() ?? 0,
|
||||||
|
normalized: false
|
||||||
|
});
|
||||||
|
}
|
||||||
|
this.worldLandmarks.push(worldLandmarks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
private refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
|
graphConfig.addOutputStream(HAND_GESTURES_STREAM);
|
||||||
|
graphConfig.addOutputStream(LANDMARKS_STREAM);
|
||||||
|
graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM);
|
||||||
|
graphConfig.addOutputStream(HANDEDNESS_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
GestureRecognizerGraphOptions.ext, this.options);
|
||||||
|
|
||||||
|
const recognizerNode = new CalculatorGraphConfig.Node();
|
||||||
|
recognizerNode.setCalculator(GESTURE_RECOGNIZER_GRAPH);
|
||||||
|
recognizerNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||||
|
recognizerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||||
|
recognizerNode.addOutputStream('HAND_GESTURES:' + HAND_GESTURES_STREAM);
|
||||||
|
recognizerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM);
|
||||||
|
recognizerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM);
|
||||||
|
recognizerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM);
|
||||||
|
recognizerNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(recognizerNode);
|
||||||
|
|
||||||
|
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
|
||||||
|
this.addJsLandmarks(binaryProto);
|
||||||
|
});
|
||||||
|
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
|
||||||
|
this.adddJsWorldLandmarks(binaryProto);
|
||||||
|
});
|
||||||
|
this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => {
|
||||||
|
this.gestures.push(...this.toJsCategories(binaryProto));
|
||||||
|
});
|
||||||
|
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
|
||||||
|
this.handednesses.push(...this.toJsCategories(binaryProto));
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
65
mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts
vendored
Normal file
65
mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts
vendored
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||||
|
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
||||||
|
|
||||||
|
/** Options to configure the MediaPipe Gesture Recognizer Task */
|
||||||
|
export interface GestureRecognizerOptions {
|
||||||
|
/** Options to configure the loading of the model assets. */
|
||||||
|
baseOptions?: BaseOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum number of hands can be detected by the GestureRecognizer.
|
||||||
|
* Defaults to 1.
|
||||||
|
*/
|
||||||
|
numHands?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The minimum confidence score for the hand detection to be considered
|
||||||
|
* successful. Defaults to 0.5.
|
||||||
|
*/
|
||||||
|
minHandDetectionConfidence?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The minimum confidence score of hand presence score in the hand landmark
|
||||||
|
* detection. Defaults to 0.5.
|
||||||
|
*/
|
||||||
|
minHandPresenceConfidence?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The minimum confidence score for the hand tracking to be considered
|
||||||
|
* successful. Defaults to 0.5.
|
||||||
|
*/
|
||||||
|
minTrackingConfidence?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional `ClassifierOptions` controling the canned gestures
|
||||||
|
* classifier, such as score threshold, allow list and deny list of gestures.
|
||||||
|
* The categories for canned gesture
|
||||||
|
* classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up",
|
||||||
|
* "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"]
|
||||||
|
*/
|
||||||
|
// TODO: Note this option is subject to change
|
||||||
|
cannedGesturesClassifierOptions?: ClassifierOptions|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options for configuring the custom gestures classifier, such as score
|
||||||
|
* threshold, allow list and deny list of gestures.
|
||||||
|
*/
|
||||||
|
// TODO b/251816640): Note this option is subject to change.
|
||||||
|
customGesturesClassifierOptions?: ClassifierOptions|undefined;
|
||||||
|
}
|
35
mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts
vendored
Normal file
35
mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts
vendored
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
import {Landmark} from '../../../../tasks/web/components/containers/landmark';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the gesture recognition results generated by `GestureRecognizer`.
|
||||||
|
*/
|
||||||
|
export interface GestureRecognitionResult {
|
||||||
|
/** Hand landmarks of detected hands. */
|
||||||
|
landmarks: Landmark[][];
|
||||||
|
|
||||||
|
/** Hand landmarks in world coordniates of detected hands. */
|
||||||
|
worldLandmarks: Landmark[][];
|
||||||
|
|
||||||
|
/** Handedness of detected hands. */
|
||||||
|
handednesses: Category[][];
|
||||||
|
|
||||||
|
/** Recognized hand gestures of detected hands */
|
||||||
|
gestures: Category[][];
|
||||||
|
}
|
33
mediapipe/tasks/web/vision/image_classifier/BUILD
Normal file
33
mediapipe/tasks/web/vision/image_classifier/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# This contains the MediaPipe Image Classifier Task.
|
||||||
|
#
|
||||||
|
# This task takes video or image frames and outputs the classification result.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image_classifier",
|
||||||
|
srcs = [
|
||||||
|
"image_classifier.ts",
|
||||||
|
"image_classifier_options.ts",
|
||||||
|
"image_classifier_result.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
186
mediapipe/tasks/web/vision/image_classifier/image_classifier.ts
Normal file
186
mediapipe/tasks/web/vision/image_classifier/image_classifier.ts
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb';
|
||||||
|
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||||
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
|
||||||
|
import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {ImageClassifierOptions} from './image_classifier_options';
|
||||||
|
import {Classifications} from './image_classifier_result';
|
||||||
|
|
||||||
|
const IMAGE_CLASSIFIER_GRAPH =
|
||||||
|
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
||||||
|
const INPUT_STREAM = 'input_image';
|
||||||
|
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
||||||
|
|
||||||
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/** Performs classification on images. */
|
||||||
|
export class ImageClassifier extends TaskRunner {
|
||||||
|
private classifications: Classifications[] = [];
|
||||||
|
private readonly options = new ImageClassifierGraphOptions();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new image classifier from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param imageClassifierOptions The options for the image classifier. Note
|
||||||
|
* that either a path to the model asset or a model buffer needs to be
|
||||||
|
* provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static async createFromOptions(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
imageClassifierOptions: ImageClassifierOptions):
|
||||||
|
Promise<ImageClassifier> {
|
||||||
|
// Create a file locator based on the loader options
|
||||||
|
const fileLocator: FileLocator = {
|
||||||
|
locateFile() {
|
||||||
|
// The only file we load is the Wasm binary
|
||||||
|
return wasmLoaderOptions.wasmBinaryPath.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const classifier = await createMediaPipeLib(
|
||||||
|
ImageClassifier, wasmLoaderOptions.wasmLoaderPath,
|
||||||
|
/* assetLoaderScript= */ undefined,
|
||||||
|
/* glCanvas= */ undefined, fileLocator);
|
||||||
|
await classifier.setOptions(imageClassifierOptions);
|
||||||
|
return classifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new image classifier based on
|
||||||
|
* the provided model asset buffer.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
|
||||||
|
return ImageClassifier.createFromOptions(
|
||||||
|
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new image classifier based on
|
||||||
|
* the path to the model asset.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static async createFromModelPath(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetPath: string): Promise<ImageClassifier> {
|
||||||
|
const response = await fetch(modelAssetPath.toString());
|
||||||
|
const graphData = await response.arrayBuffer();
|
||||||
|
return ImageClassifier.createFromModelBuffer(
|
||||||
|
wasmLoaderOptions, new Uint8Array(graphData));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the image classifier.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those options.
|
||||||
|
* You can reset an option back to its default value by explicitly setting it
|
||||||
|
* to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the image classifier.
|
||||||
|
*/
|
||||||
|
async setOptions(options: ImageClassifierOptions): Promise<void> {
|
||||||
|
if (options.baseOptions) {
|
||||||
|
const baseOptionsProto =
|
||||||
|
await convertBaseOptionsToProto(options.baseOptions);
|
||||||
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
|
options, this.options.getClassifierOptions()));
|
||||||
|
this.refreshGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs image classification on the provided image and waits synchronously
|
||||||
|
* for the response.
|
||||||
|
*
|
||||||
|
* @param imageSource An image source to process.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms. If not
|
||||||
|
* provided, defaults to `performance.now()`.
|
||||||
|
* @return The classification result of the image
|
||||||
|
*/
|
||||||
|
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
|
||||||
|
// Get classification classes by running our MediaPipe graph.
|
||||||
|
this.classifications = [];
|
||||||
|
this.addGpuBufferAsImageToStream(
|
||||||
|
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||||
|
this.finishProcessing();
|
||||||
|
return [...this.classifications];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal function for converting raw data into a classification, and
|
||||||
|
* adding it to our classfications list.
|
||||||
|
**/
|
||||||
|
private addJsImageClassification(binaryProto: Uint8Array): void {
|
||||||
|
const classificationResult =
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto);
|
||||||
|
this.classifications.push(
|
||||||
|
...convertFromClassificationResultProto(classificationResult));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
private refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
|
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
ImageClassifierGraphOptions.ext, this.options);
|
||||||
|
|
||||||
|
// Perform image classification. Pre-processing and results post-processing
|
||||||
|
// are built-in.
|
||||||
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
|
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
||||||
|
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
||||||
|
classifierNode.addOutputStream(
|
||||||
|
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||||
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
|
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||||
|
this.addJsImageClassification(binaryProto);
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
|
@ -0,0 +1,18 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
30
mediapipe/tasks/web/vision/object_detector/BUILD
Normal file
30
mediapipe/tasks/web/vision/object_detector/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# This contains the MediaPipe Object Detector Task.
|
||||||
|
#
|
||||||
|
# This task takes video frames and outputs synchronized frames along with
|
||||||
|
# the detection results for one or more object categories, using Object Detector.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "object_detector",
|
||||||
|
srcs = [
|
||||||
|
"object_detector.ts",
|
||||||
|
"object_detector_options.d.ts",
|
||||||
|
"object_detector_result.d.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:detection_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
233
mediapipe/tasks/web/vision/object_detector/object_detector.ts
Normal file
233
mediapipe/tasks/web/vision/object_detector/object_detector.ts
Normal file
|
@ -0,0 +1,233 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
|
||||||
|
import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
|
||||||
|
import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options';
|
||||||
|
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options';
|
||||||
|
import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {ObjectDetectorOptions} from './object_detector_options';
|
||||||
|
import {Detection} from './object_detector_result';
|
||||||
|
|
||||||
|
const INPUT_STREAM = 'input_frame_gpu';
|
||||||
|
const DETECTIONS_STREAM = 'detections';
|
||||||
|
const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph';
|
||||||
|
|
||||||
|
const DEFAULT_CATEGORY_INDEX = -1;
|
||||||
|
|
||||||
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/** Performs object detection on images. */
|
||||||
|
export class ObjectDetector extends TaskRunner {
|
||||||
|
private detections: Detection[] = [];
|
||||||
|
private readonly options = new ObjectDetectorOptionsProto();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new object detector from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param objectDetectorOptions The options for the Object Detector. Note that
|
||||||
|
* either a path to the model asset or a model buffer needs to be
|
||||||
|
* provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static async createFromOptions(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
|
||||||
|
// Create a file locator based on the loader options
|
||||||
|
const fileLocator: FileLocator = {
|
||||||
|
locateFile() {
|
||||||
|
// The only file we load is the Wasm binary
|
||||||
|
return wasmLoaderOptions.wasmBinaryPath.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const detector = await createMediaPipeLib(
|
||||||
|
ObjectDetector, wasmLoaderOptions.wasmLoaderPath,
|
||||||
|
/* assetLoaderScript= */ undefined,
|
||||||
|
/* glCanvas= */ undefined, fileLocator);
|
||||||
|
await detector.setOptions(objectDetectorOptions);
|
||||||
|
return detector;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new object detector based on the
|
||||||
|
* provided model asset buffer.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
|
||||||
|
return ObjectDetector.createFromOptions(
|
||||||
|
wasmLoaderOptions, {baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new object detector based on the
|
||||||
|
* path to the model asset.
|
||||||
|
* @param wasmLoaderOptions A configuration object that provides the location
|
||||||
|
* of the Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static async createFromModelPath(
|
||||||
|
wasmLoaderOptions: WasmLoaderOptions,
|
||||||
|
modelAssetPath: string): Promise<ObjectDetector> {
|
||||||
|
const response = await fetch(modelAssetPath.toString());
|
||||||
|
const graphData = await response.arrayBuffer();
|
||||||
|
return ObjectDetector.createFromModelBuffer(
|
||||||
|
wasmLoaderOptions, new Uint8Array(graphData));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the object detector.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those options.
|
||||||
|
* You can reset an option back to its default value by explicitly setting it
|
||||||
|
* to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the object detector.
|
||||||
|
*/
|
||||||
|
async setOptions(options: ObjectDetectorOptions): Promise<void> {
|
||||||
|
if (options.baseOptions) {
|
||||||
|
const baseOptionsProto =
|
||||||
|
await convertBaseOptionsToProto(options.baseOptions);
|
||||||
|
this.options.setBaseOptions(baseOptionsProto);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||||
|
// have to expliclity clear the values instead of setting them to
|
||||||
|
// `undefined`.
|
||||||
|
if (options.displayNamesLocale !== undefined) {
|
||||||
|
this.options.setDisplayNamesLocale(options.displayNamesLocale);
|
||||||
|
} else if ('displayNamesLocale' in options) { // Check for undefined
|
||||||
|
this.options.clearDisplayNamesLocale();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.maxResults !== undefined) {
|
||||||
|
this.options.setMaxResults(options.maxResults);
|
||||||
|
} else if ('maxResults' in options) { // Check for undefined
|
||||||
|
this.options.clearMaxResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.scoreThreshold !== undefined) {
|
||||||
|
this.options.setScoreThreshold(options.scoreThreshold);
|
||||||
|
} else if ('scoreThreshold' in options) { // Check for undefined
|
||||||
|
this.options.clearScoreThreshold();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.categoryAllowlist !== undefined) {
|
||||||
|
this.options.setCategoryAllowlistList(options.categoryAllowlist);
|
||||||
|
} else if ('categoryAllowlist' in options) { // Check for undefined
|
||||||
|
this.options.clearCategoryAllowlistList();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.categoryDenylist !== undefined) {
|
||||||
|
this.options.setCategoryDenylistList(options.categoryDenylist);
|
||||||
|
} else if ('categoryDenylist' in options) { // Check for undefined
|
||||||
|
this.options.clearCategoryDenylistList();
|
||||||
|
}
|
||||||
|
|
||||||
|
this.refreshGraph();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs object detection on the provided single image and waits
|
||||||
|
* synchronously for the response.
|
||||||
|
* @param imageSource An image source to process.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms. If not
|
||||||
|
* provided, defaults to `performance.now()`.
|
||||||
|
* @return The list of detected objects
|
||||||
|
*/
|
||||||
|
detect(imageSource: ImageSource, timestamp?: number): Detection[] {
|
||||||
|
// Get detections by running our MediaPipe graph.
|
||||||
|
this.detections = [];
|
||||||
|
this.addGpuBufferAsImageToStream(
|
||||||
|
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||||
|
this.finishProcessing();
|
||||||
|
return [...this.detections];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts raw data into a Detection, and adds it to our detection list. */
|
||||||
|
private addJsObjectDetections(data: Uint8Array[]): void {
|
||||||
|
for (const binaryProto of data) {
|
||||||
|
const detectionProto = DetectionProto.deserializeBinary(binaryProto);
|
||||||
|
const scores = detectionProto.getScoreList();
|
||||||
|
const indexes = detectionProto.getLabelIdList();
|
||||||
|
const labels = detectionProto.getLabelList();
|
||||||
|
const displayNames = detectionProto.getDisplayNameList();
|
||||||
|
|
||||||
|
const detection: Detection = {categories: []};
|
||||||
|
for (let i = 0; i < scores.length; i++) {
|
||||||
|
detection.categories.push({
|
||||||
|
score: scores[i],
|
||||||
|
index: indexes[i] ?? DEFAULT_CATEGORY_INDEX,
|
||||||
|
categoryName: labels[i] ?? '',
|
||||||
|
displayName: displayNames[i] ?? '',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const boundingBox = detectionProto.getLocationData()?.getBoundingBox();
|
||||||
|
if (boundingBox) {
|
||||||
|
detection.boundingBox = {
|
||||||
|
originX: boundingBox.getXmin() ?? 0,
|
||||||
|
originY: boundingBox.getYmin() ?? 0,
|
||||||
|
width: boundingBox.getWidth() ?? 0,
|
||||||
|
height: boundingBox.getHeight() ?? 0
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
this.detections.push(detection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
private refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
|
graphConfig.addOutputStream(DETECTIONS_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
ObjectDetectorOptionsProto.ext, this.options);
|
||||||
|
|
||||||
|
const detectorNode = new CalculatorGraphConfig.Node();
|
||||||
|
detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH);
|
||||||
|
detectorNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
||||||
|
detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM);
|
||||||
|
detectorNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(detectorNode);
|
||||||
|
|
||||||
|
this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => {
|
||||||
|
this.addJsObjectDetections(binaryProto);
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
52
mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts
vendored
Normal file
52
mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts
vendored
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||||
|
|
||||||
|
/** Options to configure the MediaPipe Object Detector Task */
|
||||||
|
export interface ObjectDetectorOptions {
|
||||||
|
/** Options to configure the loading of the model assets. */
|
||||||
|
baseOptions?: BaseOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The locale to use for display names specified through the TFLite Model
|
||||||
|
* Metadata, if any. Defaults to English.
|
||||||
|
*/
|
||||||
|
displayNamesLocale?: string|undefined;
|
||||||
|
|
||||||
|
/** The maximum number of top-scored detection results to return. */
|
||||||
|
maxResults?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Overrides the value provided in the model metadata. Results below this
|
||||||
|
* value are rejected.
|
||||||
|
*/
|
||||||
|
scoreThreshold?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allowlist of category names. If non-empty, detection results whose category
|
||||||
|
* name is not in this set will be filtered out. Duplicate or unknown category
|
||||||
|
* names are ignored. Mutually exclusive with `categoryDenylist`.
|
||||||
|
*/
|
||||||
|
categoryAllowlist?: string[]|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Denylist of category names. If non-empty, detection results whose category
|
||||||
|
* name is in this set will be filtered out. Duplicate or unknown category
|
||||||
|
* names are ignored. Mutually exclusive with `categoryAllowlist`.
|
||||||
|
*/
|
||||||
|
categoryDenylist?: string[]|undefined;
|
||||||
|
}
|
38
mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts
vendored
Normal file
38
mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts
vendored
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
|
||||||
|
/** An integer bounding box, axis aligned. */
|
||||||
|
export interface BoundingBox {
|
||||||
|
/** The X coordinate of the top-left corner, in pixels. */
|
||||||
|
originX: number;
|
||||||
|
/** The Y coordinate of the top-left corner, in pixels. */
|
||||||
|
originY: number;
|
||||||
|
/** The width of the bounding box, in pixels. */
|
||||||
|
width: number;
|
||||||
|
/** The height of the bounding box, in pixels. */
|
||||||
|
height: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Represents one object detected by the `ObjectDetector`. */
|
||||||
|
export interface Detection {
|
||||||
|
/** A list of `Category` objects. */
|
||||||
|
categories: Category[];
|
||||||
|
|
||||||
|
/** The bounding box of the detected objects. */
|
||||||
|
boundingBox?: BoundingBox;
|
||||||
|
}
|
41
mediapipe/web/graph_runner/BUILD
Normal file
41
mediapipe/web/graph_runner/BUILD
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# The TypeScript graph runner used by all MediaPipe Web tasks.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = [
|
||||||
|
":internal",
|
||||||
|
"//mediapipe/tasks:internal",
|
||||||
|
])
|
||||||
|
|
||||||
|
package_group(
|
||||||
|
name = "internal",
|
||||||
|
packages = [
|
||||||
|
"//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "wasm_mediapipe_lib_ts",
|
||||||
|
srcs = [
|
||||||
|
":wasm_mediapipe_lib.ts",
|
||||||
|
],
|
||||||
|
allow_unoptimized_namespaces = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "wasm_mediapipe_image_lib_ts",
|
||||||
|
srcs = [
|
||||||
|
":wasm_mediapipe_image_lib.ts",
|
||||||
|
],
|
||||||
|
allow_unoptimized_namespaces = True,
|
||||||
|
deps = [":wasm_mediapipe_lib_ts"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "register_model_resources_graph_service_ts",
|
||||||
|
srcs = [
|
||||||
|
":register_model_resources_graph_service.ts",
|
||||||
|
],
|
||||||
|
allow_unoptimized_namespaces = True,
|
||||||
|
deps = [":wasm_mediapipe_lib_ts"],
|
||||||
|
)
|
|
@ -0,0 +1,41 @@
|
||||||
|
import {WasmMediaPipeLib} from './wasm_mediapipe_lib';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We extend from a WasmMediaPipeLib constructor. This ensures our mixin has
|
||||||
|
* access to the wasmModule, among other things. The `any` type is required for
|
||||||
|
* mixin constructors.
|
||||||
|
*/
|
||||||
|
// tslint:disable-next-line:no-any
|
||||||
|
type LibConstructor = new (...args: any[]) => WasmMediaPipeLib;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
|
||||||
|
* doesn't break our JS/C++ bridge.
|
||||||
|
*/
|
||||||
|
export declare interface WasmModuleRegisterModelResources {
|
||||||
|
_registerModelResourcesGraphService: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation of WasmMediaPipeLib that supports registering model
|
||||||
|
* resources to a cache, in the form of a GraphService C++-side. We implement as
|
||||||
|
* a proper TS mixin, to allow for effective multiple inheritance. Sample usage:
|
||||||
|
* `const WasmMediaPipeImageLib = SupportModelResourcesGraphService(
|
||||||
|
* WasmMediaPipeLib);`
|
||||||
|
*/
|
||||||
|
// tslint:disable:enforce-name-casing
|
||||||
|
export function SupportModelResourcesGraphService<TBase extends LibConstructor>(
|
||||||
|
Base: TBase) {
|
||||||
|
return class extends Base {
|
||||||
|
// tslint:enable:enforce-name-casing
|
||||||
|
/**
|
||||||
|
* Instructs the graph runner to use the model resource caching graph
|
||||||
|
* service for both graph expansion/inintialization, as well as for graph
|
||||||
|
* run.
|
||||||
|
*/
|
||||||
|
registerModelResourcesGraphService(): void {
|
||||||
|
(this.wasmModule as unknown as WasmModuleRegisterModelResources)
|
||||||
|
._registerModelResourcesGraphService();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
52
mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts
Normal file
52
mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We extend from a WasmMediaPipeLib constructor. This ensures our mixin has
|
||||||
|
* access to the wasmModule, among other things. The `any` type is required for
|
||||||
|
* mixin constructors.
|
||||||
|
*/
|
||||||
|
// tslint:disable-next-line:no-any
|
||||||
|
type LibConstructor = new (...args: any[]) => WasmMediaPipeLib;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
|
||||||
|
* doesn't break our JS/C++ bridge.
|
||||||
|
*/
|
||||||
|
export declare interface WasmImageModule {
|
||||||
|
_addBoundTextureAsImageToStream:
|
||||||
|
(streamNamePtr: number, width: number, height: number,
|
||||||
|
timestamp: number) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An implementation of WasmMediaPipeLib that supports binding GPU image data as
|
||||||
|
* `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for
|
||||||
|
* effective multiple inheritance. Example usage:
|
||||||
|
* `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);`
|
||||||
|
*/
|
||||||
|
// tslint:disable-next-line:enforce-name-casing
|
||||||
|
export function SupportImage<TBase extends LibConstructor>(Base: TBase) {
|
||||||
|
return class extends Base {
|
||||||
|
/**
|
||||||
|
* Takes the relevant information from the HTML video or image element, and
|
||||||
|
* passes it into the WebGL-based graph for processing on the given stream
|
||||||
|
* at the given timestamp as a MediaPipe image. Processing will not occur
|
||||||
|
* until a blocking call (like processVideoGl or finishProcessing) is made.
|
||||||
|
* @param imageSource Reference to the video frame we wish to add into our
|
||||||
|
* graph.
|
||||||
|
* @param streamName The name of the MediaPipe graph stream to add the frame
|
||||||
|
* to.
|
||||||
|
* @param timestamp The timestamp of the input frame, in ms.
|
||||||
|
*/
|
||||||
|
addGpuBufferAsImageToStream(
|
||||||
|
imageSource: ImageSource, streamName: string, timestamp: number): void {
|
||||||
|
this.wrapStringPtr(streamName, (streamNamePtr: number) => {
|
||||||
|
const [width, height] =
|
||||||
|
this.bindTextureToStream(imageSource, streamNamePtr);
|
||||||
|
(this.wasmModule as unknown as WasmImageModule)
|
||||||
|
._addBoundTextureAsImageToStream(
|
||||||
|
streamNamePtr, width, height, timestamp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
1044
mediapipe/web/graph_runner/wasm_mediapipe_lib.ts
Normal file
1044
mediapipe/web/graph_runner/wasm_mediapipe_lib.ts
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user