Merge branch 'google:master' into text-classifier-python

This commit is contained in:
Kinar R 2022-11-05 11:03:24 +05:30 committed by GitHub
commit 740d2e47b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
99 changed files with 3589 additions and 843 deletions

1
.gitignore vendored
View File

@ -2,5 +2,6 @@ bazel-*
mediapipe/MediaPipe.xcodeproj mediapipe/MediaPipe.xcodeproj
mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user
mediapipe/provisioning_profile.mobileprovision mediapipe/provisioning_profile.mobileprovision
node_modules/
.configure.bazelrc .configure.bazelrc
.user.bazelrc .user.bazelrc

View File

@ -1,4 +1,4 @@
# Copyright 2019 The MediaPipe Authors. # Copyright 2022 The MediaPipe Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,4 +14,9 @@
licenses(["notice"]) licenses(["notice"])
exports_files(["LICENSE"]) exports_files([
"LICENSE",
"tsconfig.json",
"package.json",
"yarn.lock",
])

View File

@ -501,5 +501,48 @@ libedgetpu_dependencies()
load("@coral_crosstool//:configure.bzl", "cc_crosstool") load("@coral_crosstool//:configure.bzl", "cc_crosstool")
cc_crosstool(name = "crosstool") cc_crosstool(name = "crosstool")
# Node dependencies
http_archive(
name = "build_bazel_rules_nodejs",
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
)
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
build_bazel_rules_nodejs_dependencies()
# fetches nodejs, npm, and yarn
load("@build_bazel_rules_nodejs//:index.bzl", "node_repositories", "yarn_install")
node_repositories()
yarn_install(
name = "npm",
package_json = "//:package.json",
yarn_lock = "//:yarn.lock",
)
# Protobuf for Node dependencies
http_archive(
name = "rules_proto_grpc",
sha256 = "bbe4db93499f5c9414926e46f9e35016999a4e9f6e3522482d3760dc61011070",
strip_prefix = "rules_proto_grpc-4.2.0",
urls = ["https://github.com/rules-proto-grpc/rules_proto_grpc/archive/4.2.0.tar.gz"],
)
http_archive(
name = "com_google_protobuf_javascript",
sha256 = "35bca1729532b0a77280bf28ab5937438e3dcccd6b31a282d9ae84c896b6f6e3",
strip_prefix = "protobuf-javascript-3.21.2",
urls = ["https://github.com/protocolbuffers/protobuf-javascript/archive/refs/tags/v3.21.2.tar.gz"],
)
load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains", "rules_proto_grpc_repos")
rules_proto_grpc_toolchains()
rules_proto_grpc_repos()
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
rules_proto_dependencies()
rules_proto_toolchains()
load("//third_party:external_files.bzl", "external_files") load("//third_party:external_files.bzl", "external_files")
external_files() external_files()

View File

@ -141,3 +141,4 @@ Nvidia Jetson and Raspberry Pi, please read
```bash ```bash
(mp_env)mediapipe$ python3 setup.py bdist_wheel (mp_env)mediapipe$ python3 setup.py bdist_wheel
``` ```
7. Exit from the MediaPipe repo directory and launch the Python interpreter.

View File

@ -936,6 +936,7 @@ cc_test(
"//mediapipe/framework/tool:simulation_clock", "//mediapipe/framework/tool:simulation_clock",
"//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:simulation_clock_executor",
"//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:sink",
"//mediapipe/util:packet_test_util",
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
], ],
) )

View File

@ -18,7 +18,6 @@
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h" #include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/header_util.h" #include "mediapipe/util/header_util.h"
@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS";
// FlowLimiterCalculator provides limited support for multiple input streams. // FlowLimiterCalculator provides limited support for multiple input streams.
// The first input stream is treated as the main input stream and successive // The first input stream is treated as the main input stream and successive
// input streams are treated as auxiliary input streams. The auxiliary input // input streams are treated as auxiliary input streams. The auxiliary input
// streams are limited to timestamps passed on the main input stream. // streams are limited to timestamps allowed by the "ALLOW" stream.
// //
class FlowLimiterCalculator : public CalculatorBase { class FlowLimiterCalculator : public CalculatorBase {
public: public:
@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase {
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>()); cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
} }
input_queues_.resize(cc->Inputs().NumEntries("")); input_queues_.resize(cc->Inputs().NumEntries(""));
allowed_[Timestamp::Unset()] = true;
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
return absl::OkStatus(); return absl::OkStatus();
} }
// Returns true if an additional frame can be released for processing.
// The "ALLOW" output stream indicates this condition at each input frame.
bool ProcessingAllowed() {
return frames_in_flight_.size() < options_.max_in_flight();
}
// Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
}
}
// Sets the timestamp bound or closes an output stream.
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
if (bound > Timestamp::Max()) {
stream->Close();
} else {
stream->SetNextTimestampBound(bound);
}
}
// Returns true if a certain timestamp is being processed.
bool IsInFlight(Timestamp timestamp) {
return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
timestamp) != frames_in_flight_.end();
}
// Releases input packets up to the latest settled input timestamp.
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
// Release settled frames from each input queue.
while (!input_queues_[i].empty() &&
input_queues_[i].front().Timestamp() < settled_bound) {
Packet packet = input_queues_[i].front();
input_queues_[i].pop_front();
if (IsInFlight(packet.Timestamp())) {
cc->Outputs().Get("", i).AddPacket(packet);
}
}
// Propagate each input timestamp bound.
if (!input_queues_[i].empty()) {
Timestamp bound = input_queues_[i].front().Timestamp();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
} else {
Timestamp bound =
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
}
}
}
// Releases input packets allowed by the max_in_flight constraint. // Releases input packets allowed by the max_in_flight constraint.
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
options_ = tool::RetrieveOptions(options_, cc->Inputs()); options_ = tool::RetrieveOptions(options_, cc->Inputs());
@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase {
} }
ProcessAuxiliaryInputs(cc); ProcessAuxiliaryInputs(cc);
// Discard old ALLOW ranges.
Timestamp input_bound = InputTimestampBound(cc);
auto first_range = std::prev(allowed_.upper_bound(input_bound));
allowed_.erase(allowed_.begin(), first_range);
return absl::OkStatus(); return absl::OkStatus();
} }
int LedgerSize() {
int result = frames_in_flight_.size() + allowed_.size();
for (const auto& queue : input_queues_) {
result += queue.size();
}
return result;
}
private:
// Returns true if an additional frame can be released for processing.
// The "ALLOW" output stream indicates this condition at each input frame.
bool ProcessingAllowed() {
return frames_in_flight_.size() < options_.max_in_flight();
}
// Outputs a packet indicating whether a frame was sent or dropped.
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
if (cc->Outputs().HasTag(kAllowTag)) {
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
}
allowed_[ts] = allow;
}
// Returns true if a timestamp falls within a range of allowed timestamps.
bool IsAllowed(Timestamp timestamp) {
auto it = allowed_.upper_bound(timestamp);
return std::prev(it)->second;
}
// Sets the timestamp bound or closes an output stream.
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
if (bound > Timestamp::Max()) {
stream->Close();
} else {
stream->SetNextTimestampBound(bound);
}
}
// Returns the lowest unprocessed input Timestamp.
Timestamp InputTimestampBound(CalculatorContext* cc) {
Timestamp result = Timestamp::Done();
for (int i = 0; i < input_queues_.size(); ++i) {
auto& queue = input_queues_[i];
auto& stream = cc->Inputs().Get("", i);
Timestamp bound = queue.empty()
? stream.Value().Timestamp().NextAllowedInStream()
: queue.front().Timestamp();
result = std::min(result, bound);
}
return result;
}
// Releases input packets up to the latest settled input timestamp.
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
// Release settled frames from each input queue.
while (!input_queues_[i].empty() &&
input_queues_[i].front().Timestamp() < settled_bound) {
Packet packet = input_queues_[i].front();
input_queues_[i].pop_front();
if (IsAllowed(packet.Timestamp())) {
cc->Outputs().Get("", i).AddPacket(packet);
}
}
// Propagate each input timestamp bound.
if (!input_queues_[i].empty()) {
Timestamp bound = input_queues_[i].front().Timestamp();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
} else {
Timestamp bound =
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
}
}
}
private: private:
FlowLimiterCalculatorOptions options_; FlowLimiterCalculatorOptions options_;
std::vector<std::deque<Packet>> input_queues_; std::vector<std::deque<Packet>> input_queues_;
std::deque<Timestamp> frames_in_flight_; std::deque<Timestamp> frames_in_flight_;
std::map<Timestamp, bool> allowed_;
}; };
REGISTER_CALCULATOR(FlowLimiterCalculator); REGISTER_CALCULATOR(FlowLimiterCalculator);

View File

@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "absl/time/clock.h" #include "absl/time/clock.h"
@ -32,6 +33,7 @@
#include "mediapipe/framework/tool/simulation_clock.h" #include "mediapipe/framework/tool/simulation_clock.h"
#include "mediapipe/framework/tool/simulation_clock_executor.h" #include "mediapipe/framework/tool/simulation_clock_executor.h"
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include "mediapipe/util/packet_test_util.h"
namespace mediapipe { namespace mediapipe {
@ -77,6 +79,77 @@ std::vector<T> PacketValues(const std::vector<Packet>& packets) {
return result; return result;
} }
template <typename T>
std::vector<Packet> MakePackets(std::vector<std::pair<Timestamp, T>> contents) {
std::vector<Packet> result;
for (auto& entry : contents) {
result.push_back(MakePacket<T>(entry.second).At(entry.first));
}
return result;
}
std::string SourceString(Timestamp t) {
return (t.IsSpecialValue())
? t.DebugString()
: absl::StrCat("Timestamp(", t.DebugString(), ")");
}
template <typename PacketContainer, typename PacketContent>
class PacketsEqMatcher
: public ::testing::MatcherInterface<const PacketContainer&> {
public:
PacketsEqMatcher(PacketContainer packets) : packets_(packets) {}
void DescribeTo(::std::ostream* os) const override {
*os << "The expected packet contents: \n";
Print(packets_, os);
}
bool MatchAndExplain(
const PacketContainer& value,
::testing::MatchResultListener* listener) const override {
if (!Equals(packets_, value)) {
if (listener->IsInterested()) {
*listener << "The actual packet contents: \n";
Print(value, listener->stream());
}
return false;
}
return true;
}
private:
bool Equals(const PacketContainer& c1, const PacketContainer& c2) const {
if (c1.size() != c2.size()) {
return false;
}
for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) {
Packet p1 = *i1, p2 = *i2;
if (p1.Timestamp() != p2.Timestamp() ||
p1.Get<PacketContent>() != p2.Get<PacketContent>()) {
return false;
}
}
return true;
}
void Print(const PacketContainer& packets, ::std::ostream* os) const {
for (auto it = packets.begin(); it != packets.end(); ++it) {
const Packet& packet = *it;
*os << (it == packets.begin() ? "{" : "") << "{"
<< SourceString(packet.Timestamp()) << ", "
<< packet.Get<PacketContent>() << "}"
<< (std::next(it) == packets.end() ? "}" : ", ");
}
}
const PacketContainer packets_;
};
template <typename PacketContainer, typename PacketContent>
::testing::Matcher<const PacketContainer&> PackestEq(
const PacketContainer& packets) {
return MakeMatcher(
new PacketsEqMatcher<PacketContainer, PacketContent>(packets));
}
// A Calculator::Process callback function. // A Calculator::Process callback function.
typedef std::function<absl::Status(const InputStreamShardSet&, typedef std::function<absl::Status(const InputStreamShardSet&,
OutputStreamShardSet*)> OutputStreamShardSet*)>
@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
input_packets_[17], input_packets_[19], input_packets_[20], input_packets_[17], input_packets_[19], input_packets_[20],
}; };
EXPECT_EQ(out_1_packets_, expected_output); EXPECT_EQ(out_1_packets_, expected_output);
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. // The timestamps released by FlowLimiterCalculator for in_1_sampled,
// plus input_packets_[21].
std::vector<Packet> expected_output_2 = { std::vector<Packet> expected_output_2 = {
input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[0], input_packets_[2], input_packets_[4],
input_packets_[14], input_packets_[17], input_packets_[19], input_packets_[14], input_packets_[17], input_packets_[19],
input_packets_[20], input_packets_[20], input_packets_[21],
}; };
EXPECT_EQ(out_2_packets, expected_output_2); EXPECT_EQ(out_2_packets, expected_output_2);
} }
@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
// The processing time "sleep_time" is reduced from 22ms to 12ms to create // The processing time "sleep_time" is reduced from 22ms to 12ms to create
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
// Configure the test. // Configure the test.
SetUpInputData(); SetUpInputData();
SetUpSimulationClock(); SetUpSimulationClock();
@ -699,11 +776,10 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
} }
)pb"); )pb");
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb( auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
max_in_flight: 1 R"pb(
max_in_queue: 0 max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000 # 100 ms
in_flight_timeout: 100000 # 100 ms )pb");
)pb");
std::map<std::string, Packet> side_packets = { std::map<std::string, Packet> side_packets = {
{"limiter_options", {"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)}, MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[0], input_packets_[2], input_packets_[15],
input_packets_[17], input_packets_[19], input_packets_[17], input_packets_[19],
}; };
EXPECT_EQ(out_1_packets_, expected_output); EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
std::vector<Packet> expected_output_2 = { std::vector<Packet> expected_output_2 = {
input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[0], input_packets_[2], input_packets_[4],
input_packets_[15], input_packets_[17], input_packets_[19], input_packets_[15], input_packets_[17], input_packets_[19],
}; };
EXPECT_EQ(out_2_packets, expected_output_2); EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
// Validate the ALLOW stream output.
std::vector<Packet> expected_allow = MakePackets<bool>( //
{{Timestamp(0), true}, {Timestamp(10000), false},
{Timestamp(20000), true}, {Timestamp(30000), false},
{Timestamp(40000), true}, {Timestamp(50000), false},
{Timestamp(60000), false}, {Timestamp(70000), false},
{Timestamp(80000), false}, {Timestamp(90000), false},
{Timestamp(100000), false}, {Timestamp(110000), false},
{Timestamp(120000), false}, {Timestamp(130000), false},
{Timestamp(140000), false}, {Timestamp(150000), true},
{Timestamp(160000), false}, {Timestamp(170000), true},
{Timestamp(180000), false}, {Timestamp(190000), true},
{Timestamp(200000), false}});
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
}
// Shows how FlowLimiterCalculator releases auxiliary input packets.
// In this test, auxiliary input packets arrive at twice the primary rate.
TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
// Configure the test.
SetUpInputData();
SetUpSimulationClock();
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in_1'
input_stream: 'in_2'
node {
calculator: 'FlowLimiterCalculator'
input_side_packet: 'OPTIONS:limiter_options'
input_stream: 'in_1'
input_stream: 'in_2'
input_stream: 'FINISHED:out_1'
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
output_stream: 'in_1_sampled'
output_stream: 'in_2_sampled'
output_stream: 'ALLOW:allow'
}
node {
calculator: 'SleepCalculator'
input_side_packet: 'WARMUP_TIME:warmup_time'
input_side_packet: 'SLEEP_TIME:sleep_time'
input_side_packet: 'CLOCK:clock'
input_stream: 'PACKET:in_1_sampled'
output_stream: 'PACKET:out_1'
}
)pb");
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
R"pb(
max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000 # 1s
)pb");
std::map<std::string, Packet> side_packets = {
{"limiter_options",
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
{"warmup_time", MakePacket<int64>(22000)},
{"sleep_time", MakePacket<int64>(22000)},
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
};
// Start the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
out_1_packets_.push_back(p);
return absl::OkStatus();
}));
std::vector<Packet> out_2_packets;
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
out_2_packets.push_back(p);
return absl::OkStatus();
}));
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
allow_packets_.push_back(p);
return absl::OkStatus();
}));
simulation_clock_->ThreadStart();
MP_ASSERT_OK(graph_.StartRun(side_packets));
// Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2.
clock_->Sleep(absl::Microseconds(10000));
for (int i = 1; i < 10; ++i) {
if (i % 2 == 0) {
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
}
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i]));
clock_->Sleep(absl::Microseconds(10000));
}
// Finish the graph run.
MP_EXPECT_OK(graph_.CloseAllPacketSources());
clock_->Sleep(absl::Microseconds(40000));
MP_EXPECT_OK(graph_.WaitUntilDone());
simulation_clock_->ThreadFinish();
// Validate the output.
// Input packets 4 and 8 are dropped due to max_in_flight.
std::vector<Packet> expected_output = {
input_packets_[2],
input_packets_[6],
};
EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
// Packets following input packets 2 and 6, and not input packets 4 and 8.
std::vector<Packet> expected_output_2 = {
input_packets_[1], input_packets_[2], input_packets_[3],
input_packets_[6], input_packets_[7],
};
EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
// Validate the ALLOW stream output.
std::vector<Packet> expected_allow =
MakePackets<bool>({{Timestamp(20000), 1},
{Timestamp(40000), 0},
{Timestamp(60000), 1},
{Timestamp(80000), 0}});
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
} }
} // anonymous namespace } // anonymous namespace

View File

@ -25,12 +25,12 @@
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/util/tflite/tflite_gpu_runner.h" #include "mediapipe/util/tflite/tflite_gpu_runner.h"
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h" #include "mediapipe/util/android/file/base/filesystem.h"
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID #endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
@ -231,7 +231,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
return tflite_gpu_runner_->Build(); return tflite_gpu_runner_->Build();
} }
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions& options,
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
@ -318,7 +318,7 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
tflite::gpu::TFLiteGPURunner* gpu_runner) const { tflite::gpu::TFLiteGPURunner* gpu_runner) const {
return absl::OkStatus(); return absl::OkStatus();
} }
#endif // MEDIAPIPE_ANDROID #endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract( absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
CalculatorContract* cc) { CalculatorContract* cc) {

View File

@ -15,7 +15,7 @@
# Description: # Description:
# The dependencies of mediapipe. # The dependencies of mediapipe.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
licenses(["notice"]) licenses(["notice"])
@ -38,7 +38,7 @@ bzl_library(
visibility = ["//mediapipe/framework:__subpackages__"], visibility = ["//mediapipe/framework:__subpackages__"],
) )
proto_library( mediapipe_proto_library(
name = "proto_descriptor_proto", name = "proto_descriptor_proto",
srcs = ["proto_descriptor.proto"], srcs = ["proto_descriptor.proto"],
visibility = [ visibility = [
@ -47,13 +47,6 @@ proto_library(
], ],
) )
mediapipe_cc_proto_library(
name = "proto_descriptor_cc_proto",
srcs = ["proto_descriptor.proto"],
visibility = ["//mediapipe/framework:__subpackages__"],
deps = [":proto_descriptor_proto"],
)
cc_library( cc_library(
name = "aligned_malloc_and_free", name = "aligned_malloc_and_free",
hdrs = ["aligned_malloc_and_free.h"], hdrs = ["aligned_malloc_and_free.h"],

View File

@ -4,6 +4,9 @@
""".bzl file for mediapipe open source build configs.""" """.bzl file for mediapipe open source build configs."""
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library")
load("@npm//@bazel/typescript:index.bzl", "ts_project")
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_proto_grpc//js:defs.bzl", "js_proto_library")
load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library") load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library")
def provided_args(**kwargs): def provided_args(**kwargs):
@ -71,7 +74,7 @@ def mediapipe_proto_library(
def_jspb_proto: define the jspb_proto_library target def_jspb_proto: define the jspb_proto_library target
def_options_lib: define the mediapipe_options_library target def_options_lib: define the mediapipe_options_library target
""" """
_ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps] _ignore = [def_portable_proto, def_objc_proto, def_java_proto, portable_deps] # buildifier: disable=unused-variable
# The proto_library targets for the compiled ".proto" source files. # The proto_library targets for the compiled ".proto" source files.
proto_deps = [":" + name] proto_deps = [":" + name]
@ -119,6 +122,24 @@ def mediapipe_proto_library(
compatible_with = compatible_with, compatible_with = compatible_with,
)) ))
if def_jspb_proto:
js_deps = replace_deps(deps, "_proto", "_jspb_proto", False)
proto_library(
name = replace_suffix(name, "_proto", "_lib_proto"),
srcs = srcs,
deps = deps,
)
js_proto_library(
name = replace_suffix(name, "_proto", "_jspb_proto"),
protos = [replace_suffix(name, "_proto", "_lib_proto")],
output_mode = "NO_PREFIX_FLAT",
# Need to specify this to work around bug in js_proto_library()
# https://github.com/bazelbuild/rules_nodejs/issues/3503
legacy_path = "unused",
deps = js_deps,
visibility = visibility,
)
if def_options_lib: if def_options_lib:
cc_deps = replace_deps(deps, "_proto", "_cc_proto") cc_deps = replace_deps(deps, "_proto", "_cc_proto")
mediapipe_options_library(**provided_args( mediapipe_options_library(**provided_args(
@ -182,3 +203,35 @@ def mediapipe_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps
default_runtime = "@com_google_protobuf//:protobuf", default_runtime = "@com_google_protobuf//:protobuf",
alwayslink = 1, alwayslink = 1,
)) ))
def mediapipe_ts_library(
name,
srcs,
visibility = None,
deps = [],
testonly = 0,
allow_unoptimized_namespaces = False):
"""Generate ts_project for MediaPipe open source version.
Args:
name: the name of the cc_proto_library.
srcs: the .proto files of the cc_proto_library for Bazel use.
visibility: visibility of this target.
deps: a list of dependency labels for Bazel use; must be cc_proto_library.
testonly: test only or not.
allow_unoptimized_namespaces: ignored, used only internally
"""
_ignore = [allow_unoptimized_namespaces] # buildifier: disable=unused-variable
ts_project(**provided_args(
name = name,
srcs = srcs,
visibility = visibility,
deps = deps + [
"@npm//@types/offscreencanvas",
"@npm//@types/google-protobuf",
],
testonly = testonly,
declaration = True,
tsconfig = "//:tsconfig.json",
))

View File

@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
"""Prints a summary of the model.""" """Prints a summary of the model."""
self._model.summary() self._model.summary()
# TODO: Remove this method when all tasks use Metadata writer
def export_tflite( def export_tflite(
self, self,
export_dir: str, export_dir: str,
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
Args: Args:
export_dir: The directory to save exported files. export_dir: The directory to save exported files.
tflite_filename: File name to save tflite model. The full export path is tflite_filename: File name to save TFLite model. The full export path is
{export_dir}/{tflite_filename}. {export_dir}/{tflite_filename}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
tf.io.gfile.makedirs(export_dir) tf.io.gfile.makedirs(export_dir)
tflite_filepath = os.path.join(export_dir, tflite_filename) tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model. tflite_model = model_util.convert_to_tflite(
model_util.export_tflite(
model=self._model, model=self._model,
tflite_filepath=tflite_filepath,
quantization_config=quantization_config, quantization_config=quantization_config,
preprocess=preprocess) preprocess=preprocess)
model_util.save_tflite(
tflite_model=tflite_model, tflite_file=tflite_filepath)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath) 'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
return len(train_data) // batch_size return len(train_data) // batch_size
def export_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
tflite_filepath: str,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None): preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
"""Converts the model to tflite format and saves it. """Converts the input Keras model to TFLite format.
Args: Args:
model: model to be converted to tflite. model: Keras model to be converted to TFLite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, label, quantization. The callable takes three arguments in order: feature, label,
and is_training. and is_training.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
Returns:
bytearray of TFLite model
"""
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
save_path = os.path.join(temp_dir, 'saved_model') save_path = os.path.join(temp_dir, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf') model.save(save_path, include_optimizer=False, save_format='tf')
@ -122,9 +119,22 @@ def export_tflite(
converter.target_spec.supported_ops = supported_ops converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert() tflite_model = converter.convert()
return tflite_model
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
"""Saves TFLite file to tflite_file.
Args:
tflite_model: A valid flatbuffer representing the TFLite model.
tflite_file: File path to save TFLite model.
"""
if tflite_file is None:
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
f.write(tflite_model) f.write(tflite_model)
tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
class LiteRunner(object): class LiteRunner(object):
"""A runner to do inference with the TFLite model.""" """A runner to do inference with the TFLite model."""
def __init__(self, tflite_filepath: str): def __init__(self, tflite_model: bytearray):
"""Initializes Lite runner with tflite model file. """Initializes Lite runner from TFLite model buffer.
Args: Args:
tflite_filepath: File path to the TFLite model. tflite_model: A valid flatbuffer representing the TFLite model.
""" """
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
tflite_model = f.read()
self.interpreter = tf.lite.Interpreter(model_content=tflite_model) self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors() self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details() self.input_details = self.interpreter.get_input_details()
@ -250,9 +258,9 @@ class LiteRunner(object):
return output_tensors return output_tensors
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner': def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
"""Returns a `LiteRunner` from file path to TFLite model.""" """Returns a `LiteRunner` from flatbuffer of the TFLite model."""
lite_runner = LiteRunner(tflite_filepath) lite_runner = LiteRunner(tflite_buffer)
return lite_runner return lite_runner

View File

@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
'name': 'test' 'name': 'test'
}) })
def test_export_tflite(self): def test_convert_to_tflite(self):
input_dim = 4 input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_model = model_util.convert_to_tflite(model)
model_util.export_tflite(model, tflite_file)
test_util.test_tflite( test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
testcase_name='float16_quantize', testcase_name='float16_quantize',
config=quantization.QuantizationConfig.for_float16(), config=quantization.QuantizationConfig.for_float16(),
model_size=1468)) model_size=1468))
def test_export_tflite_quantized(self, config, model_size): def test_convert_to_tflite_quantized(self, config, model_size):
input_dim = 16 input_dim = 16
num_classes = 2 num_classes = 2
max_input_value = 5 max_input_value = 5
model = test_util.build_model( model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes) input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite( tflite_model = model_util.convert_to_tflite(
model=model, tflite_filepath=tflite_file, quantization_config=config) model=model, quantization_config=config)
self.assertTrue( self.assertTrue(
test_util.test_tflite( test_util.test_tflite(
keras_model=model, keras_model=model,
tflite_file=tflite_file, tflite_model=tflite_model,
size=[1, input_dim], size=[1, input_dim],
high=max_input_value, high=max_input_value,
atol=1e-00)) atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300) self.assertNear(len(tflite_model), model_size, 300)
def test_save_tflite(self):
input_dim = 4
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_model = model_util.convert_to_tflite(model)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
test_util.test_tflite_file(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
return model return model
def is_same_output(tflite_file: str, def is_same_output(tflite_model: bytearray,
keras_model: tf.keras.Model, keras_model: tf.keras.Model,
input_tensors: Union[List[tf.Tensor], tf.Tensor], input_tensors: Union[List[tf.Tensor], tf.Tensor],
atol: float = 1e-04) -> bool: atol: float = 1e-04) -> bool:
"""Returns if the output of TFLite model and keras model are identical.""" """Returns if the output of TFLite model and keras model are identical."""
# Gets output from lite model. # Gets output from lite model.
lite_runner = model_util.get_lite_runner(tflite_file) lite_runner = model_util.get_lite_runner(tflite_model)
lite_output = lite_runner.run(input_tensors) lite_output = lite_runner.run(input_tensors)
# Gets output from keras model. # Gets output from keras model.
@ -95,12 +95,41 @@ def is_same_output(tflite_file: str,
def test_tflite(keras_model: tf.keras.Model, def test_tflite(keras_model: tf.keras.Model,
tflite_file: str, tflite_model: bytearray,
size: Union[int, List[int]], size: Union[int, List[int]],
high: float = 1, high: float = 1,
atol: float = 1e-04) -> bool: atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical. """Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_model: Input TFLite model flatbuffer.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_model=tflite_model,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)
def test_tflite_file(keras_model: tf.keras.Model,
tflite_file: bytearray,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args: Args:
keras_model: Input TensorFlow Keras model. keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file. tflite_file: Input TFLite model file.
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
True if the output of TFLite model and TF Keras model are identical. True if the output of TFLite model and TF Keras model are identical.
Otherwise, False. Otherwise, False.
""" """
random_input = create_random_sample(size=size, high=high) with tf.io.gfile.GFile(tflite_file, "rb") as f:
random_input = tf.convert_to_tensor(random_input) tflite_model = f.read()
return test_tflite(keras_model, tflite_model, size, high, atol)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)

View File

@ -1,4 +0,0 @@
# MediaPipe Model Maker Internal Library
This directory contains model maker library for internal users and experimental
purposes.

View File

@ -1 +0,0 @@
"""Model maker internal library."""

View File

@ -81,6 +81,8 @@ py_library(
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing", "//mediapipe/model_maker/python/vision/core:image_preprocessing",
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
], ],
) )
@ -88,7 +90,11 @@ py_library(
name = "image_classifier_test_lib", name = "image_classifier_test_lib",
testonly = 1, testonly = 1,
srcs = ["image_classifier_test.py"], srcs = ["image_classifier_test.py"],
deps = [":image_classifier_import"], data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
deps = [
":image_classifier_import",
"//mediapipe/tasks/python/test:test_utils",
],
) )
py_test( py_test(

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""APIs to train image classifier model.""" """APIs to train image classifier model."""
import os
from typing import List, Optional from typing import List, Optional
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
self, self,
model_name: str = 'model.tflite', model_name: str = 'model.tflite',
quantization_config: Optional[quantization.QuantizationConfig] = None): quantization_config: Optional[quantization.QuantizationConfig] = None):
"""Converts the model to the requested formats and exports to a file. """Converts and saves the model to a TFLite file with metadata included.
Note that only the TFLite file is needed for deployment. This function also
saves a metadata.json file to the same directory as the TFLite file which
can be used to interpret the metadata content in the TFLite file.
Args: Args:
model_name: File name to save tflite model. The full export path is model_name: File name to save TFLite model with metadata. The full export
{export_dir}/{tflite_filename}. path is {self._hparams.model_dir}/{model_name}.
quantization_config: The configuration for model quantization. quantization_config: The configuration for model quantization.
""" """
super().export_tflite( if not tf.io.gfile.exists(self._hparams.model_dir):
self._hparams.model_dir, tf.io.gfile.makedirs(self._hparams.model_dir)
model_name, tflite_file = os.path.join(self._hparams.model_dir, model_name)
quantization_config, metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
tflite_model = model_util.convert_to_tflite(
model=self._model,
quantization_config=quantization_config,
preprocess=self._preprocess) preprocess=self._preprocess)
writer = image_classifier_writer.MetadataWriter.create(
tflite_model,
self._model_spec.mean_rgb,
self._model_spec.stddev_rgb,
labels=metadata_writer.Labels().add(self._label_names))
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, 'w') as f:
f.write(metadata_json)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import filecmp
import os import os
from absl.testing import parameterized from absl.testing import parameterized
@ -19,6 +20,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision import image_classifier from mediapipe.model_maker.python.vision import image_classifier
from mediapipe.tasks.python.test import test_utils
def _fill_image(rgb, image_size): def _fill_image(rgb, image_size):
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data) validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self): def test_efficientnetlite0_model_train_and_export(self):
hparams = image_classifier.HParams( hparams = image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True) train_epochs=1, batch_size=1, shuffle=True)
model = image_classifier.ImageClassifier.create( model = image_classifier.ImageClassifier.create(
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
validation_data=self.test_data) validation_data=self.test_data)
self._test_accuracy(model) self._test_accuracy(model)
# Test export_model
model.export_model()
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
self.assertTrue(os.path.exists(output_tflite_file))
self.assertGreater(os.path.getsize(output_tflite_file), 0)
self.assertTrue(os.path.exists(output_metadata_file))
self.assertGreater(os.path.getsize(output_metadata_file), 0)
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
def _test_accuracy(self, model, threshold=0.0): def _test_accuracy(self, model, threshold=0.0):
_, accuracy = model.evaluate(self.test_data) _, accuracy = model.evaluate(self.test_data)
self.assertGreaterEqual(accuracy, threshold) self.assertGreaterEqual(accuracy, threshold)

View File

@ -0,0 +1,23 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "testdata",
srcs = ["metadata.json"],
)

View File

@ -0,0 +1,68 @@
{
"name": "ImageClassifier",
"description": "Identify the most prominent object in the image from a known set of categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "image",
"description": "Input image to be processed.",
"content": {
"content_properties_type": "ImageProperties",
"content_properties": {
"color_space": "RGB"
}
},
"process_units": [
{
"options_type": "NormalizationOptions",
"options": {
"mean": [
0.0
],
"std": [
255.0
]
}
}
],
"stats": {
"max": [
1.0
],
"min": [
0.0
]
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
]
}
],
"min_parser_version": "1.0.0"
}

View File

@ -87,6 +87,7 @@ cc_library(
cc_library( cc_library(
name = "builtin_task_graphs", name = "builtin_task_graphs",
deps = [ deps = [
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",

View File

@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) {
// TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&" // TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
// as the input argument. Investigate why bazel non-optimized mode // as the input argument. Investigate why bazel non-optimized mode
// triggers a memory allocation bug in Eigen::internal::aligned_free(). // triggers a memory allocation bug in Eigen::internal::aligned_free().
[](const Eigen::MatrixXf& matrix) { [](const Eigen::MatrixXf& matrix, bool transpose) {
// MakePacket copies the data. // MakePacket copies the data.
if (transpose) {
return MakePacket<Matrix>(matrix.transpose());
}
return MakePacket<Matrix>(matrix); return MakePacket<Matrix>(matrix);
}, },
R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray. R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) {
Args: Args:
matrix: A 2d numpy float ndarray. matrix: A 2d numpy float ndarray.
transpose: A boolean to indicate if the input matrix needs to be transposed.
Default to False.
Returns: Returns:
A MediaPipe Matrix Packet. A MediaPipe Matrix Packet.
@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) {
np.array([[.1, .2, .3], [.4, .5, .6]]) np.array([[.1, .2, .3], [.4, .5, .6]])
matrix = mp.packet_getter.get_matrix(packet) matrix = mp.packet_getter.get_matrix(packet)
)doc", )doc",
py::arg("matrix"), py::arg("transpose") = false,
py::return_value_policy::move); py::return_value_policy::move);
} // NOLINT(readability/fn_size) } // NOLINT(readability/fn_size)

View File

@ -65,6 +65,7 @@ cc_library(
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode", "//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",

View File

@ -18,12 +18,14 @@ limitations under the License.
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -38,12 +40,16 @@ namespace audio_classifier {
namespace { namespace {
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsName[] = "classifications_out";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications_out";
constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig(
} }
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap( subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get()); options_proto.get());
subgraph.Out(kClassificationResultTag) subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
.SetName(kClassificationResultStreamName) >> graph.Out(kClassificationsTag);
graph.Out(kClassificationResultTag); subgraph.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph.Out(kTimestampedClassificationsTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
return options_proto; return options_proto;
} }
absl::StatusOr<ClassificationResult> ConvertOutputPackets( absl::StatusOr<std::vector<AudioClassifierResult>> ConvertOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) { absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) { if (!status_or_packets.ok()) {
return status_or_packets.status(); return status_or_packets.status();
} }
return status_or_packets.value()[kClassificationResultStreamName] auto classification_results =
.Get<ClassificationResult>(); status_or_packets.value()[kTimestampedClassificationsName]
.Get<std::vector<ClassificationResult>>();
std::vector<AudioClassifierResult> results;
results.reserve(classification_results.size());
for (const auto& classification_result : classification_results) {
results.emplace_back(ConvertToClassificationResult(classification_result));
}
return results;
}
absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
if (!status_or_packets.ok()) {
return status_or_packets.status();
}
return ConvertToClassificationResult(
status_or_packets.value()[kClassificationsName]
.Get<ClassificationResult>());
} }
} // namespace } // namespace
@ -118,7 +143,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
auto result_callback = options->result_callback; auto result_callback = options->result_callback;
packets_callback = packets_callback =
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) { [=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
result_callback(ConvertOutputPackets(status_or_packets)); result_callback(ConvertAsyncOutputPackets(status_or_packets));
}; };
} }
return core::AudioTaskApiFactory::Create<AudioClassifier, return core::AudioTaskApiFactory::Create<AudioClassifier,
@ -128,7 +153,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<ClassificationResult> AudioClassifier::Classify( absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
Matrix audio_clip, double audio_sample_rate) { Matrix audio_clip, double audio_sample_rate) {
return ConvertOutputPackets(ProcessAudioClip( return ConvertOutputPackets(ProcessAudioClip(
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))}, {{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},

View File

@ -18,12 +18,13 @@ limitations under the License.
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
@ -32,6 +33,10 @@ namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier { namespace audio_classifier {
// Alias the shared ClassificationResult struct as result type.
using AudioClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a mediapipe audio classifier task. // The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions { struct AudioClassifierOptions {
// Base options for configuring Task library, such as specifying the TfLite // Base options for configuring Task library, such as specifying the TfLite
@ -59,9 +64,8 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data. // The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM. // to RunningMode::AUDIO_STREAM.
std::function<void( std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
absl::StatusOr<components::containers::proto::ClassificationResult>)> nullptr;
result_callback = nullptr;
}; };
// Performs audio classification on audio clips or audio stream. // Performs audio classification on audio clips or audio stream.
@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// required to provide the corresponding audio sample rate along with the // required to provide the corresponding audio sample rate along with the
// input audio clips. // input audio clips.
// //
// For each audio clip, the output classifications are grouped in a // The input audio clip may be longer than what the model is able to process
// ClassificationResult object that has three dimensions: // in a single inference. When this occurs, the input audio clip is split into
// Classification head: // multiple chunks starting at different timestamps. For this reason, this
// The prediction heads targeting different audio classification tasks // function returns a vector of ClassificationResult objects, each associated
// such as audio event classification and bird sound classification. // with a timestamp corresponding to the start (in milliseconds) of the chunk
// Classification timestamp: // data that was classified, e.g:
// The start time (in milliseconds) of each audio clip that is sent to the //
// model for audio classification. As the audio classification models take // ClassificationResult #0 (first chunk of data):
// a fixed number of audio samples, long audio clips will be framed to // timestamp_ms: 0 (starts at 0ms)
// multiple buffers (with the desired number of audio samples) during // classifications #0 (single head model):
// preprocessing. // category #0:
// Classification category: // category_name: "Speech"
// The list of the classification categories that model predicts per // score: 0.6
// framed audio clip. // category #1:
// category_name: "Music"
// score: 0.2
// ClassificationResult #1 (second chunk of data):
// timestamp_ms: 800 (starts at 800ms)
// classifications #0 (single head model):
// category #0:
// category_name: "Speech"
// score: 0.5
// category #1:
// category_name: "Silence"
// score: 0.1
// ...
//
// TODO: Use `sample_rate` in AudioClassifierOptions by default // TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional. // and makes `audio_sample_rate` optional.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
mediapipe::Matrix audio_clip, double audio_sample_rate); mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio // Sends audio data (a block in a continuous audio stream) to perform audio
@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// milliseconds) to indicate the start time of the input audio block. The // milliseconds) to indicate the start time of the input audio block. The
// timestamps must be monotonically increasing. // timestamps must be monotonically increasing.
// //
// The output classifications are grouped in a ClassificationResult object // The input audio block may be longer than what the model is able to process
// that has three dimensions: // in a single inference. When this occurs, the input audio block is split
// Classification head: // into multiple chunks. For this reason, the callback may be called multiple
// The prediction heads targeting different audio classification tasks // times (once per chunk) for each call to this function.
// such as audio event classification and bird sound classification.
// Classification timestamp :
// The start time (in milliseconds) of the framed audio block that is sent
// to the model for audio classification.
// Classification category:
// The list of the classification categories that model predicts per
// framed audio clip.
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms); absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
// Shuts down the AudioClassifier when all works are done. // Shuts down the AudioClassifier when all works are done.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <stdint.h> #include <stdint.h>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kPacketTag[] = "PACKET"; constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the audio classifier
// graph.
struct AudioClassifierOutputStreams {
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
absl::Status SanityCheckOptions( absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) { const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() && if (options.base_options().use_stream_mode() &&
@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator(
// series stream header with sample rate info. // series stream header with sample rate info.
// //
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The aggregated classification result object that has 3 dimensions: // The classification results aggregated by head. Only produces results if
// (classification head, classification timestamp, classification category). // the graph if the 'use_stream_mode' option is true.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Only
// produces results if the graph if the 'use_stream_mode' option is false.
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph" // calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in" // input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in" // input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATIONS:classifications"
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options { // options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext] // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// { // {
@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
.base_options() .base_options()
.use_stream_mode(); .use_stream_mode();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto output_streams,
BuildAudioClassificationTask( BuildAudioClassificationTask(
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources, sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)], graph[Input<Matrix>(kAudioTag)],
@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
? absl::nullopt ? absl::nullopt
: absl::make_optional(graph[Input<double>(kSampleRateTag)]), : absl::make_optional(graph[Input<double>(kSampleRateTag)]),
graph)); graph));
classification_result_out >> output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationResultTag)]; graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// audio_in: (mediapipe::Matrix) stream to run audio classification on. // audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate. // sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask( absl::StatusOr<AudioClassifierOutputStreams> BuildAudioClassificationTask(
const proto::AudioClassifierGraphOptions& task_options, const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in, const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) { absl::optional<Source<double>> sample_rate_in, Graph& graph) {
@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on // Time aggregation is only needed for performing audio classification on
// audio files. Disables time aggregration by not connecting the // audio files. Disables timestamp aggregation by not connecting the
// "TIMESTAMPS" streams. // "TIMESTAMPS" streams.
if (!use_stream_mode) { if (!use_stream_mode) {
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag); audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
} }
// Outputs the aggregated classification result as the subgraph output // Output both streams as graph output streams/
// stream. return AudioClassifierOutputStreams{
return postprocessing[Output<ClassificationResult>( /*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)]; kClassificationsTag)],
/*timestamped_classifications=*/
postprocessing[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)],
};
} }
}; };

View File

@ -32,13 +32,11 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe { namespace mediapipe {
@ -49,7 +47,6 @@ namespace {
using ::absl::StatusOr; using ::absl::StatusOr;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) {
return matrix_mapping.matrix(); return matrix_mapping.matrix();
} }
void CheckSpeechClassificationResult(const ClassificationResult& result) { void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
EXPECT_THAT(result.classifications_size(), testing::Eq(1)); int expected_num_categories = 521) {
EXPECT_EQ(result.classifications(0).head_name(), "scores"); EXPECT_EQ(result.size(), 5);
EXPECT_EQ(result.classifications(0).head_index(), 0); // Ignore last result, which operates on a too small chunk to return relevant
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5)); // results.
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925}; std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
for (int i = 0; i < timestamps_ms.size(); i++) { for (int i = 0; i < timestamps_ms.size(); i++) {
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
testing::Eq(521)); EXPECT_EQ(result[i].classifications.size(), 1);
const auto* top_category = auto classifications = result[i].classifications[0];
&result.classifications(0).entries(0).categories(0); EXPECT_EQ(classifications.head_index, 0);
EXPECT_THAT(top_category->category_name(), testing::Eq("Speech")); EXPECT_EQ(classifications.head_name, "scores");
EXPECT_GT(top_category->score(), 0.9f); EXPECT_EQ(classifications.categories.size(), expected_num_categories);
EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(), auto category = classifications.categories[0];
timestamps_ms[i]); EXPECT_EQ(category.index, 0);
EXPECT_EQ(category.category_name, "Speech");
EXPECT_GT(category.score, 0.9f);
} }
} }
void CheckTwoHeadsClassificationResult(const ClassificationResult& result) { void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
EXPECT_THAT(result.classifications_size(), testing::Eq(2)); EXPECT_GE(result.size(), 1);
// Checks classification head #1. EXPECT_LE(result.size(), 2);
EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification"); // Check first result.
EXPECT_EQ(result.classifications(0).head_index(), 0); EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[0].classifications.size(), 2);
testing::Eq(521)); // Check first head.
const auto* top_category = EXPECT_EQ(result[0].classifications[0].head_index, 0);
&result.classifications(0).entries(0).categories(0); EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
EXPECT_THAT(top_category->category_name(), EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
testing::Eq("Environmental noise")); EXPECT_EQ(result[0].classifications[0].categories[0].index, 508);
EXPECT_GT(top_category->score(), 0.5f); EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0); "Environmental noise");
if (result.classifications(0).entries_size() == 2) { EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
top_category = &result.classifications(0).entries(1).categories(0); // Check second head.
EXPECT_THAT(top_category->category_name(), testing::Eq("Silence")); EXPECT_EQ(result[0].classifications[1].head_index, 1);
EXPECT_GT(top_category->score(), 0.99f); EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975); EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
"Chestnut-crowned Antpitta");
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
// Check second result, if present.
if (result.size() == 2) {
EXPECT_EQ(result[1].timestamp_ms, 975);
EXPECT_EQ(result[1].classifications.size(), 2);
// Check first head.
EXPECT_EQ(result[1].classifications[0].head_index, 0);
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
EXPECT_EQ(result[1].classifications[0].categories[0].index, 494);
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
"Silence");
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
// Check second head.
EXPECT_EQ(result[1].classifications[1].head_index, 1);
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
EXPECT_EQ(result[1].classifications[1].categories[0].index, 1);
EXPECT_EQ(result[1].classifications[1].categories[0].category_name,
"White-breasted Wood-Wren");
EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f);
} }
// Checks classification head #2.
EXPECT_EQ(result.classifications(1).head_name(), "bird_classification");
EXPECT_EQ(result.classifications(1).head_index(), 1);
EXPECT_THAT(result.classifications(1).entries(0).categories_size(),
testing::Eq(5));
top_category = &result.classifications(1).entries(0).categories(0);
EXPECT_THAT(top_category->category_name(),
testing::Eq("Chestnut-crowned Antpitta"));
EXPECT_GT(top_category->score(), 0.9f);
EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0);
} }
ClassificationResult GenerateSpeechClassificationResult() { void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
return ParseTextProtoOrDie<ClassificationResult>( EXPECT_EQ(outputs.size(), 5);
R"pb(classifications { // Ignore last result, which operates on a too small chunk to return relevant
head_index: 0 // results.
head_name: "scores" for (int i = 0; i < outputs.size() - 1; i++) {
entries { EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
categories { index: 0 score: 0.94140625 category_name: "Speech" } EXPECT_EQ(outputs[i].classifications.size(), 1);
timestamp_ms: 0 EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
} EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
entries { EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1);
categories { index: 0 score: 0.9921875 category_name: "Speech" } EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0);
timestamp_ms: 975 EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name,
} "Speech");
entries { EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f);
categories { index: 0 score: 0.98828125 category_name: "Speech" }
timestamp_ms: 1950
}
entries {
categories { index: 0 score: 0.99609375 category_name: "Speech" }
timestamp_ms: 2925
}
entries {
# categories are filtered out due to the low scores.
timestamp_ms: 3900
}
})pb");
}
void CheckStreamingModeClassificationResult(
std::vector<ClassificationResult> outputs) {
ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6);
auto expected_results = GenerateSpeechClassificationResult();
for (int i = 0; i < outputs.size() - 1; ++i) {
EXPECT_THAT(outputs[i].classifications(0).entries(0),
EqualsProto(expected_results.classifications(0).entries(i)));
} }
int last_elem_index = outputs.size() - 1;
EXPECT_EQ(
mediapipe::Timestamp::Done().Value() / 1000,
outputs[last_elem_index].classifications(0).entries(0).timestamp_ms());
} }
class CreateFromOptionsTest : public tflite_shims::testing::Test {}; class CreateFromOptionsTest : public tflite_shims::testing::Test {};
@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->result_callback = options->result_callback =
[](absl::StatusOr<ClassificationResult> status_or_result) {}; [](absl::StatusOr<AudioClassifierResult> status_or_result) {};
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
AudioClassifier::Create(std::move(options)); AudioClassifier::Create(std::move(options));
@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result); CheckSpeechResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithResampling) { TEST_F(ClassifyTest, SucceedsWithResampling) {
@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result); CheckSpeechResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
auto result_16k_hz, auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz), audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
CheckSpeechClassificationResult(result_16k_hz); CheckSpeechResult(result_16k_hz);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result_48k_hz, auto result_48k_hz,
audio_classifier->Classify(std::move(audio_buffer_48k_hz), audio_classifier->Classify(std::move(audio_buffer_48k_hz),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckSpeechClassificationResult(result_48k_hz); CheckSpeechResult(result_48k_hz);
} }
TEST_F(ClassifyTest, SucceedsWithInsufficientData) { TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000)); auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result.classifications_size(), testing::Eq(1)); EXPECT_EQ(result.size(), 1);
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1)); EXPECT_EQ(result[0].timestamp_ms, 0);
EXPECT_THAT(result.classifications(0).entries(0).categories_size(), EXPECT_EQ(result[0].classifications.size(), 1);
testing::Eq(521)); EXPECT_EQ(result[0].classifications[0].head_index, 0);
EXPECT_THAT( EXPECT_EQ(result[0].classifications[0].head_name, "scores");
result.classifications(0).entries(0).categories(0).category_name(), EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
testing::Eq("Silence")); EXPECT_EQ(result[0].classifications[0].categories[0].index, 494);
EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(), EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
testing::FloatEq(.800781f)); "Silence");
EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f);
} }
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result); CheckTwoHeadsResult(result);
} }
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/44100)); /*audio_sample_rate=*/44100));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result); CheckTwoHeadsResult(result);
} }
TEST_F(ClassifyTest, TEST_F(ClassifyTest,
@ -413,13 +402,13 @@ TEST_F(ClassifyTest,
auto result_44k_hz, auto result_44k_hz,
audio_classifier->Classify(std::move(audio_buffer_44k_hz), audio_classifier->Classify(std::move(audio_buffer_44k_hz),
/*audio_sample_rate=*/44100)); /*audio_sample_rate=*/44100));
CheckTwoHeadsClassificationResult(result_44k_hz); CheckTwoHeadsResult(result_44k_hz);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result_16k_hz, auto result_16k_hz,
audio_classifier->Classify(std::move(audio_buffer_16k_hz), audio_classifier->Classify(std::move(audio_buffer_16k_hz),
/*audio_sample_rate=*/16000)); /*audio_sample_rate=*/16000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckTwoHeadsClassificationResult(result_16k_hz); CheckTwoHeadsResult(result_16k_hz);
} }
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kModelWithMetadata); JoinPath("./", kTestDataDirectory, kModelWithMetadata);
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->classifier_options.score_threshold = 0.35f;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
AudioClassifier::Create(std::move(options))); AudioClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult())); CheckSpeechResult(result, /*expected_num_categories=*/1);
} }
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
auto result, audio_classifier->Classify(std::move(audio_buffer), auto result, audio_classifier->Classify(std::move(audio_buffer),
/*audio_sample_rate=*/48000)); /*audio_sample_rate=*/48000));
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
// All categroies with the "Speech" label are filtered out. // All categories with the "Speech" label are filtered out.
EXPECT_THAT(result, EqualsProto(R"pb(classifications { std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
head_index: 0 for (int i = 0; i < timestamps_ms.size(); i++) {
head_name: "scores" EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
entries { timestamp_ms: 0 } EXPECT_EQ(result[i].classifications.size(), 1);
entries { timestamp_ms: 975 } auto classifications = result[i].classifications[0];
entries { timestamp_ms: 1950 } EXPECT_EQ(classifications.head_index, 0);
entries { timestamp_ms: 2925 } EXPECT_EQ(classifications.head_name, "scores");
entries { timestamp_ms: 3900 } EXPECT_TRUE(classifications.categories.empty());
})pb")); }
} }
class ClassifyAsyncTest : public tflite_shims::testing::Test {}; class ClassifyAsyncTest : public tflite_shims::testing::Test {};
@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz; options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs; std::vector<AudioClassifierResult> outputs;
options->result_callback = options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) { [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
start_col += kYamnetNumOfAudioSamples * 3; start_col += kYamnetNumOfAudioSamples * 3;
} }
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs); CheckStreamingModeResults(outputs);
} }
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
options->classifier_options.score_threshold = 0.3f; options->classifier_options.score_threshold = 0.3f;
options->running_mode = core::RunningMode::AUDIO_STREAM; options->running_mode = core::RunningMode::AUDIO_STREAM;
options->sample_rate = kSampleRateHz; options->sample_rate = kSampleRateHz;
std::vector<ClassificationResult> outputs; std::vector<AudioClassifierResult> outputs;
options->result_callback = options->result_callback =
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) { [&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result); MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
start_col += num_samples; start_col += num_samples;
} }
MP_ASSERT_OK(audio_classifier->Close()); MP_ASSERT_OK(audio_classifier->Close());
CheckStreamingModeClassificationResult(outputs); CheckStreamingModeResults(outputs);
} }
} // namespace } // namespace

View File

@ -40,6 +40,7 @@ cc_library(
"//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
@ -60,41 +61,6 @@ cc_library(
# TODO: Enable this test # TODO: Enable this test
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
# TODO: Investigate rewriting the build rule to only link # TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed. # the Bert Preprocessor if it's needed.
cc_library( cc_library(

View File

@ -163,7 +163,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
], ],
) )
@ -178,7 +178,7 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],

View File

@ -26,14 +26,14 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::proto::Embedding;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in // Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
class TensorsToEmbeddingsCalculator : public Node { class TensorsToEmbeddingsCalculator : public Node {
public: public:
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"}; static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"}; static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut); MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
bool quantize_; bool quantize_;
std::vector<std::string> head_names_; std::vector<std::string> head_names_;
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
}; };
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
for (int i = 0; i < tensors.size(); ++i) { for (int i = 0; i < tensors.size(); ++i) {
const auto& tensor = tensors[i]; const auto& tensor = tensors[i];
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
auto* embeddings = result.add_embeddings(); auto* embedding = result.add_embeddings();
embeddings->set_head_index(i); embedding->set_head_index(i);
if (!head_names_.empty()) { if (!head_names_.empty()) {
embeddings->set_head_name(head_names_[i]); embedding->set_head_name(head_names_[i]);
} }
if (quantize_) { if (quantize_) {
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries()); FillQuantizedEmbedding(tensor, embedding);
} else { } else {
FillFloatEmbeddingEntry(tensor, embeddings->add_entries()); FillFloatEmbedding(tensor, embedding);
} }
} }
kEmbeddingsOut(cc).Send(result); kEmbeddingsOut(cc).Send(result);
return absl::OkStatus(); return absl::OkStatus();
} }
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry( void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
const Tensor& tensor, EmbeddingEntry* entry) { Embedding* embedding) {
int size = tensor.shape().num_elements(); int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView(); auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>(); const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm = float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* float_embedding = entry->mutable_float_embedding(); auto* float_embedding = embedding->mutable_float_embedding();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm); float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
} }
} }
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry( void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
const Tensor& tensor, EmbeddingEntry* entry) { const Tensor& tensor, Embedding* embedding) {
int size = tensor.shape().num_elements(); int size = tensor.shape().num_elements();
auto tensor_view = tensor.GetCpuReadView(); auto tensor_view = tensor.GetCpuReadView();
const float* tensor_buffer = tensor_view.buffer<float>(); const float* tensor_buffer = tensor_view.buffer<float>();
float inv_l2_norm = float inv_l2_norm =
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
auto* values = entry->mutable_quantized_embedding()->mutable_values(); auto* values = embedding->mutable_quantized_embedding()->mutable_values();
values->resize(size); values->resize(size);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
// Normalize. // Normalize.

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
message TensorsToEmbeddingsCalculatorOptions { message TensorsToEmbeddingsCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
// The embedder options defining whether to L2-normalize or scalar-quantize // The embedder options defining whether to L2-normalize or scalar-quantize
// the outputs. // the outputs.
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options = optional mediapipe.tasks.components.processors.proto.EmbedderOptions
1; embedder_options = 1;
// The embedder head names. // The embedder head names.
repeated string head_names = 2; repeated string head_names = 2;

View File

@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" } [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
} }
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false } embedder_options { l2_normalize: false quantize: false }
@ -84,28 +84,24 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
.Get<EmbeddingResult>(); R"pb(embeddings {
EXPECT_THAT( float_embedding { values: 0.1 values: 0.2 }
result, head_index: 0
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( }
R"pb(embeddings { embeddings {
entries { float_embedding { values: 0.1 values: 0.2 } } float_embedding { values: -0.2 values: -0.3 }
head_index: 0 head_index: 1
} })pb")));
embeddings {
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
})pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: false } embedder_options { l2_normalize: false quantize: false }
@ -118,30 +114,26 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
.Get<EmbeddingResult>(); R"pb(embeddings {
EXPECT_THAT( float_embedding { values: 0.1 values: 0.2 }
result, head_index: 0
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( head_name: "foo"
R"pb(embeddings { }
entries { float_embedding { values: 0.1 values: 0.2 } } embeddings {
head_index: 0 float_embedding { values: -0.2 values: -0.3 }
head_name: "foo" head_index: 1
} head_name: "bar"
embeddings { })pb")));
entries { float_embedding { values: -0.2 values: -0.3 } }
head_index: 1
head_name: "bar"
})pb")));
} }
TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: false } embedder_options { l2_normalize: true quantize: false }
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT( EXPECT_THAT(
result, result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { R"pb(embeddings {
entries { float_embedding { values: 0.44721356 values: 0.8944271 }
float_embedding { values: 0.44721356 values: 0.8944271 }
}
head_index: 0 head_index: 0
} }
embeddings { embeddings {
entries { float_embedding { values: -0.5547002 values: -0.8320503 }
float_embedding { values: -0.5547002 values: -0.8320503 }
}
head_index: 1 head_index: 1
})pb"))); })pb")));
} }
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: false quantize: true } embedder_options { l2_normalize: false quantize: true }
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0]
.Get<EmbeddingResult>();
EXPECT_THAT(result, EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { R"pb(embeddings {
entries { quantized_embedding { values: "\x0d\x1a" } # 13,26
quantized_embedding { values: "\x0d\x1a" } # 13,26
}
head_index: 0 head_index: 0
} }
embeddings { embeddings {
entries { quantized_embedding { values: "\xe6\xda" } # -26,-38
quantized_embedding { values: "\xe6\xda" } # -26,-38
}
head_index: 1 head_index: 1
})pb"))); })pb")));
} }
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToEmbeddingsCalculator" calculator: "TensorsToEmbeddingsCalculator"
input_stream: "TENSORS:tensors" input_stream: "TENSORS:tensors"
output_stream: "EMBEDDING_RESULT:embeddings" output_stream: "EMBEDDINGS:embeddings"
options { options {
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
embedder_options { l2_normalize: true quantize: true } embedder_options { l2_normalize: true quantize: true }
@ -224,25 +204,18 @@ TEST(TensorsToEmbeddingsCalculatorTest,
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const EmbeddingResult& result = runner.Outputs() const EmbeddingResult& result =
.Get("EMBEDDING_RESULT", 0) runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
.packets[0] EXPECT_THAT(result,
.Get<EmbeddingResult>(); EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
EXPECT_THAT( R"pb(embeddings {
result, quantized_embedding { values: "\x39\x72" } # 57,114
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( head_index: 0
R"pb(embeddings { }
entries { embeddings {
quantized_embedding { values: "\x39\x72" } # 57,114 quantized_embedding { values: "\xb9\x95" } # -71,-107
} head_index: 1
head_index: 0 })pb")));
}
embeddings {
entries {
quantized_embedding { values: "\xb9\x95" } # -71,-107
}
head_index: 1
})pb")));
} }
} // namespace } // namespace

View File

@ -49,3 +49,12 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
], ],
) )
cc_library(
name = "embedding_result",
srcs = ["embedding_result.cc"],
hdrs = ["embedding_result.h"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include <iterator>
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
Embedding embedding;
if (proto.has_float_embedding()) {
embedding.float_embedding = {
std::make_move_iterator(proto.float_embedding().values().begin()),
std::make_move_iterator(proto.float_embedding().values().end())};
} else {
embedding.quantized_embedding = {
std::make_move_iterator(proto.quantized_embedding().values().begin()),
std::make_move_iterator(proto.quantized_embedding().values().end())};
}
embedding.head_index = proto.head_index();
if (proto.has_head_name()) {
embedding.head_name = proto.head_name();
}
return embedding;
}
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
EmbeddingResult embedding_result;
embedding_result.embeddings.reserve(proto.embeddings_size());
for (const auto& embedding : proto.embeddings()) {
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
}
if (proto.has_timestamp_ms()) {
embedding_result.timestamp_ms = proto.timestamp_ms();
}
return embedding_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,72 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe::tasks::components::containers {
// Embedding result for a given embedder head.
//
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
// contain data, based on whether or not the embedder was configured to perform
// scalar quantization.
struct Embedding {
// Floating-point embedding. Empty if the embedder was configured to perform
// scalar-quantization.
std::vector<float> float_embedding;
// Scalar-quantized embedding. Empty if the embedder was not configured to
// perform scalar quantization.
std::string quantized_embedding;
// The index of the embedder head (i.e. output tensor) this embedding comes
// from. This is useful for multi-head models.
int head_index;
// The optional name of the embedder head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines embedding results of a model.
struct EmbeddingResult {
// The embedding results for each head of the model.
std::vector<Embedding> embeddings;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Embedding proto to Embedding struct.
Embedding ConvertToEmbedding(const proto::Embedding& proto);
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
// struct.
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_

View File

@ -30,30 +30,31 @@ message QuantizedEmbedding {
optional bytes values = 1; optional bytes values = 1;
} }
// Floating-point or scalar-quantized embedding with an optional timestamp. // Embedding result for a given embedder head.
message EmbeddingEntry { message Embedding {
// The actual embedding, either floating-point or scalar-quantized. // The actual embedding, either floating-point or quantized.
oneof embedding { oneof embedding {
FloatEmbedding float_embedding = 1; FloatEmbedding float_embedding = 1;
QuantizedEmbedding quantized_embedding = 2; QuantizedEmbedding quantized_embedding = 2;
} }
// The optional timestamp (in milliseconds) associated to the embedding entry.
// This is useful for time series use cases, e.g. audio embedding.
optional int64 timestamp_ms = 3;
}
// Embeddings for a given embedder head.
message Embeddings {
repeated EmbeddingEntry entries = 1;
// The index of the embedder head that produced this embedding. This is useful // The index of the embedder head that produced this embedding. This is useful
// for multi-head models. // for multi-head models.
optional int32 head_index = 2; optional int32 head_index = 3;
// The name of the embedder head, which is the corresponding tensor metadata // The name of the embedder head, which is the corresponding tensor metadata
// name (if any). This is useful for multi-head models. // name (if any). This is useful for multi-head models.
optional string head_name = 3; optional string head_name = 4;
} }
// Contains one set of results per embedder head. // Embedding results for a given embedder model.
message EmbeddingResult { message EmbeddingResult {
repeated Embeddings embeddings = 1; // The embedding results for each model head, i.e. one for each output tensor.
repeated Embedding embeddings = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for embedding extraction on time series (e.g. audio
// embedding). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/image/image_clone_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"

View File

@ -62,3 +62,38 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],
hdrs = ["embedder_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
)
cc_library(
name = "embedding_postprocessing_graph",
srcs = ["embedding_postprocessing_graph.cc"],
hdrs = ["embedding_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options) { EmbedderOptions* embedder_options) {
tasks::components::proto::EmbedderOptions options_proto; proto::EmbedderOptions options_proto;
options_proto.set_l2_normalize(embedder_options->l2_normalize); options_proto.set_l2_normalize(embedder_options->l2_normalize);
options_proto.set_quantize(embedder_options->quantize); options_proto.set_quantize(embedder_options->quantize);
return options_proto; return options_proto;
} }
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Embedder options for MediaPipe C++ embedding extraction tasks. // Embedder options for MediaPipe C++ embedding extraction tasks.
struct EmbedderOptions { struct EmbedderOptions {
@ -37,11 +38,12 @@ struct EmbedderOptions {
bool quantize; bool quantize;
}; };
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto( proto::EmbedderOptions ConvertEmbedderOptionsToProto(
EmbedderOptions* embedder_options); EmbedderOptions* embedder_options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <string> #include <string>
#include <vector> #include <vector>
@ -29,8 +29,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -39,6 +39,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using TensorsSource = using TensorsSource =
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>; ::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
// Identifies whether or not the model has quantized outputs, and performs // Identifies whether or not the model has quantized outputs, and performs
// sanity checks. // sanity checks.
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const ModelResources& model_resources, const ModelResources& model_resources,
const EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options) { proto::EmbeddingPostprocessingGraphOptions* options) {
ASSIGN_OR_RETURN(bool has_quantized_outputs, ASSIGN_OR_RETURN(bool has_quantized_outputs,
HasQuantizedOutputs(model_resources)); HasQuantizedOutputs(model_resources));
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
BuildEmbeddingPostprocessing( BuildEmbeddingPostprocessing(
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(), sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph)); graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)]; embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>() .GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
.CopyFrom(options.tensors_to_embeddings_options()); .CopyFrom(options.tensors_to_embeddings_options());
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
return tensors_to_embeddings_node[Output<EmbeddingResult>( return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
kEmbeddingResultTag)];
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::EmbeddingPostprocessingGraph); ::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Configures an EmbeddingPostprocessingGraph using the provided model resources // Configures an EmbeddingPostprocessingGraph using the provided model resources
// and EmbedderOptions. // and EmbedderOptions.
@ -44,18 +45,19 @@ namespace components {
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult
// The output EmbeddingResult. // The output EmbeddingResult.
// //
// TODO: add support for additional optional "TIMESTAMPS" input for // TODO: add support for additional optional "TIMESTAMPS" input for
// embeddings aggregation. // embeddings aggregation.
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const tasks::components::proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,
proto::EmbeddingPostprocessingGraphOptions* options); proto::EmbeddingPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include <memory> #include <memory>
@ -25,8 +25,8 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -34,12 +34,10 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::EmbedderOptions;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -69,68 +67,72 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { l2_normalize: true } R"pb(tensors_to_embeddings_options {
head_names: "probability" embedder_options { l2_normalize: true }
} head_names: "probability"
has_quantized_outputs: true)pb"))); }
has_quantized_outputs: true)pb")));
} }
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_quantize(true); options_in.set_quantize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { quantize: true } R"pb(tensors_to_embeddings_options {
} embedder_options { quantize: true }
has_quantized_outputs: true)pb"))); }
has_quantized_outputs: true)pb")));
} }
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources, MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileNetV3Embedder)); CreateModelResourcesForModel(kMobileNetV3Embedder));
EmbedderOptions options_in; proto::EmbedderOptions options_in;
options_in.set_quantize(true); options_in.set_quantize(true);
options_in.set_l2_normalize(true); options_in.set_l2_normalize(true);
EmbeddingPostprocessingGraphOptions options_out; proto::EmbeddingPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
&options_out)); &options_out));
EXPECT_THAT( EXPECT_THAT(
options_out, options_out,
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>( EqualsProto(
R"pb(tensors_to_embeddings_options { ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
embedder_options { quantize: true l2_normalize: true } R"pb(tensors_to_embeddings_options {
head_names: "feature" embedder_options { quantize: true l2_normalize: true }
} head_names: "feature"
has_quantized_outputs: false)pb"))); }
has_quantized_outputs: false)pb")));
} }
// TODO: add E2E Postprocessing tests once timestamp aggregation is // TODO: add E2E Postprocessing tests once timestamp aggregation is
// supported. // supported.
} // namespace } // namespace
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -34,3 +34,18 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
], ],
) )
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)

View File

@ -15,7 +15,10 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "EmbedderOptionsProto";
// Shared options used by all embedding extraction tasks. // Shared options used by all embedding extraction tasks.
message EmbedderOptions { message EmbedderOptions {

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";

View File

@ -23,21 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"], srcs = ["segmenter_options.proto"],
) )
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],
)
mediapipe_proto_library(
name = "embedding_postprocessing_graph_options_proto",
srcs = ["embedding_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "text_preprocessing_graph_options_proto", name = "text_preprocessing_graph_options_proto",
srcs = ["text_preprocessing_graph_options.proto"], srcs = ["text_preprocessing_graph_options.proto"],

View File

@ -26,7 +26,7 @@ cc_library(
hdrs = ["cosine_similarity.h"], hdrs = ["cosine_similarity.h"],
deps = [ deps = [
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers:embedding_result",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -39,7 +39,7 @@ cc_test(
deps = [ deps = [
":cosine_similarity", ":cosine_similarity",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers:embedding_result",
], ],
) )

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -30,7 +30,7 @@ namespace utils {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::Embedding;
template <typename T> template <typename T>
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v, absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
// an L2-norm of 0. // an L2-norm of 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u, absl::StatusOr<double> CosineSimilarity(const Embedding& u,
const EmbeddingEntry& v) { const Embedding& v) {
if (u.has_float_embedding() && v.has_float_embedding()) { if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
if (u.float_embedding().values().size() != if (u.float_embedding.size() != v.float_embedding.size()) {
v.float_embedding().values().size()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings " absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)", "of different sizes (%d vs. %d)",
u.float_embedding().values().size(), u.float_embedding.size(), v.float_embedding.size()),
v.float_embedding().values().size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
return ComputeCosineSimilarity(u.float_embedding().values().data(), return ComputeCosineSimilarity(u.float_embedding.data(),
v.float_embedding().values().data(), v.float_embedding.data(),
u.float_embedding().values().size()); u.float_embedding.size());
} }
if (u.has_quantized_embedding() && v.has_quantized_embedding()) { if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
if (u.quantized_embedding().values().size() != if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
v.quantized_embedding().values().size()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Cannot compute cosine similarity between embeddings " absl::StrFormat("Cannot compute cosine similarity between embeddings "
"of different sizes (%d vs. %d)", "of different sizes (%d vs. %d)",
u.quantized_embedding().values().size(), u.quantized_embedding.size(),
v.quantized_embedding().values().size()), v.quantized_embedding.size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>( return ComputeCosineSimilarity(
u.quantized_embedding().values().data()), reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
reinterpret_cast<const int8_t*>( reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
v.quantized_embedding().values().data()), u.quantized_embedding.size());
u.quantized_embedding().values().size());
} }
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,

View File

@ -17,22 +17,20 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace utils { namespace utils {
// Utility function to compute cosine similarity [1] between two embedding // Utility function to compute cosine similarity [1] between two embeddings. May
// entries. May return an InvalidArgumentError if e.g. the feature vectors are // return an InvalidArgumentError if e.g. the embeddings are of different types
// of different types (quantized vs. float), have different sizes, or have a // (quantized vs. float), have different sizes, or have a an L2-norm of 0.
// an L2-norm of 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
absl::StatusOr<double> CosineSimilarity( absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
const containers::proto::EmbeddingEntry& u, const containers::Embedding& v);
const containers::proto::EmbeddingEntry& v);
} // namespace utils } // namespace utils
} // namespace components } // namespace components

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -30,29 +30,27 @@ namespace components {
namespace utils { namespace utils {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::Embedding;
using ::testing::HasSubstr; using ::testing::HasSubstr;
// Helper function to generate float EmbeddingEntry. // Helper function to generate float Embedding.
EmbeddingEntry BuildFloatEntry(std::vector<float> values) { Embedding BuildFloatEmbedding(std::vector<float> values) {
EmbeddingEntry entry; Embedding embedding;
for (const float value : values) { embedding.float_embedding = values;
entry.mutable_float_embedding()->add_values(value); return embedding;
}
return entry;
} }
// Helper function to generate quantized EmbeddingEntry. // Helper function to generate quantized Embedding.
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) { Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
EmbeddingEntry entry; Embedding embedding;
entry.mutable_quantized_embedding()->set_values( uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
reinterpret_cast<uint8_t*>(values.data()), values.size()); embedding.quantized_embedding = {data, data + values.size()};
return entry; return embedding;
} }
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) { TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildQuantizedEntry({0, 1}); auto v = BuildQuantizedEmbedding({0, 1});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
} }
TEST(CosineSimilarity, FailsWithZeroNorm) { TEST(CosineSimilarity, FailsWithZeroNorm) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEntry({0.0, 0.0}); auto v = BuildFloatEmbedding({0.0, 0.0});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
} }
TEST(CosineSimilarity, FailsWithDifferentSizes) { TEST(CosineSimilarity, FailsWithDifferentSizes) {
auto u = BuildFloatEntry({0.1, 0.2}); auto u = BuildFloatEmbedding({0.1, 0.2});
auto v = BuildFloatEntry({0.1, 0.2, 0.3}); auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
auto status = CosineSimilarity(u, v); auto status = CosineSimilarity(u, v);
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
} }
TEST(CosineSimilarity, SucceedsWithFloatEntries) { TEST(CosineSimilarity, SucceedsWithFloatEntries) {
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0}); auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5}); auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
} }
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) { TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
auto u = BuildQuantizedEntry({127, 0, 0, 0}); auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
auto v = BuildQuantizedEntry({-128, 0, 0, 0}); auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v)); MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));

View File

@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
hand_gesture_subgraph[Output<std::vector<ClassificationList>>( hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
kHandGesturesTag)]; kHandGesturesTag)];
return {{.gesture = hand_gestures, return GestureRecognizerOutputs{
.handedness = handedness, /*gesture=*/hand_gestures,
.hand_landmarks = hand_landmarks, /*handedness=*/handedness,
.hand_world_landmarks = hand_world_landmarks, /*hand_landmarks=*/hand_landmarks,
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}}; /*hand_world_landmarks=*/hand_world_landmarks,
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
} }
// Populate normalized non rotated face bounding box // Populate normalized non rotated face bounding box
return {.left = bounding_box_left, return Rect{/*left=*/bounding_box_left,
.top = bounding_box_top, /*top=*/bounding_box_top,
.right = bounding_box_right, /*right=*/bounding_box_right,
.bottom = bounding_box_bottom}; /*bottom=*/bounding_box_bottom};
} }
// Uses IoU and distance of some corresponding hand landmarks to detect // Uses IoU and distance of some corresponding hand landmarks to detect

View File

@ -26,12 +26,12 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
@ -49,9 +49,10 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/components:embedder_options", "//mediapipe/tasks/cc/components/containers:embedding_result",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors:embedder_options",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:cosine_similarity", "//mediapipe/tasks/cc/components/utils:cosine_similarity",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",

View File

@ -21,9 +21,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/tool/options_map.h" #include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h" #include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -41,8 +42,8 @@ namespace image_embedder {
namespace { namespace {
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out"; constexpr char kEmbeddingsStreamName[] = "embeddings_out";
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto:: using ::mediapipe::tasks::vision::image_embedder::proto::
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kNormRectTag).SetName(kNormRectStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName);
auto& task_graph = graph.AddNode(kGraphTypeName); auto& task_graph = graph.AddNode(kGraphTypeName);
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get()); task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >> task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
graph.Out(kEmbeddingResultTag); graph.Out(kEmbeddingsTag);
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >> task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator( return tasks::core::AddFlowLimiterCalculator(
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag); graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
} }
graph.In(kImageTag) >> task_graph.In(kImageTag); graph.In(kImageTag) >> task_graph.In(kImageTag);
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag); graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE); options->running_mode != core::RunningMode::IMAGE);
auto embedder_options_proto = auto embedder_options_proto =
std::make_unique<tasks::components::proto::EmbedderOptions>( std::make_unique<components::processors::proto::EmbedderOptions>(
components::ConvertEmbedderOptionsToProto( components::processors::ConvertEmbedderOptionsToProto(
&(options->embedder_options))); &(options->embedder_options)));
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get()); options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
return options_proto; return options_proto;
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
return; return;
} }
Packet embedding_result_packet = Packet embedding_result_packet =
status_or_packets.value()[kEmbeddingResultStreamName]; status_or_packets.value()[kEmbeddingsStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(embedding_result_packet.Get<EmbeddingResult>(), result_callback(ConvertToEmbeddingResult(
embedding_result_packet.Get<EmbeddingResult>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
embedding_result_packet.Timestamp().Value() / embedding_result_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed( absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
Image image, Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo( absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return ConvertToEmbeddingResult(
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
} }
absl::Status ImageEmbedder::EmbedAsync( absl::Status ImageEmbedder::EmbedAsync(
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
} }
absl::StatusOr<double> ImageEmbedder::CosineSimilarity( absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
const EmbeddingEntry& u, const EmbeddingEntry& v) { const components::containers::Embedding& u,
const components::containers::Embedding& v) {
return components::utils::CosineSimilarity(u, v); return components::utils::CosineSimilarity(u, v);
} }

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/processors/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
@ -33,6 +33,10 @@ namespace tasks {
namespace vision { namespace vision {
namespace image_embedder { namespace image_embedder {
// Alias the shared EmbeddingResult struct as result typo.
using ImageEmbedderResult =
::mediapipe::tasks::components::containers::EmbeddingResult;
// The options for configuring a MediaPipe image embedder task. // The options for configuring a MediaPipe image embedder task.
struct ImageEmbedderOptions { struct ImageEmbedderOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model // Base options for configuring MediaPipe Tasks, such as specifying the model
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
// Options for configuring the embedder behavior, such as L2-normalization or // Options for configuring the embedder behavior, such as L2-normalization or
// scalar-quantization. // scalar-quantization.
components::EmbedderOptions embedder_options; components::processors::EmbedderOptions embedder_options;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void( std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
absl::StatusOr<components::containers::proto::EmbeddingResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// running mode. // running mode.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed( absl::StatusOr<ImageEmbedderResult> Embed(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo( absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// Shuts down the ImageEmbedder when all works are done. // Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
// Utility function to compute cosine similarity [1] between two embedding // Utility function to compute cosine similarity [1] between two embeddings.
// entries. May return an InvalidArgumentError if e.g. the feature vectors are // May return an InvalidArgumentError if e.g. the embeddings are of different
// of different types (quantized vs. float), have different sizes, or have a // types (quantized vs. float), have different sizes, or have a an L2-norm of
// an L2-norm of 0. // 0.
// //
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity // [1]: https://en.wikipedia.org/wiki/Cosine_similarity
static absl::StatusOr<double> CosineSimilarity( static absl::StatusOr<double> CosineSimilarity(
const components::containers::proto::EmbeddingEntry& u, const components::containers::Embedding& u,
const components::containers::proto::EmbeddingEntry& v); const components::containers::Embedding& v);
}; };
} // namespace image_embedder } // namespace image_embedder

View File

@ -20,10 +20,10 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::components::proto::
EmbeddingPostprocessingGraphOptions;
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
// Describes region of image to perform embedding extraction on. // Describes region of image to perform embedding extraction on.
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult
// The embedding result. // The embedding result.
// IMAGE - Image // IMAGE - Image
// The image that embedding extraction runs on. // The image that embedding extraction runs on.
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
// node { // node {
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph" // calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// output_stream: "EMBEDDING_RESULT:embedding_result_out" // output_stream: "EMBEDDINGS:embedding_result_out"
// output_stream: "IMAGE:image_out" // output_stream: "IMAGE:image_out"
// options { // options {
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext] // [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.embedding_result >> output_streams.embedding_result >>
graph[Output<EmbeddingResult>(kEmbeddingResultTag)]; graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)]; output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects its input stream to the // Adds postprocessing calculators and connects its input stream to the
// inference results. // inference results.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.EmbeddingPostprocessingGraph"); "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing( MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
model_resources, task_options.embedder_options(), model_resources, task_options.embedder_options(),
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>())); &postprocessing.GetOptions<components::processors::proto::
EmbeddingPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the embedding results. // Outputs the embedding results.
return ImageEmbedderOutputStreams{ return ImageEmbedderOutputStreams{
/*embedding_result=*/postprocessing[Output<EmbeddingResult>( /*embedding_result=*/postprocessing[Output<EmbeddingResult>(
kEmbeddingResultTag)], kEmbeddingsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]}; /*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -42,7 +42,6 @@ namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
// Utility function to check the sizes, head_index and head_names of a result // Utility function to check the sizes, head_index and head_names of a result
// procuded by kMobileNetV3Embedder. // procuded by kMobileNetV3Embedder.
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) { void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
EXPECT_EQ(result.embeddings().size(), 1); EXPECT_EQ(result.embeddings.size(), 1);
EXPECT_EQ(result.embeddings(0).head_index(), 0); EXPECT_EQ(result.embeddings[0].head_index, 0);
EXPECT_EQ(result.embeddings(0).head_name(), "feature"); EXPECT_EQ(result.embeddings[0].head_name, "feature");
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
if (quantized) { if (quantized) {
EXPECT_EQ( EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
result.embeddings(0).entries(0).quantized_embedding().values().size(),
1024);
} else { } else {
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(), EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
1024);
} }
} }
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = running_mode; options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
auto image_embedder = ImageEmbedder::Create(std::move(options)); auto image_embedder = ImageEmbedder::Create(std::move(options));
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519; double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.925519; double expected_similarity = 0.925519;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, true); CheckMobileNetV3Result(image_result, true);
CheckMobileNetV3Result(crop_result, true); CheckMobileNetV3Result(crop_result, true);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.926791; double expected_similarity = 0.926791;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result, const ImageEmbedderResult& image_result,
image_embedder->Embed(image, image_processing_options)); image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), crop_result.embeddings[0]));
crop_result.embeddings(0).entries(0)));
double expected_similarity = 0.999931; double expected_similarity = 0.999931;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
image_processing_options.rotation_degrees = -90; image_processing_options.rotation_degrees = -90;
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
image_embedder->Embed(image)); image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result, const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options)); image_embedder->Embed(rotated, image_processing_options));
// Check results. // Check results.
CheckMobileNetV3Result(image_result, false); CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false); CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, image_result.embeddings[0],
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), rotated_result.embeddings[0]));
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265; double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
/*rotation_degrees=*/-90}; /*rotation_degrees=*/-90};
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result, const ImageEmbedderResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options)); image_embedder->Embed(rotated, image_processing_options));
// Check results. // Check results.
CheckMobileNetV3Result(crop_result, false); CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false); CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity. // CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
double similarity, crop_result.embeddings[0],
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), rotated_result.embeddings[0]));
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838; double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
EmbeddingResult previous_results; ImageEmbedderResult previous_results;
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results, MP_ASSERT_OK_AND_ASSIGN(auto results,
image_embedder->EmbedForVideo(image, i)); image_embedder->EmbedForVideo(image, i));
CheckMobileNetV3Result(results, false); CheckMobileNetV3Result(results, false);
if (i > 0) { if (i > 0) {
MP_ASSERT_OK_AND_ASSIGN(double similarity, MP_ASSERT_OK_AND_ASSIGN(
ImageEmbedder::CosineSimilarity( double similarity,
results.embeddings(0).entries(0), ImageEmbedder::CosineSimilarity(results.embeddings[0],
previous_results.embeddings(0).entries(0))); previous_results.embeddings[0]));
double expected_similarity = 1.000000; double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<EmbeddingResult>, options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options))); ImageEmbedder::Create(std::move(options)));
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
} }
struct LiveStreamModeResults { struct LiveStreamModeResults {
EmbeddingResult embedding_result; ImageEmbedderResult embedding_result;
std::pair<int, int> image_size; std::pair<int, int> image_size;
int64 timestamp_ms; int64 timestamp_ms;
}; };
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback =
[&results](absl::StatusOr<EmbeddingResult> embedding_result, [&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(embedding_result.status()); MP_ASSERT_OK(embedding_result.status());
results.push_back( results.push_back(
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
double similarity, double similarity,
ImageEmbedder::CosineSimilarity( ImageEmbedder::CosineSimilarity(
result.embedding_result.embeddings(0).entries(0), result.embedding_result.embeddings[0],
results[i - 1].embedding_result.embeddings(0).entries(0))); results[i - 1].embedding_result.embeddings[0]));
double expected_similarity = 1.000000; double expected_similarity = 1.000000;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:embedder_options_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_embedder.proto; package mediapipe.tasks.vision.image_embedder.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageEmbedderGraphOptions { message ImageEmbedderGraphOptions {
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
// Options for configuring the embedder behavior, such as normalization or // Options for configuring the embedder behavior, such as normalization or
// quantization. // quantization.
optional components.proto.EmbedderOptions embedder_options = 2; optional components.processors.proto.EmbedderOptions embedder_options = 2;
} }

View File

@ -287,10 +287,9 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i])); tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
} }
} }
return {{ return ImageSegmenterOutputs{
.segmented_masks = segmented_masks, /*segmented_masks=*/segmented_masks,
.image = preprocessing[Output<Image>(kImageTag)], /*image=*/preprocessing[Output<Image>(kImageTag)]};
}};
} }
}; };

View File

@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",

View File

@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
py_library(
name = "audio_data",
srcs = ["audio_data.py"],
)
py_library( py_library(
name = "bounding_box", name = "bounding_box",
srcs = ["bounding_box.py"], srcs = ["bounding_box.py"],
@ -36,6 +41,29 @@ py_library(
], ],
) )
py_library(
name = "landmark",
srcs = ["landmark.py"],
deps = [
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "landmark_detection_result",
srcs = ["landmark_detection_result.py"],
deps = [
":landmark",
":rect",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library( py_library(
name = "category", name = "category",
srcs = ["category.py"], srcs = ["category.py"],

View File

@ -0,0 +1,109 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio data."""
import dataclasses
from typing import Optional
import numpy as np
@dataclasses.dataclass
class AudioFormat:
"""Audio format metadata.
Attributes:
num_channels: the number of channels of the audio data.
sample_rate: the audio sample rate.
"""
num_channels: int = 1
sample_rate: Optional[float] = None
class AudioData(object):
"""MediaPipe Tasks' audio container."""
def __init__(
self, buffer_length: int,
audio_format: AudioFormat = AudioFormat()) -> None:
"""Initializes the `AudioData` object.
Args:
buffer_length: the length of the audio buffer.
audio_format: the audio format metadata.
"""
self._audio_format = audio_format
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
dtype=np.float32)
def clear(self):
"""Clears the internal buffer and fill it with zeros."""
self._buffer.fill(0)
def load_from_array(self,
src: np.ndarray,
offset: int = 0,
size: int = -1) -> None:
"""Loads the audio data from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
offset: An optional offset for loading a slice of the `src` array to the
buffer.
size: An optional size parameter denoting the number of samples to load
from the `src` array.
Raises:
ValueError: If the input array has an incorrect shape or if
`offset` + `size` exceeds the length of the `src` array.
"""
if src.shape[1] != self._audio_format.num_channels:
raise ValueError(f"Input audio contains an invalid number of channels. "
f"Expect {self._audio_format.num_channels}.")
if size < 0:
size = len(src)
if offset + size > len(src):
raise ValueError(
f"Index out of range. offset {offset} + size {size} should be <= "
f"src's length: {len(src)}")
if len(src) >= len(self._buffer):
# If the internal buffer is shorter than the load target (src), copy
# values from the end of the src array to the internal buffer.
new_offset = offset + size - len(self._buffer)
new_size = len(self._buffer)
self._buffer = src[new_offset:new_offset + new_size].copy()
else:
# Shift the internal buffer backward and add the incoming data to the end
# of the buffer.
shift = size
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = src[offset:offset + size].copy()
@property
def audio_format(self) -> AudioFormat:
"""Gets the audio format of the audio."""
return self._audio_format
@property
def buffer_length(self) -> int:
"""Gets the sample count of the audio."""
return self._buffer.shape[0]
@property
def buffer(self) -> np.ndarray:
"""Gets the internal buffer."""
return self._buffer

View File

@ -14,7 +14,7 @@
"""Category data class.""" """Category data class."""
import dataclasses import dataclasses
from typing import Any from typing import Any, Optional
from mediapipe.tasks.cc.components.containers.proto import category_pb2 from mediapipe.tasks.cc.components.containers.proto import category_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -39,10 +39,10 @@ class Category:
category_name: The label of this category object. category_name: The label of this category object.
""" """
index: int index: Optional[int] = None
score: float score: Optional[float] = None
display_name: str display_name: Optional[str] = None
category_name: str category_name: Optional[str] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _CategoryProto: def to_pb2(self) -> _CategoryProto:

View File

@ -0,0 +1,122 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Landmark data class."""
import dataclasses
from typing import Optional
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarkProto = landmark_pb2.Landmark
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
@dataclasses.dataclass
class Landmark:
"""A landmark that can have 1 to 3 dimensions.
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
Attributes:
x: The x coordinate.
y: The y coordinate.
z: The z coordinate.
visibility: Landmark visibility. Should stay unset if not supported. Float
score of whether landmark is visible or occluded by other objects.
Landmark considered as invisible also if it is not present on the screen
(out of scene bounds). Depending on the model, visibility value is either
a sigmoid or an argument of sigmoid.
presence: Landmark presence. Should stay unset if not supported. Float score
of whether landmark is present on the scene (located within scene bounds).
Depending on the model, presence value is either a result of sigmoid or an
argument of sigmoid function to get landmark presence probability.
"""
x: Optional[float] = None
y: Optional[float] = None
z: Optional[float] = None
visibility: Optional[float] = None
presence: Optional[float] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarkProto:
"""Generates a Landmark protobuf object."""
return _LandmarkProto(
x=self.x,
y=self.y,
z=self.z,
visibility=self.visibility,
presence=self.presence)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _LandmarkProto) -> 'Landmark':
"""Creates a `Landmark` object from the given protobuf object."""
return Landmark(
x=pb2_obj.x,
y=pb2_obj.y,
z=pb2_obj.z,
visibility=pb2_obj.visibility,
presence=pb2_obj.presence)
@dataclasses.dataclass
class NormalizedLandmark:
"""A normalized version of above Landmark proto.
All coordinates should be within [0, 1].
Attributes:
x: The normalized x coordinate.
y: The normalized y coordinate.
z: The normalized z coordinate.
visibility: Landmark visibility. Should stay unset if not supported. Float
score of whether landmark is visible or occluded by other objects.
Landmark considered as invisible also if it is not present on the screen
(out of scene bounds). Depending on the model, visibility value is either
a sigmoid or an argument of sigmoid.
presence: Landmark presence. Should stay unset if not supported. Float score
of whether landmark is present on the scene (located within scene bounds).
Depending on the model, presence value is either a result of sigmoid or an
argument of sigmoid function to get landmark presence probability.
"""
x: Optional[float] = None
y: Optional[float] = None
z: Optional[float] = None
visibility: Optional[float] = None
presence: Optional[float] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _NormalizedLandmarkProto:
"""Generates a NormalizedLandmark protobuf object."""
return _NormalizedLandmarkProto(
x=self.x,
y=self.y,
z=self.z,
visibility=self.visibility,
presence=self.presence)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _NormalizedLandmarkProto) -> 'NormalizedLandmark':
"""Creates a `NormalizedLandmark` object from the given protobuf object."""
return NormalizedLandmark(
x=pb2_obj.x,
y=pb2_obj.y,
z=pb2_obj.z,
visibility=pb2_obj.visibility,
presence=pb2_obj.presence)

View File

@ -0,0 +1,96 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Landmarks Detection Result data class."""
import dataclasses
from typing import Optional, List
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_LandmarkListProto = landmark_pb2.LandmarkList
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
_NormalizedRect = rect_module.NormalizedRect
_Category = category_module.Category
_NormalizedLandmark = landmark_module.NormalizedLandmark
_Landmark = landmark_module.Landmark
@dataclasses.dataclass
class LandmarksDetectionResult:
"""Represents the landmarks detection result.
Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A
list of `Category` objects. world_landmarks : A list of `Landmark` objects.
rect : A `NormalizedRect` object.
"""
landmarks: Optional[List[_NormalizedLandmark]]
categories: Optional[List[_Category]]
world_landmarks: Optional[List[_Landmark]]
rect: _NormalizedRect
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarksDetectionResultProto:
"""Generates a LandmarksDetectionResult protobuf object."""
classifications = _ClassificationListProto()
for category in self.categories:
classifications.classification.append(
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name))
return _LandmarksDetectionResultProto(
landmarks=_NormalizedLandmarkListProto(self.landmarks),
classifications=classifications,
world_landmarks=_LandmarkListProto(self.world_landmarks),
rect=self.rect.to_pb2())
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
"""Creates a `LandmarksDetectionResult` object from the given protobuf object.
"""
categories = []
for classification in pb2_obj.classifications.classification:
categories.append(
category_module.Category(
score=classification.score,
index=classification.index,
category_name=classification.label,
display_name=classification.display_name))
return LandmarksDetectionResult(
landmarks=[
_NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmarks.landmark
],
categories=categories,
world_landmarks=[
_Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.world_landmarks.landmark
],
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))

View File

@ -19,80 +19,49 @@ from typing import Any, Optional
from mediapipe.framework.formats import rect_pb2 from mediapipe.framework.formats import rect_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_RectProto = rect_pb2.Rect
_NormalizedRectProto = rect_pb2.NormalizedRect _NormalizedRectProto = rect_pb2.NormalizedRect
@dataclasses.dataclass @dataclasses.dataclass
class Rect: class Rect:
"""A rectangle with rotation in image coordinates. """A rectangle, used as part of detection results or as input region-of-interest.
Attributes: x_center : The X coordinate of the top-left corner, in pixels. The coordinates are normalized wrt the image dimensions, i.e. generally in
y_center : The Y coordinate of the top-left corner, in pixels. [0,1] but they may exceed these bounds if describing a region overlapping the
width: The width of the rectangle, in pixels. image. The origin is on the top-left corner of the image.
height: The height of the rectangle, in pixels.
rotation: Rotation angle is clockwise in radians. Attributes:
rect_id: Optional unique id to help associate different rectangles to each left: The X coordinate of the left side of the rectangle.
other. top: The Y coordinate of the top of the rectangle.
right: The X coordinate of the right side of the rectangle.
bottom: The Y coordinate of the bottom of the rectangle.
""" """
x_center: int left: float
y_center: int top: float
width: int right: float
height: int bottom: float
rotation: Optional[float] = 0.0
rect_id: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _RectProto:
"""Generates a Rect protobuf object."""
return _RectProto(
x_center=self.x_center,
y_center=self.y_center,
width=self.width,
height=self.height,
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect':
"""Creates a `Rect` object from the given protobuf object."""
return Rect(
x_center=pb2_obj.x_center,
y_center=pb2_obj.y_center,
width=pb2_obj.width,
height=pb2_obj.height)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Rect):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass @dataclasses.dataclass
class NormalizedRect: class NormalizedRect:
"""A rectangle with rotation in normalized coordinates. """A rectangle with rotation in normalized coordinates.
The values of box Location of the center of the rectangle in image coordinates. The (0.0, 0.0)
point is at the (top, left) corner.
center location and size are within [0, 1]. The values of box center location and size are within [0, 1].
Attributes: x_center : The X normalized coordinate of the top-left corner. Attributes:
y_center : The Y normalized coordinate of the top-left corner. x_center: The normalized X coordinate of the rectangle, in image
coordinates.
y_center: The normalized Y coordinate of the rectangle, in image
coordinates.
width: The width of the rectangle. width: The width of the rectangle.
height: The height of the rectangle. height: The height of the rectangle.
rotation: Rotation angle is clockwise in radians. rotation: Rotation angle is clockwise in radians.
rect_id: Optional unique id to help associate different rectangles to each rect_id: Optional unique id to help associate different rectangles to each
other. other.
""" """
x_center: float x_center: float

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -24,7 +24,7 @@ py_library(
srcs = ["test_utils.py"], srcs = ["test_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
visibility = [ visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", "//mediapipe/model_maker/python:__subpackages__",
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",
], ],
deps = [ deps = [

View File

@ -53,6 +53,7 @@ py_test(
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_classifier", "//mediapipe/tasks/python/vision:image_classifier",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )

View File

@ -30,9 +30,10 @@ from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_classifier from mediapipe.tasks.python.vision import image_classifier
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
_NormalizedRect = rect.NormalizedRect _Rect = rect.Rect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category _Category = category.Category
@ -43,6 +44,7 @@ _Image = image.Image
_ImageClassifier = image_classifier.ImageClassifier _ImageClassifier = image_classifier.ImageClassifier
_ImageClassifierOptions = image_classifier.ImageClassifierOptions _ImageClassifierOptions = image_classifier.ImageClassifierOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' _MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg' _IMAGE_FILE = 'burger.jpg'
@ -227,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
# NormalizedRect around the soccer ball. # Region-of-interest around the soccer ball.
roi = _NormalizedRect( roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
x_center=0.532, y_center=0.521, width=0.164, height=0.427) image_processing_options = _ImageProcessingOptions(roi)
# Performs image classification on the input. # Performs image classification on the input.
image_result = classifier.classify(test_image, roi) image_result = classifier.classify(test_image, image_processing_options)
# Comparing results. # Comparing results.
test_utils.assert_proto_equals(self, image_result.to_pb2(), test_utils.assert_proto_equals(self, image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2()) _generate_soccer_ball_results(0).to_pb2())
@ -417,12 +419,12 @@ class ImageClassifierTest(parameterized.TestCase):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
# NormalizedRect around the soccer ball. # Region-of-interest around the soccer ball.
roi = _NormalizedRect( roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
x_center=0.532, y_center=0.521, width=0.164, height=0.427) image_processing_options = _ImageProcessingOptions(roi)
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video( classification_result = classifier.classify_for_video(
test_image, timestamp, roi) test_image, timestamp, image_processing_options)
test_utils.assert_proto_equals( test_utils.assert_proto_equals(
self, classification_result.to_pb2(), self, classification_result.to_pb2(),
_generate_soccer_ball_results(timestamp).to_pb2()) _generate_soccer_ball_results(timestamp).to_pb2())
@ -491,9 +493,9 @@ class ImageClassifierTest(parameterized.TestCase):
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg'))) os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
# NormalizedRect around the soccer ball. # Region-of-interest around the soccer ball.
roi = _NormalizedRect( roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
x_center=0.532, y_center=0.521, width=0.164, height=0.427) image_processing_options = _ImageProcessingOptions(roi)
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image, def check_result(result: _ClassificationResult, output_image: _Image,
@ -514,7 +516,8 @@ class ImageClassifierTest(parameterized.TestCase):
result_callback=check_result) result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
classifier.classify_async(test_image, timestamp, roi) classifier.classify_async(test_image, timestamp,
image_processing_options)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -33,8 +33,8 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.OutputType _OutputType = image_segmenter.ImageSegmenterOptions.OutputType
_Activation = image_segmenter.Activation _Activation = image_segmenter.ImageSegmenterOptions.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode

View File

@ -55,6 +55,7 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api", "//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )
@ -77,3 +78,27 @@ py_library(
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )
py_library(
name = "gesture_recognizer",
srcs = [
"gesture_recognizer.py",
],
deps = [
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)

View File

@ -15,17 +15,25 @@
"""MediaPipe Tasks Vision API.""" """MediaPipe Tasks Vision API."""
import mediapipe.tasks.python.vision.core import mediapipe.tasks.python.vision.core
import mediapipe.tasks.python.vision.gesture_recognizer
import mediapipe.tasks.python.vision.image_classifier import mediapipe.tasks.python.vision.image_classifier
import mediapipe.tasks.python.vision.image_segmenter
import mediapipe.tasks.python.vision.object_detector import mediapipe.tasks.python.vision.object_detector
GestureRecognizer = gesture_recognizer.GestureRecognizer
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
ImageClassifier = image_classifier.ImageClassifier ImageClassifier = image_classifier.ImageClassifier
ImageClassifierOptions = image_classifier.ImageClassifierOptions ImageClassifierOptions = image_classifier.ImageClassifierOptions
ImageSegmenter = image_segmenter.ImageSegmenter
ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
ObjectDetector = object_detector.ObjectDetector ObjectDetector = object_detector.ObjectDetector
ObjectDetectorOptions = object_detector.ObjectDetectorOptions ObjectDetectorOptions = object_detector.ObjectDetectorOptions
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
# Remove unnecessary modules to avoid duplication in API docs. # Remove unnecessary modules to avoid duplication in API docs.
del core del core
del gesture_recognizer
del image_classifier del image_classifier
del image_segmenter
del object_detector del object_detector
del mediapipe del mediapipe

View File

@ -23,15 +23,25 @@ py_library(
srcs = ["vision_task_running_mode.py"], srcs = ["vision_task_running_mode.py"],
) )
py_library(
name = "image_processing_options",
srcs = ["image_processing_options.py"],
deps = [
"//mediapipe/tasks/python/components/containers:rect",
],
)
py_library( py_library(
name = "base_vision_task_api", name = "base_vision_task_api",
srcs = [ srcs = [
"base_vision_task_api.py", "base_vision_task_api.py",
], ],
deps = [ deps = [
":image_processing_options",
":vision_task_running_mode", ":vision_task_running_mode",
"//mediapipe/framework:calculator_py_pb2", "//mediapipe/framework:calculator_py_pb2",
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
], ],
) )

View File

@ -13,17 +13,22 @@
# limitations under the License. # limitations under the License.
"""MediaPipe vision task base api.""" """MediaPipe vision task base api."""
import math
from typing import Callable, Mapping, Optional from typing import Callable, Mapping, Optional
from mediapipe.framework import calculator_pb2 from mediapipe.framework import calculator_pb2
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_TaskRunner = task_runner_module.TaskRunner _TaskRunner = task_runner_module.TaskRunner
_Packet = packet_module.Packet _Packet = packet_module.Packet
_NormalizedRect = rect_module.NormalizedRect
_RunningMode = running_mode_module.VisionTaskRunningMode _RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
class BaseVisionTaskApi(object): class BaseVisionTaskApi(object):
@ -122,6 +127,49 @@ class BaseVisionTaskApi(object):
+ self._running_mode.name) + self._running_mode.name)
self._runner.send(inputs) self._runner.send(inputs)
def convert_to_normalized_rect(self,
options: _ImageProcessingOptions,
roi_allowed: bool = True) -> _NormalizedRect:
"""Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly.
If the input ImageProcessingOptions is not present, returns a default
NormalizedRect covering the whole image with rotation set to 0. If
'roi_allowed' is false, an error will be returned if the input
ImageProcessingOptions has its 'region_of_interest' field set.
Args:
options: Options for image processing.
roi_allowed: Indicates if the `region_of_interest` field is allowed to be
set. By default, it's set to True.
Returns:
A normalized rect proto that repesents the image processing options.
"""
normalized_rect = _NormalizedRect(
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1)
if options is None:
return normalized_rect
if options.rotation_degrees % 90 != 0:
raise ValueError('Expected rotation to be a multiple of 90°.')
# Convert to radians counter-clockwise.
normalized_rect.rotation = -options.rotation_degrees * math.pi / 180.0
if options.region_of_interest:
if not roi_allowed:
raise ValueError("This task doesn't support region-of-interest.")
roi = options.region_of_interest
if roi.left >= roi.right or roi.top >= roi.bottom:
raise ValueError('Expected Rect with left < right and top < bottom.')
if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1:
raise ValueError('Expected Rect values to be in [0,1].')
normalized_rect.x_center = (roi.left + roi.right) / 2.0
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
normalized_rect.width = roi.right - roi.left
normalized_rect.height = roi.bottom - roi.top
return normalized_rect
def close(self) -> None: def close(self) -> None:
"""Shuts down the mediapipe vision task instance. """Shuts down the mediapipe vision task instance.

View File

@ -0,0 +1,39 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe vision options for image processing."""
import dataclasses
from typing import Optional
from mediapipe.tasks.python.components.containers import rect as rect_module
@dataclasses.dataclass
class ImageProcessingOptions:
"""Options for image processing.
If both region-of-interest and rotation are specified, the crop around the
region-of-interest is extracted first, then the specified rotation is applied
to the crop.
Attributes:
region_of_interest: The optional region-of-interest to crop from the image.
If not specified, the full image is used. Coordinates must be in [0,1]
with 'left' < 'right' and 'top' < 'bottom'.
rotation_degrees: The rotation to apply to the image (or cropped
region-of-interest), in degrees clockwise. The rotation must be a multiple
(positive or negative) of 90°.
"""
region_of_interest: Optional[rect_module.Rect] = None
rotation_degrees: int = 0

View File

@ -0,0 +1,426 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe gesture recognizer task."""
import dataclasses
from typing import Callable, Mapping, Optional, List
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_HAND_GESTURE_STREAM_NAME = 'hand_gestures'
_HAND_GESTURE_TAG = 'HAND_GESTURES'
_HANDEDNESS_STREAM_NAME = 'handedness'
_HANDEDNESS_TAG = 'HANDEDNESS'
_HAND_LANDMARKS_STREAM_NAME = 'landmarks'
_HAND_LANDMARKS_TAG = 'LANDMARKS'
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
_GESTURE_DEFAULT_INDEX = -1
@dataclasses.dataclass
class GestureRecognitionResult:
"""The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image.
Attributes:
gestures: Recognized hand gestures of detected hands. Note that the index of
the gesture is always -1, because the raw indices from multiple gesture
classifiers cannot consolidate to a meaningful index.
handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates.
hand_world_landmarks: Detected hand landmarks in world coordinates.
"""
gestures: List[List[category_module.Category]]
handedness: List[List[category_module.Category]]
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
hand_world_landmarks: List[List[landmark_module.Landmark]]
def _build_recognition_result(
output_packets: Mapping[str,
packet_module.Packet]) -> GestureRecognitionResult:
"""Consturcts a `GestureRecognitionResult` from output packets."""
gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
gesture_results = []
for proto in gestures_proto_list:
gesture_categories = []
gesture_classifications = classification_pb2.ClassificationList()
gesture_classifications.MergeFrom(proto)
for gesture in gesture_classifications.classification:
gesture_categories.append(
category_module.Category(
index=_GESTURE_DEFAULT_INDEX,
score=gesture.score,
display_name=gesture.display_name,
category_name=gesture.label))
gesture_results.append(gesture_categories)
handedness_results = []
for proto in handedness_proto_list:
handedness_categories = []
handedness_classifications = classification_pb2.ClassificationList()
handedness_classifications.MergeFrom(proto)
for handedness in handedness_classifications.classification:
handedness_categories.append(
category_module.Category(
index=handedness.index,
score=handedness.score,
display_name=handedness.display_name,
category_name=handedness.label))
handedness_results.append(handedness_categories)
hand_landmarks_results = []
for proto in hand_landmarks_proto_list:
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
hand_landmarks.MergeFrom(proto)
hand_landmarks_results.append([
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
for hand_landmark in hand_landmarks.landmark
])
hand_world_landmarks_results = []
for proto in hand_world_landmarks_proto_list:
hand_world_landmarks = landmark_pb2.LandmarkList()
hand_world_landmarks.MergeFrom(proto)
hand_world_landmarks_results.append([
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
for hand_world_landmark in hand_world_landmarks.landmark
])
return GestureRecognitionResult(gesture_results, handedness_results,
hand_landmarks_results,
hand_world_landmarks_results)
@dataclasses.dataclass
class GestureRecognizerOptions:
"""Options for the gesture recognizer task.
Attributes:
base_options: Base options for the hand gesture recognizer task.
running_mode: The running mode of the task. Default to the image mode.
Gesture recognizer task has three running modes: 1) The image mode for
recognizing hand gestures on single image inputs. 2) The video mode for
recognizing hand gestures on the decoded frames of a video. 3) The live
stream mode for recognizing hand gestures on a live stream of input data,
such as from camera.
num_hands: The maximum number of hands can be detected by the recognizer.
min_hand_detection_confidence: The minimum confidence score for the hand
detection to be considered successful.
min_hand_presence_confidence: The minimum confidence score of hand presence
score in the hand landmark detection.
min_tracking_confidence: The minimum confidence score for the hand tracking
to be considered successful.
canned_gesture_classifier_options: Options for configuring the canned
gestures classifier, such as score threshold, allow list and deny list of
gestures. The categories for canned gesture classifiers are: ["None",
"Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", "Thumb_Up",
"Victory", "ILoveYou"]. Note this option is subject to change.
custom_gesture_classifier_options: Options for configuring the custom
gestures classifier, such as score threshold, allow list and deny list of
gestures. Note this option is subject to change.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
num_hands: Optional[int] = 1
min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5
canned_gesture_classifier_options: Optional[
_ClassifierOptions] = _ClassifierOptions()
custom_gesture_classifier_options: Optional[
_ClassifierOptions] = _ClassifierOptions()
result_callback: Optional[Callable[
[GestureRecognitionResult, image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
"""Generates an GestureRecognizerOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Initialize gesture recognizer options from base options.
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
base_options=base_options_proto)
# Configure hand detector and hand landmarker options.
hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
# Configure hand gesture recognizer options.
hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(
self.canned_gesture_classifier_options.to_pb2())
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(
self.custom_gesture_classifier_options.to_pb2())
return gesture_recognizer_options_proto
class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs gesture recognition on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'GestureRecognizer':
"""Creates an `GestureRecognizer` object from a TensorFlow Lite model and the default `GestureRecognizerOptions`.
Note that the created `GestureRecognizer` instance is in image mode, for
recognizing hand gestures on single image inputs.
Args:
model_path: Path to the model.
Returns:
`GestureRecognizer` object that's created from the model file and the
default `GestureRecognizerOptions`.
Raises:
ValueError: If failed to create `GestureRecognizer` object from the
provided file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = GestureRecognizerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
return cls.create_from_options(options)
@classmethod
def create_from_options(
cls, options: GestureRecognizerOptions) -> 'GestureRecognizer':
"""Creates the `GestureRecognizer` object from gesture recognizer options.
Args:
options: Options for the gesture recognizer task.
Returns:
`GestureRecognizer` object that's created from `options`.
Raises:
ValueError: If failed to create `GestureRecognizer` object from
`GestureRecognizerOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
options.result_callback(
GestureRecognitionResult([], [], [], []), image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
return
gesture_recognition_result = _build_recognition_result(output_packets)
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
options.result_callback(gesture_recognition_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]),
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
':'.join([_HAND_LANDMARKS_TAG,
_HAND_LANDMARKS_STREAM_NAME]), ':'.join([
_HAND_WORLD_LANDMARKS_TAG,
_HAND_WORLD_LANDMARKS_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
def recognize(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> GestureRecognitionResult:
"""Performs hand gesture recognition on the given image.
Only use this method when the GestureRecognizer is created with the image
running mode.
The image can be of any size with format RGB or RGBA.
TODO: Describes how the input image will be preprocessed after the yuv
support is implemented.
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
The hand gesture recognition results.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If gesture recognition failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
})
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], [])
return _build_recognition_result(output_packets)
def recognize_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> GestureRecognitionResult:
"""Performs gesture recognition on the provided video frame.
Only use this method when the GestureRecognizer is created with the video
running mode.
Only use this method when the GestureRecognizer is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns:
The hand gesture recognition results.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If gesture recognition failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], [])
return _build_recognition_result(output_packets)
def recognize_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None:
"""Sends live image data to perform gesture recognition.
The results will be available via the "result_callback" provided in the
GestureRecognizerOptions. Only use this method when the GestureRecognizer
is created with the live stream running mode.
Only use this method when the GestureRecognizer is created with the live
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input image is accepted. The results will be available via the
`result_callback` provided in the `GestureRecognizerOptions`. The
`recognize_async` method is designed to process live stream data such as
camera input. To lower the overall latency, gesture recognizer may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` provides:
- The hand gesture recognition results.
- The input image that the gesture recognizer runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises:
ValueError: If the current input timestamp is smaller than what the
gesture recognizer has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})

View File

@ -30,6 +30,7 @@ from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
@ -37,6 +38,7 @@ _BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions _ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions _ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' _CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
_NORM_RECT_NAME = 'norm_rect_in' _NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT' _NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
def _build_full_image_norm_rect() -> _NormalizedRect:
# Builds a NormalizedRect covering the entire image.
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
@dataclasses.dataclass @dataclasses.dataclass
class ImageClassifierOptions: class ImageClassifierOptions:
"""Options for the image classifier task. """Options for the image classifier task.
@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=[
':'.join([ ':'.join([
@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
_RunningMode.LIVE_STREAM), options.running_mode, _RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None) packets_callback if options.result_callback else None)
# TODO: Replace _NormalizedRect with ImageProcessingOption
def classify( def classify(
self, self,
image: image_module.Image, image: image_module.Image,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult: ) -> classifications.ClassificationResult:
"""Performs image classification on the provided MediaPipe Image. """Performs image classification on the provided MediaPipe Image.
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
A classification result object that contains a list of classifications. A classification result object that contains a list of classifications.
@ -190,10 +186,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run. RuntimeError: If image classification failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_image_data({ output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), _IMAGE_IN_STREAM_NAME:
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()) packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
}) })
classification_result_proto = classifications_pb2.ClassificationResult() classification_result_proto = classifications_pb2.ClassificationResult()
@ -210,7 +208,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
roi: Optional[_NormalizedRect] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> classifications.ClassificationResult: ) -> classifications.ClassificationResult:
"""Performs image classification on the provided video frames. """Performs image classification on the provided video frames.
@ -222,7 +220,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Returns: Returns:
A classification result object that contains a list of classifications. A classification result object that contains a list of classifications.
@ -231,13 +229,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
ValueError: If any of the input arguments is invalid. ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run. RuntimeError: If image classification failed to run.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
output_packets = self._process_video_data({ output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at( packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME: _NORM_RECT_STREAM_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at( packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })
@ -251,10 +249,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
for classification in classification_result_proto.classifications for classification in classification_result_proto.classifications
]) ])
def classify_async(self, def classify_async(
image: image_module.Image, self,
timestamp_ms: int, image: image_module.Image,
roi: Optional[_NormalizedRect] = None) -> None: timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image classification. """Sends live image data (an Image with a unique timestamp) to perform image classification.
Only use this method when the ImageClassifier is created with the live Only use this method when the ImageClassifier is created with the live
@ -275,18 +275,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds. timestamp_ms: The timestamp of the input image in milliseconds.
roi: The region of interest. image_processing_options: Options for image processing.
Raises: Raises:
ValueError: If the current input timestamp is smaller than what the image ValueError: If the current input timestamp is smaller than what the image
classifier has already processed. classifier has already processed.
""" """
norm_rect = roi if roi is not None else _build_full_image_norm_rect() normalized_rect = self.convert_to_normalized_rect(image_processing_options)
self._send_live_stream_data({ self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: _IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at( packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME: _NORM_RECT_STREAM_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at( packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
}) })

View File

@ -44,18 +44,6 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmenterOptions: class ImageSegmenterOptions:
"""Options for the image segmenter task. """Options for the image segmenter task.
@ -74,6 +62,17 @@ class ImageSegmenterOptions:
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_type: Optional[OutputType] = OutputType.CATEGORY_MASK

View File

@ -0,0 +1,21 @@
# This package contains options shared by all MediaPipe Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "category",
srcs = ["category.d.ts"],
)
mediapipe_ts_library(
name = "classifications",
srcs = ["classifications.d.ts"],
deps = [":category"],
)
mediapipe_ts_library(
name = "landmark",
srcs = ["landmark.d.ts"],
)

View File

@ -0,0 +1,38 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** A classification category. */
export interface Category {
/** The probability score of this label category. */
score: number;
/** The index of the category in the corresponding label file. */
index: number;
/**
* The label of this category object. Defaults to an empty string if there is
* no category.
*/
categoryName: string;
/**
* The display name of the label, which may be translated for different
* locales. For example, a label, "apple", may be translated into Spanish for
* display purpose, so that the `display_name` is "manzana". Defaults to an
* empty string if there is no display name.
*/
displayName: string;
}

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Category} from '../../../../tasks/web/components/containers/category';
/** List of predicted categories with an optional timestamp. */
export interface ClassificationEntry {
/**
* The array of predicted categories, usually sorted by descending scores,
* e.g., from high to low probability.
*/
categories: Category[];
/**
* The optional timestamp (in milliseconds) associated to the classification
* entry. This is useful for time series use cases, e.g., audio
* classification.
*/
timestampMs?: number;
}
/** Classifications for a given classifier head. */
export interface Classifications {
/** A list of classification entries. */
entries: ClassificationEntry[];
/**
* The index of the classifier head these categories refer to. This is
* useful for multi-head models.
*/
headIndex: number;
/**
* The name of the classifier head, which is the corresponding tensor
* metadata name.
*/
headName: string;
}

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Landmark represents a point in 3D space with x, y, z coordinates. If
* normalized is true, the landmark coordinates is normalized respect to the
* dimension of image, and the coordinates values are in the range of [0,1].
* Otherwise, it represenet a point in world coordinates.
*/
export class Landmark {
/** The x coordinates of the landmark. */
x: number;
/** The y coordinates of the landmark. */
y: number;
/** The z coordinates of the landmark. */
z: number;
/** Whether this landmark is normalized with respect to the image size. */
normalized: boolean;
}

View File

@ -0,0 +1,33 @@
# This package contains options shared by all MediaPipe Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "classifier_options",
srcs = ["classifier_options.ts"],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto",
"//mediapipe/tasks/web/core:classifier_options",
],
)
mediapipe_ts_library(
name = "classifier_result",
srcs = ["classifier_result.ts"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classifications",
],
)
mediapipe_ts_library(
name = "base_options",
srcs = ["base_options.ts"],
deps = [
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core",
],
)

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/**
* Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided
*/
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
Promise<BaseOptionsProto> {
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
}
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
}
let modelAssetBuffer = baseOptions.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(baseOptions.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
const proto = new BaseOptionsProto();
const externalFile = new ExternalFile();
externalFile.setFileContent(modelAssetBuffer);
proto.setModelAsset(externalFile);
return proto;
}

View File

@ -0,0 +1,62 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb';
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
/**
* Converts a ClassifierOptions object to its Proto representation, optionally
* based on existing definition.
* @param options The options object to convert to a Proto. Only options that
* are expliclty provided are set.
* @param baseOptions A base object that options can be merged into.
*/
export function convertClassifierOptionsToProto(
options: ClassifierOptions,
baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
const classifierOptions =
baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
if (options.displayNamesLocale) {
classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
} else if (options.displayNamesLocale === undefined) {
classifierOptions.clearDisplayNamesLocale();
}
if (options.maxResults) {
classifierOptions.setMaxResults(options.maxResults);
} else if ('maxResults' in options) { // Check for undefined
classifierOptions.clearMaxResults();
}
if (options.scoreThreshold) {
classifierOptions.setScoreThreshold(options.scoreThreshold);
} else if ('scoreThreshold' in options) { // Check for undefined
classifierOptions.clearScoreThreshold();
}
if (options.categoryAllowlist) {
classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
} else if ('categoryAllowlist' in options) { // Check for undefined
classifierOptions.clearCategoryAllowlistList();
}
if (options.categoryDenylist) {
classifierOptions.setCategoryDenylistList(options.categoryDenylist);
} else if ('categoryDenylist' in options) { // Check for undefined
classifierOptions.clearCategoryDenylistList();
}
return classifierOptions;
}

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0;
/**
* Converts a ClassificationEntry proto to the ClassificationEntry result
* type.
*/
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
ClassificationEntry {
const categories = source.getCategoriesList().map(category => {
return {
index: category.getIndex() ?? DEFAULT_INDEX,
score: category.getScore() ?? DEFAULT_SCORE,
displayName: category.getDisplayName() ?? '',
categoryName: category.getCategoryName() ?? '',
};
});
return {
categories,
timestampMs: source.getTimestampMs(),
};
}
/**
* Converts a ClassificationResult proto to a list of classifications.
*/
export function convertFromClassificationResultProto(
classificationResult: ClassificationResult) : Classifications[] {
const result: Classifications[] = [];
for (const classificationsProto of
classificationResult.getClassificationsList()) {
const classifications: Classifications = {
entries: classificationsProto.getEntriesList().map(
entry => convertFromClassificationEntryProto(entry)),
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
headName: classificationsProto.getHeadName() ?? '',
};
result.push(classifications);
}
return result;
}

View File

@ -0,0 +1,21 @@
# This package contains options shared by all MediaPipe Tasks for Web.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library(
name = "core",
srcs = [
"base_options.d.ts",
"wasm_loader_options.d.ts",
],
)
mediapipe_ts_library(
name = "classifier_options",
srcs = [
"classifier_options.d.ts",
],
deps = [":core"],
)

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Placeholder for internal dependency on trusted resource url
/** Options to configure MediaPipe Tasks in general. */
export interface BaseOptions {
/**
* The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetPath?: string;
/**
* A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set.
*/
modelAssetBuffer?: Uint8Array;
}

View File

@ -0,0 +1,52 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {BaseOptions} from '../../../tasks/web/core/base_options';
/** Options to configure the Mediapipe Classifier Task. */
export interface ClassifierOptions {
/** Options to configure the loading of the model assets. */
baseOptions?: BaseOptions;
/**
* The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English.
*/
displayNamesLocale?: string|undefined;
/** The maximum number of top-scored detection results to return. */
maxResults?: number|undefined;
/**
* Overrides the value provided in the model metadata. Results below this
* value are rejected.
*/
scoreThreshold?: number|undefined;
/**
* Allowlist of category names. If non-empty, detection results whose category
* name is not in this set will be filtered out. Duplicate or unknown category
* names are ignored. Mutually exclusive with `categoryDenylist`.
*/
categoryAllowlist?: string[]|undefined;
/**
* Denylist of category names. If non-empty, detection results whose category
* name is in this set will be filtered out. Duplicate or unknown category
* names are ignored. Mutually exclusive with `categoryAllowlist`.
*/
categoryDenylist?: string[]|undefined;
}

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Placeholder for internal dependency on trusted resource url
/** An object containing the locations of all Wasm assets */
export interface WasmLoaderOptions {
/** The path to the Wasm loader script. */
wasmLoaderPath: string;
/** The path to the Wasm binary. */
wasmBinaryPath: string;
}

15
package.json Normal file
View File

@ -0,0 +1,15 @@
{
"name": "medipipe-dev",
"version": "0.0.0-alphga",
"description": "MediaPipe GitHub repo",
"devDependencies": {
"@bazel/typescript": "^5.7.1",
"@types/google-protobuf": "^3.15.6",
"@types/offscreencanvas": "^2019.7.0",
"google-protobuf": "^3.21.2",
"protobufjs": "^7.1.2",
"protobufjs-cli": "^1.0.2",
"ts-protoc-gen": "^0.15.0",
"typescript": "^4.8.4"
}
}

View File

@ -1,5 +1,6 @@
absl-py absl-py
attrs>=19.1.0 attrs>=19.1.0
flatbuffers>=2.0
matplotlib matplotlib
numpy numpy
opencv-contrib-python opencv-contrib-python

View File

@ -129,6 +129,15 @@ def _add_mp_init_files():
mp_dir_init_file.close() mp_dir_init_file.close()
def _copy_to_build_lib_dir(build_lib, file):
"""Copy a file from bazel-bin to the build lib dir."""
dst = os.path.join(build_lib + '/', file)
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
shutil.copyfile(os.path.join('bazel-bin/', file), dst)
class GeneratePyProtos(build_ext.build_ext): class GeneratePyProtos(build_ext.build_ext):
"""Generate MediaPipe Python protobuf files by Protocol Compiler.""" """Generate MediaPipe Python protobuf files by Protocol Compiler."""
@ -259,7 +268,7 @@ class BuildModules(build_ext.build_ext):
] ]
if subprocess.call(fetch_model_command) != 0: if subprocess.call(fetch_model_command) != 0:
sys.exit(-1) sys.exit(-1)
self._copy_to_build_lib_dir(external_file) _copy_to_build_lib_dir(self.build_lib, external_file)
def _generate_binary_graph(self, binary_graph_target): def _generate_binary_graph(self, binary_graph_target):
"""Generate binary graph for a particular MediaPipe binary graph target.""" """Generate binary graph for a particular MediaPipe binary graph target."""
@ -277,15 +286,27 @@ class BuildModules(build_ext.build_ext):
bazel_command.append('--define=OPENCV=source') bazel_command.append('--define=OPENCV=source')
if subprocess.call(bazel_command) != 0: if subprocess.call(bazel_command) != 0:
sys.exit(-1) sys.exit(-1)
self._copy_to_build_lib_dir(binary_graph_target + '.binarypb') _copy_to_build_lib_dir(self.build_lib, binary_graph_target + '.binarypb')
def _copy_to_build_lib_dir(self, file):
"""Copy a file from bazel-bin to the build lib dir.""" class GenerateMetadataSchema(build_ext.build_ext):
dst = os.path.join(self.build_lib + '/', file) """Generate metadata python schema files."""
dst_dir = os.path.dirname(dst)
if not os.path.exists(dst_dir): def run(self):
os.makedirs(dst_dir) for target in ['metadata_schema_py', 'schema_py']:
shutil.copyfile(os.path.join('bazel-bin/', file), dst) bazel_command = [
'bazel',
'build',
'--compilation_mode=opt',
'--define=MEDIAPIPE_DISABLE_GPU=1',
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
'//mediapipe/tasks/metadata:' + target,
]
if subprocess.call(bazel_command) != 0:
sys.exit(-1)
_copy_to_build_lib_dir(
self.build_lib,
'mediapipe/tasks/metadata/' + target + '_generated.py')
class BazelExtension(setuptools.Extension): class BazelExtension(setuptools.Extension):
@ -375,6 +396,7 @@ class BuildPy(build_py.build_py):
build_ext_obj = self.distribution.get_command_obj('build_ext') build_ext_obj = self.distribution.get_command_obj('build_ext')
build_ext_obj.link_opencv = self.link_opencv build_ext_obj.link_opencv = self.link_opencv
self.run_command('gen_protos') self.run_command('gen_protos')
self.run_command('generate_metadata_schema')
self.run_command('build_modules') self.run_command('build_modules')
self.run_command('build_ext') self.run_command('build_ext')
build_py.build_py.run(self) build_py.build_py.run(self)
@ -434,18 +456,25 @@ setuptools.setup(
author_email='mediapipe@google.com', author_email='mediapipe@google.com',
long_description=_get_long_description(), long_description=_get_long_description(),
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
packages=setuptools.find_packages(exclude=['mediapipe.examples.desktop.*']), packages=setuptools.find_packages(
exclude=['mediapipe.examples.desktop.*', 'mediapipe.model_maker.*']),
install_requires=_parse_requirements('requirements.txt'), install_requires=_parse_requirements('requirements.txt'),
cmdclass={ cmdclass={
'build_py': BuildPy, 'build_py': BuildPy,
'gen_protos': GeneratePyProtos,
'build_modules': BuildModules, 'build_modules': BuildModules,
'build_ext': BuildExtension, 'build_ext': BuildExtension,
'generate_metadata_schema': GenerateMetadataSchema,
'gen_protos': GeneratePyProtos,
'install': Install, 'install': Install,
'restore': Restore, 'restore': Restore,
}, },
ext_modules=[ ext_modules=[
BazelExtension('//mediapipe/python:_framework_bindings'), BazelExtension('//mediapipe/python:_framework_bindings'),
BazelExtension(
'//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version'),
BazelExtension(
'//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers'
),
], ],
zip_safe=False, zip_safe=False,
include_package_data=True, include_package_data=True,

47
tsconfig.json Normal file
View File

@ -0,0 +1,47 @@
{
"compilerOptions": {
"target": "es2017",
"module": "commonjs",
"lib": ["ES2017", "dom"],
"declaration": true,
"moduleResolution": "node",
"esModuleInterop": true,
"noImplicitAny": true,
"inlineSourceMap": true,
"inlineSources": true,
"strict": true,
"types": ["@types/offscreencanvas"],
"rootDirs": [
".",
"./bazel-out/host/bin",
"./bazel-out/darwin-dbg/bin",
"./bazel-out/darwin-fastbuild/bin",
"./bazel-out/darwin-opt/bin",
"./bazel-out/darwin_arm64-dbg/bin",
"./bazel-out/darwin_arm64-fastbuild/bin",
"./bazel-out/darwin_arm64-opt/bin",
"./bazel-out/k8-dbg/bin",
"./bazel-out/k8-fastbuild/bin",
"./bazel-out/k8-opt/bin",
"./bazel-out/x64_windows-dbg/bin",
"./bazel-out/x64_windows-fastbuild/bin",
"./bazel-out/x64_windows-opt/bin",
"./bazel-out/darwin-dbg/bin",
"./bazel-out/darwin-fastbuild/bin",
"./bazel-out/darwin-opt/bin",
"./bazel-out/k8-dbg/bin",
"./bazel-out/k8-fastbuild/bin",
"./bazel-out/k8-opt/bin",
"./bazel-out/x64_windows-dbg/bin",
"./bazel-out/x64_windows-fastbuild/bin",
"./bazel-out/x64_windows-opt/bin",
"./bazel-out/k8-fastbuild-ST-4a519fd6d3e4/bin"
]
},
"exclude": [
"./_bazel_bin",
"./_bazel_buildbot",
"./_bazel_out",
"./_bazel_testlogs"
]
}

626
yarn.lock Normal file
View File

@ -0,0 +1,626 @@
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
# yarn lockfile v1
"@babel/parser@^7.9.4":
version "7.20.1"
resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.1.tgz#3e045a92f7b4623cafc2425eddcb8cf2e54f9cc5"
integrity sha512-hp0AYxaZJhxULfM1zyp7Wgr+pSUKBcP3M+PHnSzWGdXOzg/kHWIgiUWARvubhUKGOEw3xqY4x+lyZ9ytBVcELw==
"@bazel/typescript@^5.7.1":
version "5.7.1"
resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682"
integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g==
dependencies:
"@bazel/worker" "5.7.1"
semver "5.6.0"
source-map-support "0.5.9"
tsutils "3.21.0"
"@bazel/worker@5.7.1":
version "5.7.1"
resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad"
integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg==
dependencies:
google-protobuf "^3.6.1"
"@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2":
version "1.1.2"
resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf"
integrity sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==
"@protobufjs/base64@^1.1.2":
version "1.1.2"
resolved "https://registry.yarnpkg.com/@protobufjs/base64/-/base64-1.1.2.tgz#4c85730e59b9a1f1f349047dbf24296034bb2735"
integrity sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==
"@protobufjs/codegen@^2.0.4":
version "2.0.4"
resolved "https://registry.yarnpkg.com/@protobufjs/codegen/-/codegen-2.0.4.tgz#7ef37f0d010fb028ad1ad59722e506d9262815cb"
integrity sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==
"@protobufjs/eventemitter@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz#355cbc98bafad5978f9ed095f397621f1d066b70"
integrity sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==
"@protobufjs/fetch@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/fetch/-/fetch-1.1.0.tgz#ba99fb598614af65700c1619ff06d454b0d84c45"
integrity sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==
dependencies:
"@protobufjs/aspromise" "^1.1.1"
"@protobufjs/inquire" "^1.1.0"
"@protobufjs/float@^1.0.2":
version "1.0.2"
resolved "https://registry.yarnpkg.com/@protobufjs/float/-/float-1.0.2.tgz#5e9e1abdcb73fc0a7cb8b291df78c8cbd97b87d1"
integrity sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==
"@protobufjs/inquire@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/inquire/-/inquire-1.1.0.tgz#ff200e3e7cf2429e2dcafc1140828e8cc638f089"
integrity sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==
"@protobufjs/path@^1.1.2":
version "1.1.2"
resolved "https://registry.yarnpkg.com/@protobufjs/path/-/path-1.1.2.tgz#6cc2b20c5c9ad6ad0dccfd21ca7673d8d7fbf68d"
integrity sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==
"@protobufjs/pool@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/pool/-/pool-1.1.0.tgz#09fd15f2d6d3abfa9b65bc366506d6ad7846ff54"
integrity sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==
"@protobufjs/utf8@^1.1.0":
version "1.1.0"
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"
integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==
"@types/google-protobuf@^3.15.6":
version "3.15.6"
resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504"
integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw==
"@types/linkify-it@*":
version "3.0.2"
resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9"
integrity sha512-HZQYqbiFVWufzCwexrvh694SOim8z2d+xJl5UNamcvQFejLY/2YUtzXHYi3cHdI7PMlS8ejH2slRAOJQ32aNbA==
"@types/markdown-it@^12.2.3":
version "12.2.3"
resolved "https://registry.yarnpkg.com/@types/markdown-it/-/markdown-it-12.2.3.tgz#0d6f6e5e413f8daaa26522904597be3d6cd93b51"
integrity sha512-GKMHFfv3458yYy+v/N8gjufHO6MSZKCOXpZc5GXIWWy8uldwfmPn98vp81gZ5f9SVw8YYBctgfJ22a2d7AOMeQ==
dependencies:
"@types/linkify-it" "*"
"@types/mdurl" "*"
"@types/mdurl@*":
version "1.0.2"
resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9"
integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA==
"@types/node@>=13.7.0":
version "18.11.9"
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4"
integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg==
"@types/offscreencanvas@^2019.7.0":
version "2019.7.0"
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d"
integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg==
acorn-jsx@^5.3.2:
version "5.3.2"
resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937"
integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==
acorn@^8.8.0:
version "8.8.1"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73"
integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==
ansi-styles@^4.1.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937"
integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==
dependencies:
color-convert "^2.0.1"
argparse@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38"
integrity sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==
balanced-match@^1.0.0:
version "1.0.2"
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee"
integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==
bluebird@^3.7.2:
version "3.7.2"
resolved "https://registry.yarnpkg.com/bluebird/-/bluebird-3.7.2.tgz#9f229c15be272454ffa973ace0dbee79a1b0c36f"
integrity sha512-XpNj6GDQzdfW+r2Wnn7xiSAd7TM3jzkxGXBGTtWKuSXv1xUV+azxAm8jdWZN06QTQk+2N2XB9jRDkvbmQmcRtg==
brace-expansion@^1.1.7:
version "1.1.11"
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd"
integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==
dependencies:
balanced-match "^1.0.0"
concat-map "0.0.1"
brace-expansion@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-2.0.1.tgz#1edc459e0f0c548486ecf9fc99f2221364b9a0ae"
integrity sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==
dependencies:
balanced-match "^1.0.0"
buffer-from@^1.0.0:
version "1.1.2"
resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.2.tgz#2b146a6fd72e80b4f55d255f35ed59a3a9a41bd5"
integrity sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==
catharsis@^0.9.0:
version "0.9.0"
resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121"
integrity sha512-prMTQVpcns/tzFgFVkVp6ak6RykZyWb3gu8ckUpd6YkTlacOd3DXGJjIpD4Q6zJirizvaiAjSSHlOsA+6sNh2A==
dependencies:
lodash "^4.17.15"
chalk@^4.0.0:
version "4.1.2"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01"
integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==
dependencies:
ansi-styles "^4.1.0"
supports-color "^7.1.0"
color-convert@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3"
integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==
dependencies:
color-name "~1.1.4"
color-name@~1.1.4:
version "1.1.4"
resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2"
integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==
concat-map@0.0.1:
version "0.0.1"
resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b"
integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==
deep-is@~0.1.3:
version "0.1.4"
resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831"
integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==
entities@~2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5"
integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w==
escape-string-regexp@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344"
integrity sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w==
escodegen@^1.13.0:
version "1.14.3"
resolved "https://registry.yarnpkg.com/escodegen/-/escodegen-1.14.3.tgz#4e7b81fba61581dc97582ed78cab7f0e8d63f503"
integrity sha512-qFcX0XJkdg+PB3xjZZG/wKSuT1PnQWx57+TVSjIMmILd2yC/6ByYElPwJnslDsuWuSAp4AwJGumarAAmJch5Kw==
dependencies:
esprima "^4.0.1"
estraverse "^4.2.0"
esutils "^2.0.2"
optionator "^0.8.1"
optionalDependencies:
source-map "~0.6.1"
eslint-visitor-keys@^3.3.0:
version "3.3.0"
resolved "https://registry.yarnpkg.com/eslint-visitor-keys/-/eslint-visitor-keys-3.3.0.tgz#f6480fa6b1f30efe2d1968aa8ac745b862469826"
integrity sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA==
espree@^9.0.0:
version "9.4.0"
resolved "https://registry.yarnpkg.com/espree/-/espree-9.4.0.tgz#cd4bc3d6e9336c433265fc0aa016fc1aaf182f8a"
integrity sha512-DQmnRpLj7f6TgN/NYb0MTzJXL+vJF9h3pHy4JhCIs3zwcgez8xmGg3sXHcEO97BrmO2OSvCwMdfdlyl+E9KjOw==
dependencies:
acorn "^8.8.0"
acorn-jsx "^5.3.2"
eslint-visitor-keys "^3.3.0"
esprima@^4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/esprima/-/esprima-4.0.1.tgz#13b04cdb3e6c5d19df91ab6987a8695619b0aa71"
integrity sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==
estraverse@^4.2.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-4.3.0.tgz#398ad3f3c5a24948be7725e83d11a7de28cdbd1d"
integrity sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==
estraverse@^5.1.0:
version "5.3.0"
resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.3.0.tgz#2eea5290702f26ab8fe5370370ff86c965d21123"
integrity sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==
esutils@^2.0.2:
version "2.0.3"
resolved "https://registry.yarnpkg.com/esutils/-/esutils-2.0.3.tgz#74d2eb4de0b8da1293711910d50775b9b710ef64"
integrity sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==
fast-levenshtein@~2.0.6:
version "2.0.6"
resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917"
integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==
fs.realpath@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f"
integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==
glob@^7.1.3:
version "7.2.3"
resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b"
integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==
dependencies:
fs.realpath "^1.0.0"
inflight "^1.0.4"
inherits "2"
minimatch "^3.1.1"
once "^1.3.0"
path-is-absolute "^1.0.0"
glob@^8.0.0:
version "8.0.3"
resolved "https://registry.yarnpkg.com/glob/-/glob-8.0.3.tgz#415c6eb2deed9e502c68fa44a272e6da6eeca42e"
integrity sha512-ull455NHSHI/Y1FqGaaYFaLGkNMMJbavMrEGFXG/PGrg6y7sutWHUHrz6gy6WEBH6akM1M414dWKCNs+IhKdiQ==
dependencies:
fs.realpath "^1.0.0"
inflight "^1.0.4"
inherits "2"
minimatch "^5.0.1"
once "^1.3.0"
google-protobuf@^3.15.5, google-protobuf@^3.21.2, google-protobuf@^3.6.1:
version "3.21.2"
resolved "https://registry.yarnpkg.com/google-protobuf/-/google-protobuf-3.21.2.tgz#4580a2bea8bbb291ee579d1fefb14d6fa3070ea4"
integrity sha512-3MSOYFO5U9mPGikIYCzK0SaThypfGgS6bHqrUGXG3DPHCrb+txNqeEcns1W0lkGfk0rCyNXm7xB9rMxnCiZOoA==
graceful-fs@^4.1.9:
version "4.2.10"
resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.10.tgz#147d3a006da4ca3ce14728c7aefc287c367d7a6c"
integrity sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==
has-flag@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b"
integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==
inflight@^1.0.4:
version "1.0.6"
resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9"
integrity sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==
dependencies:
once "^1.3.0"
wrappy "1"
inherits@2:
version "2.0.4"
resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c"
integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==
js2xmlparser@^4.0.2:
version "4.0.2"
resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a"
integrity sha512-6n4D8gLlLf1n5mNLQPRfViYzu9RATblzPEtm1SthMX1Pjao0r9YI9nw7ZIfRxQMERS87mcswrg+r/OYrPRX6jA==
dependencies:
xmlcreate "^2.0.4"
jsdoc@^3.6.3:
version "3.6.11"
resolved "https://registry.yarnpkg.com/jsdoc/-/jsdoc-3.6.11.tgz#8bbb5747e6f579f141a5238cbad4e95e004458ce"
integrity sha512-8UCU0TYeIYD9KeLzEcAu2q8N/mx9O3phAGl32nmHlE0LpaJL71mMkP4d+QE5zWfNt50qheHtOZ0qoxVrsX5TUg==
dependencies:
"@babel/parser" "^7.9.4"
"@types/markdown-it" "^12.2.3"
bluebird "^3.7.2"
catharsis "^0.9.0"
escape-string-regexp "^2.0.0"
js2xmlparser "^4.0.2"
klaw "^3.0.0"
markdown-it "^12.3.2"
markdown-it-anchor "^8.4.1"
marked "^4.0.10"
mkdirp "^1.0.4"
requizzle "^0.2.3"
strip-json-comments "^3.1.0"
taffydb "2.6.2"
underscore "~1.13.2"
klaw@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/klaw/-/klaw-3.0.0.tgz#b11bec9cf2492f06756d6e809ab73a2910259146"
integrity sha512-0Fo5oir+O9jnXu5EefYbVK+mHMBeEVEy2cmctR1O1NECcCkPRreJKrS6Qt/j3KC2C148Dfo9i3pCmCMsdqGr0g==
dependencies:
graceful-fs "^4.1.9"
levn@~0.3.0:
version "0.3.0"
resolved "https://registry.yarnpkg.com/levn/-/levn-0.3.0.tgz#3b09924edf9f083c0490fdd4c0bc4421e04764ee"
integrity sha512-0OO4y2iOHix2W6ujICbKIaEQXvFQHue65vUG3pb5EUomzPI90z9hsA1VsO/dbIIpC53J8gxM9Q4Oho0jrCM/yA==
dependencies:
prelude-ls "~1.1.2"
type-check "~0.3.2"
linkify-it@^3.0.1:
version "3.0.3"
resolved "https://registry.yarnpkg.com/linkify-it/-/linkify-it-3.0.3.tgz#a98baf44ce45a550efb4d49c769d07524cc2fa2e"
integrity sha512-ynTsyrFSdE5oZ/O9GEf00kPngmOfVwazR5GKDq6EYfhlpFug3J2zybX56a2PRRpc9P+FuSoGNAwjlbDs9jJBPQ==
dependencies:
uc.micro "^1.0.1"
lodash@^4.17.14, lodash@^4.17.15:
version "4.17.21"
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
long@^5.0.0:
version "5.2.0"
resolved "https://registry.yarnpkg.com/long/-/long-5.2.0.tgz#2696dadf4b4da2ce3f6f6b89186085d94d52fd61"
integrity sha512-9RTUNjK60eJbx3uz+TEGF7fUr29ZDxR5QzXcyDpeSfeH28S9ycINflOgOlppit5U+4kNTe83KQnMEerw7GmE8w==
lru-cache@^6.0.0:
version "6.0.0"
resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-6.0.0.tgz#6d6fe6570ebd96aaf90fcad1dafa3b2566db3a94"
integrity sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==
dependencies:
yallist "^4.0.0"
markdown-it-anchor@^8.4.1:
version "8.6.5"
resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8"
integrity sha512-PI1qEHHkTNWT+X6Ip9w+paonfIQ+QZP9sCeMYi47oqhH+EsW8CrJ8J7CzV19QVOj6il8ATGbK2nTECj22ZHGvQ==
markdown-it@^12.3.2:
version "12.3.2"
resolved "https://registry.yarnpkg.com/markdown-it/-/markdown-it-12.3.2.tgz#bf92ac92283fe983fe4de8ff8abfb5ad72cd0c90"
integrity sha512-TchMembfxfNVpHkbtriWltGWc+m3xszaRD0CZup7GFFhzIgQqxIfn3eGj1yZpfuflzPvfkt611B2Q/Bsk1YnGg==
dependencies:
argparse "^2.0.1"
entities "~2.1.0"
linkify-it "^3.0.1"
mdurl "^1.0.1"
uc.micro "^1.0.5"
marked@^4.0.10:
version "4.2.1"
resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.1.tgz#eaa32594e45b4e58c02e4d118531fd04345de3b4"
integrity sha512-VK1/jNtwqDLvPktNpL0Fdg3qoeUZhmRsuiIjPEy/lHwXW4ouLoZfO4XoWd4ClDt+hupV1VLpkZhEovjU0W/kqA==
mdurl@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e"
integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g==
minimatch@^3.1.1:
version "3.1.2"
resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b"
integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==
dependencies:
brace-expansion "^1.1.7"
minimatch@^5.0.1:
version "5.1.0"
resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7"
integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg==
dependencies:
brace-expansion "^2.0.1"
minimist@^1.2.0:
version "1.2.7"
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.7.tgz#daa1c4d91f507390437c6a8bc01078e7000c4d18"
integrity sha512-bzfL1YUZsP41gmu/qjrEk0Q6i2ix/cVeAhbCbqH9u3zYutS1cLg00qhrD0M2MVdCcx4Sc0UpP2eBWo9rotpq6g==
mkdirp@^1.0.4:
version "1.0.4"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-1.0.4.tgz#3eb5ed62622756d79a5f0e2a221dfebad75c2f7e"
integrity sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==
once@^1.3.0:
version "1.4.0"
resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1"
integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==
dependencies:
wrappy "1"
optionator@^0.8.1:
version "0.8.3"
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.3.tgz#84fa1d036fe9d3c7e21d99884b601167ec8fb495"
integrity sha512-+IW9pACdk3XWmmTXG8m3upGUJst5XRGzxMRjXzAuJ1XnIFNvfhjjIuYkDvysnPQ7qzqVzLt78BCruntqRhWQbA==
dependencies:
deep-is "~0.1.3"
fast-levenshtein "~2.0.6"
levn "~0.3.0"
prelude-ls "~1.1.2"
type-check "~0.3.2"
word-wrap "~1.2.3"
path-is-absolute@^1.0.0:
version "1.0.1"
resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f"
integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==
prelude-ls@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.1.2.tgz#21932a549f5e52ffd9a827f570e04be62a97da54"
integrity sha512-ESF23V4SKG6lVSGZgYNpbsiaAkdab6ZgOxe52p7+Kid3W3u3bxR4Vfd/o21dmN7jSt0IwgZ4v5MUd26FEtXE9w==
protobufjs-cli@^1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/protobufjs-cli/-/protobufjs-cli-1.0.2.tgz#905fc49007cf4aaf3c45d5f250eb294eedeea062"
integrity sha512-cz9Pq9p/Zs7okc6avH20W7QuyjTclwJPgqXG11jNaulfS3nbVisID8rC+prfgq0gbZE0w9LBFd1OKFF03kgFzg==
dependencies:
chalk "^4.0.0"
escodegen "^1.13.0"
espree "^9.0.0"
estraverse "^5.1.0"
glob "^8.0.0"
jsdoc "^3.6.3"
minimist "^1.2.0"
semver "^7.1.2"
tmp "^0.2.1"
uglify-js "^3.7.7"
protobufjs@^7.1.2:
version "7.1.2"
resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-7.1.2.tgz#a0cf6aeaf82f5625bffcf5a38b7cd2a7de05890c"
integrity sha512-4ZPTPkXCdel3+L81yw3dG6+Kq3umdWKh7Dc7GW/CpNk4SX3hK58iPCWeCyhVTDrbkNeKrYNZ7EojM5WDaEWTLQ==
dependencies:
"@protobufjs/aspromise" "^1.1.2"
"@protobufjs/base64" "^1.1.2"
"@protobufjs/codegen" "^2.0.4"
"@protobufjs/eventemitter" "^1.1.0"
"@protobufjs/fetch" "^1.1.0"
"@protobufjs/float" "^1.0.2"
"@protobufjs/inquire" "^1.1.0"
"@protobufjs/path" "^1.1.2"
"@protobufjs/pool" "^1.1.0"
"@protobufjs/utf8" "^1.1.0"
"@types/node" ">=13.7.0"
long "^5.0.0"
requizzle@^0.2.3:
version "0.2.3"
resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded"
integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ==
dependencies:
lodash "^4.17.14"
rimraf@^3.0.0:
version "3.0.2"
resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-3.0.2.tgz#f1a5402ba6220ad52cc1282bac1ae3aa49fd061a"
integrity sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==
dependencies:
glob "^7.1.3"
semver@5.6.0:
version "5.6.0"
resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004"
integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg==
semver@^7.1.2:
version "7.3.8"
resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798"
integrity sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==
dependencies:
lru-cache "^6.0.0"
source-map-support@0.5.9:
version "0.5.9"
resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f"
integrity sha512-gR6Rw4MvUlYy83vP0vxoVNzM6t8MUXqNuRsuBmBHQDu1Fh6X015FrLdgoDKcNdkwGubozq0P4N0Q37UyFVr1EA==
dependencies:
buffer-from "^1.0.0"
source-map "^0.6.0"
source-map@^0.6.0, source-map@~0.6.1:
version "0.6.1"
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263"
integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==
strip-json-comments@^3.1.0:
version "3.1.1"
resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006"
integrity sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==
supports-color@^7.1.0:
version "7.2.0"
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.2.0.tgz#1b7dcdcb32b8138801b3e478ba6a51caa89648da"
integrity sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==
dependencies:
has-flag "^4.0.0"
taffydb@2.6.2:
version "2.6.2"
resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268"
integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA==
tmp@^0.2.1:
version "0.2.1"
resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14"
integrity sha512-76SUhtfqR2Ijn+xllcI5P1oyannHNHByD80W1q447gU3mp9G9PSpGdWmjUOHRDPiHYacIk66W7ubDTuPF3BEtQ==
dependencies:
rimraf "^3.0.0"
ts-protoc-gen@^0.15.0:
version "0.15.0"
resolved "https://registry.yarnpkg.com/ts-protoc-gen/-/ts-protoc-gen-0.15.0.tgz#2fec5930b46def7dcc9fa73c060d770b7b076b7b"
integrity sha512-TycnzEyrdVDlATJ3bWFTtra3SCiEP0W0vySXReAuEygXCUr1j2uaVyL0DhzjwuUdQoW5oXPwk6oZWeA0955V+g==
dependencies:
google-protobuf "^3.15.5"
tslib@^1.8.1:
version "1.14.1"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00"
integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==
tsutils@3.21.0:
version "3.21.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.21.0.tgz#b48717d394cea6c1e096983eed58e9d61715b623"
integrity sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==
dependencies:
tslib "^1.8.1"
type-check@~0.3.2:
version "0.3.2"
resolved "https://registry.yarnpkg.com/type-check/-/type-check-0.3.2.tgz#5884cab512cf1d355e3fb784f30804b2b520db72"
integrity sha512-ZCmOJdvOWDBYJlzAoFkC+Q0+bUyEOS1ltgp1MGU03fqHG+dbi9tBFU2Rd9QKiDZFAYrhPh2JUf7rZRIuHRKtOg==
dependencies:
prelude-ls "~1.1.2"
typescript@^4.8.4:
version "4.8.4"
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6"
integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ==
uc.micro@^1.0.1, uc.micro@^1.0.5:
version "1.0.6"
resolved "https://registry.yarnpkg.com/uc.micro/-/uc.micro-1.0.6.tgz#9c411a802a409a91fc6cf74081baba34b24499ac"
integrity sha512-8Y75pvTYkLJW2hWQHXxoqRgV7qb9B+9vFEtidML+7koHUFapnVJAZ6cKs+Qjz5Aw3aZWHMC6u0wJE3At+nSGwA==
uglify-js@^3.7.7:
version "3.17.4"
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
underscore@~1.13.2:
version "1.13.6"
resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441"
integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==
word-wrap@~1.2.3:
version "1.2.3"
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
wrappy@1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f"
integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==
xmlcreate@^2.0.4:
version "2.0.4"
resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be"
integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg==
yallist@^4.0.0:
version "4.0.0"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==