Merge branch 'google:master' into text-classifier-python
This commit is contained in:
commit
740d2e47b5
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,5 +2,6 @@ bazel-*
|
||||||
mediapipe/MediaPipe.xcodeproj
|
mediapipe/MediaPipe.xcodeproj
|
||||||
mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user
|
mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user
|
||||||
mediapipe/provisioning_profile.mobileprovision
|
mediapipe/provisioning_profile.mobileprovision
|
||||||
|
node_modules/
|
||||||
.configure.bazelrc
|
.configure.bazelrc
|
||||||
.user.bazelrc
|
.user.bazelrc
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2019 The MediaPipe Authors.
|
# Copyright 2022 The MediaPipe Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,4 +14,9 @@
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files([
|
||||||
|
"LICENSE",
|
||||||
|
"tsconfig.json",
|
||||||
|
"package.json",
|
||||||
|
"yarn.lock",
|
||||||
|
])
|
||||||
|
|
43
WORKSPACE
43
WORKSPACE
|
@ -501,5 +501,48 @@ libedgetpu_dependencies()
|
||||||
load("@coral_crosstool//:configure.bzl", "cc_crosstool")
|
load("@coral_crosstool//:configure.bzl", "cc_crosstool")
|
||||||
cc_crosstool(name = "crosstool")
|
cc_crosstool(name = "crosstool")
|
||||||
|
|
||||||
|
|
||||||
|
# Node dependencies
|
||||||
|
http_archive(
|
||||||
|
name = "build_bazel_rules_nodejs",
|
||||||
|
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
|
||||||
|
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
|
||||||
|
)
|
||||||
|
|
||||||
|
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
|
||||||
|
build_bazel_rules_nodejs_dependencies()
|
||||||
|
|
||||||
|
# fetches nodejs, npm, and yarn
|
||||||
|
load("@build_bazel_rules_nodejs//:index.bzl", "node_repositories", "yarn_install")
|
||||||
|
node_repositories()
|
||||||
|
yarn_install(
|
||||||
|
name = "npm",
|
||||||
|
package_json = "//:package.json",
|
||||||
|
yarn_lock = "//:yarn.lock",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Protobuf for Node dependencies
|
||||||
|
http_archive(
|
||||||
|
name = "rules_proto_grpc",
|
||||||
|
sha256 = "bbe4db93499f5c9414926e46f9e35016999a4e9f6e3522482d3760dc61011070",
|
||||||
|
strip_prefix = "rules_proto_grpc-4.2.0",
|
||||||
|
urls = ["https://github.com/rules-proto-grpc/rules_proto_grpc/archive/4.2.0.tar.gz"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "com_google_protobuf_javascript",
|
||||||
|
sha256 = "35bca1729532b0a77280bf28ab5937438e3dcccd6b31a282d9ae84c896b6f6e3",
|
||||||
|
strip_prefix = "protobuf-javascript-3.21.2",
|
||||||
|
urls = ["https://github.com/protocolbuffers/protobuf-javascript/archive/refs/tags/v3.21.2.tar.gz"],
|
||||||
|
)
|
||||||
|
|
||||||
|
load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains", "rules_proto_grpc_repos")
|
||||||
|
rules_proto_grpc_toolchains()
|
||||||
|
rules_proto_grpc_repos()
|
||||||
|
|
||||||
|
load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
|
||||||
|
rules_proto_dependencies()
|
||||||
|
rules_proto_toolchains()
|
||||||
|
|
||||||
load("//third_party:external_files.bzl", "external_files")
|
load("//third_party:external_files.bzl", "external_files")
|
||||||
external_files()
|
external_files()
|
||||||
|
|
|
@ -141,3 +141,4 @@ Nvidia Jetson and Raspberry Pi, please read
|
||||||
```bash
|
```bash
|
||||||
(mp_env)mediapipe$ python3 setup.py bdist_wheel
|
(mp_env)mediapipe$ python3 setup.py bdist_wheel
|
||||||
```
|
```
|
||||||
|
7. Exit from the MediaPipe repo directory and launch the Python interpreter.
|
||||||
|
|
|
@ -936,6 +936,7 @@ cc_test(
|
||||||
"//mediapipe/framework/tool:simulation_clock",
|
"//mediapipe/framework/tool:simulation_clock",
|
||||||
"//mediapipe/framework/tool:simulation_clock_executor",
|
"//mediapipe/framework/tool:simulation_clock_executor",
|
||||||
"//mediapipe/framework/tool:sink",
|
"//mediapipe/framework/tool:sink",
|
||||||
|
"//mediapipe/util:packet_test_util",
|
||||||
"@com_google_absl//absl/time",
|
"@com_google_absl//absl/time",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
|
|
||||||
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
|
#include "mediapipe/calculators/core/flow_limiter_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/util/header_util.h"
|
#include "mediapipe/util/header_util.h"
|
||||||
|
|
||||||
|
@ -68,7 +67,7 @@ constexpr char kOptionsTag[] = "OPTIONS";
|
||||||
// FlowLimiterCalculator provides limited support for multiple input streams.
|
// FlowLimiterCalculator provides limited support for multiple input streams.
|
||||||
// The first input stream is treated as the main input stream and successive
|
// The first input stream is treated as the main input stream and successive
|
||||||
// input streams are treated as auxiliary input streams. The auxiliary input
|
// input streams are treated as auxiliary input streams. The auxiliary input
|
||||||
// streams are limited to timestamps passed on the main input stream.
|
// streams are limited to timestamps allowed by the "ALLOW" stream.
|
||||||
//
|
//
|
||||||
class FlowLimiterCalculator : public CalculatorBase {
|
class FlowLimiterCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
|
@ -100,64 +99,11 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
|
cc->InputSidePackets().Tag(kMaxInFlightTag).Get<int>());
|
||||||
}
|
}
|
||||||
input_queues_.resize(cc->Inputs().NumEntries(""));
|
input_queues_.resize(cc->Inputs().NumEntries(""));
|
||||||
|
allowed_[Timestamp::Unset()] = true;
|
||||||
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs())));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if an additional frame can be released for processing.
|
|
||||||
// The "ALLOW" output stream indicates this condition at each input frame.
|
|
||||||
bool ProcessingAllowed() {
|
|
||||||
return frames_in_flight_.size() < options_.max_in_flight();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Outputs a packet indicating whether a frame was sent or dropped.
|
|
||||||
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
|
||||||
if (cc->Outputs().HasTag(kAllowTag)) {
|
|
||||||
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the timestamp bound or closes an output stream.
|
|
||||||
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
|
|
||||||
if (bound > Timestamp::Max()) {
|
|
||||||
stream->Close();
|
|
||||||
} else {
|
|
||||||
stream->SetNextTimestampBound(bound);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if a certain timestamp is being processed.
|
|
||||||
bool IsInFlight(Timestamp timestamp) {
|
|
||||||
return std::find(frames_in_flight_.begin(), frames_in_flight_.end(),
|
|
||||||
timestamp) != frames_in_flight_.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases input packets up to the latest settled input timestamp.
|
|
||||||
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
|
|
||||||
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
|
|
||||||
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
|
|
||||||
// Release settled frames from each input queue.
|
|
||||||
while (!input_queues_[i].empty() &&
|
|
||||||
input_queues_[i].front().Timestamp() < settled_bound) {
|
|
||||||
Packet packet = input_queues_[i].front();
|
|
||||||
input_queues_[i].pop_front();
|
|
||||||
if (IsInFlight(packet.Timestamp())) {
|
|
||||||
cc->Outputs().Get("", i).AddPacket(packet);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Propagate each input timestamp bound.
|
|
||||||
if (!input_queues_[i].empty()) {
|
|
||||||
Timestamp bound = input_queues_[i].front().Timestamp();
|
|
||||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
|
||||||
} else {
|
|
||||||
Timestamp bound =
|
|
||||||
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
|
|
||||||
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Releases input packets allowed by the max_in_flight constraint.
|
// Releases input packets allowed by the max_in_flight constraint.
|
||||||
absl::Status Process(CalculatorContext* cc) final {
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
options_ = tool::RetrieveOptions(options_, cc->Inputs());
|
||||||
|
@ -224,13 +170,97 @@ class FlowLimiterCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
ProcessAuxiliaryInputs(cc);
|
ProcessAuxiliaryInputs(cc);
|
||||||
|
|
||||||
|
// Discard old ALLOW ranges.
|
||||||
|
Timestamp input_bound = InputTimestampBound(cc);
|
||||||
|
auto first_range = std::prev(allowed_.upper_bound(input_bound));
|
||||||
|
allowed_.erase(allowed_.begin(), first_range);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int LedgerSize() {
|
||||||
|
int result = frames_in_flight_.size() + allowed_.size();
|
||||||
|
for (const auto& queue : input_queues_) {
|
||||||
|
result += queue.size();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns true if an additional frame can be released for processing.
|
||||||
|
// The "ALLOW" output stream indicates this condition at each input frame.
|
||||||
|
bool ProcessingAllowed() {
|
||||||
|
return frames_in_flight_.size() < options_.max_in_flight();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Outputs a packet indicating whether a frame was sent or dropped.
|
||||||
|
void SendAllow(bool allow, Timestamp ts, CalculatorContext* cc) {
|
||||||
|
if (cc->Outputs().HasTag(kAllowTag)) {
|
||||||
|
cc->Outputs().Tag(kAllowTag).AddPacket(MakePacket<bool>(allow).At(ts));
|
||||||
|
}
|
||||||
|
allowed_[ts] = allow;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if a timestamp falls within a range of allowed timestamps.
|
||||||
|
bool IsAllowed(Timestamp timestamp) {
|
||||||
|
auto it = allowed_.upper_bound(timestamp);
|
||||||
|
return std::prev(it)->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets the timestamp bound or closes an output stream.
|
||||||
|
void SetNextTimestampBound(Timestamp bound, OutputStream* stream) {
|
||||||
|
if (bound > Timestamp::Max()) {
|
||||||
|
stream->Close();
|
||||||
|
} else {
|
||||||
|
stream->SetNextTimestampBound(bound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the lowest unprocessed input Timestamp.
|
||||||
|
Timestamp InputTimestampBound(CalculatorContext* cc) {
|
||||||
|
Timestamp result = Timestamp::Done();
|
||||||
|
for (int i = 0; i < input_queues_.size(); ++i) {
|
||||||
|
auto& queue = input_queues_[i];
|
||||||
|
auto& stream = cc->Inputs().Get("", i);
|
||||||
|
Timestamp bound = queue.empty()
|
||||||
|
? stream.Value().Timestamp().NextAllowedInStream()
|
||||||
|
: queue.front().Timestamp();
|
||||||
|
result = std::min(result, bound);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Releases input packets up to the latest settled input timestamp.
|
||||||
|
void ProcessAuxiliaryInputs(CalculatorContext* cc) {
|
||||||
|
Timestamp settled_bound = cc->Outputs().Get("", 0).NextTimestampBound();
|
||||||
|
for (int i = 1; i < cc->Inputs().NumEntries(""); ++i) {
|
||||||
|
// Release settled frames from each input queue.
|
||||||
|
while (!input_queues_[i].empty() &&
|
||||||
|
input_queues_[i].front().Timestamp() < settled_bound) {
|
||||||
|
Packet packet = input_queues_[i].front();
|
||||||
|
input_queues_[i].pop_front();
|
||||||
|
if (IsAllowed(packet.Timestamp())) {
|
||||||
|
cc->Outputs().Get("", i).AddPacket(packet);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Propagate each input timestamp bound.
|
||||||
|
if (!input_queues_[i].empty()) {
|
||||||
|
Timestamp bound = input_queues_[i].front().Timestamp();
|
||||||
|
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
||||||
|
} else {
|
||||||
|
Timestamp bound =
|
||||||
|
cc->Inputs().Get("", i).Value().Timestamp().NextAllowedInStream();
|
||||||
|
SetNextTimestampBound(bound, &cc->Outputs().Get("", i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FlowLimiterCalculatorOptions options_;
|
FlowLimiterCalculatorOptions options_;
|
||||||
std::vector<std::deque<Packet>> input_queues_;
|
std::vector<std::deque<Packet>> input_queues_;
|
||||||
std::deque<Timestamp> frames_in_flight_;
|
std::deque<Timestamp> frames_in_flight_;
|
||||||
|
std::map<Timestamp, bool> allowed_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(FlowLimiterCalculator);
|
REGISTER_CALCULATOR(FlowLimiterCalculator);
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/time/clock.h"
|
#include "absl/time/clock.h"
|
||||||
|
@ -32,6 +33,7 @@
|
||||||
#include "mediapipe/framework/tool/simulation_clock.h"
|
#include "mediapipe/framework/tool/simulation_clock.h"
|
||||||
#include "mediapipe/framework/tool/simulation_clock_executor.h"
|
#include "mediapipe/framework/tool/simulation_clock_executor.h"
|
||||||
#include "mediapipe/framework/tool/sink.h"
|
#include "mediapipe/framework/tool/sink.h"
|
||||||
|
#include "mediapipe/util/packet_test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -77,6 +79,77 @@ std::vector<T> PacketValues(const std::vector<Packet>& packets) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<Packet> MakePackets(std::vector<std::pair<Timestamp, T>> contents) {
|
||||||
|
std::vector<Packet> result;
|
||||||
|
for (auto& entry : contents) {
|
||||||
|
result.push_back(MakePacket<T>(entry.second).At(entry.first));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SourceString(Timestamp t) {
|
||||||
|
return (t.IsSpecialValue())
|
||||||
|
? t.DebugString()
|
||||||
|
: absl::StrCat("Timestamp(", t.DebugString(), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PacketContainer, typename PacketContent>
|
||||||
|
class PacketsEqMatcher
|
||||||
|
: public ::testing::MatcherInterface<const PacketContainer&> {
|
||||||
|
public:
|
||||||
|
PacketsEqMatcher(PacketContainer packets) : packets_(packets) {}
|
||||||
|
void DescribeTo(::std::ostream* os) const override {
|
||||||
|
*os << "The expected packet contents: \n";
|
||||||
|
Print(packets_, os);
|
||||||
|
}
|
||||||
|
bool MatchAndExplain(
|
||||||
|
const PacketContainer& value,
|
||||||
|
::testing::MatchResultListener* listener) const override {
|
||||||
|
if (!Equals(packets_, value)) {
|
||||||
|
if (listener->IsInterested()) {
|
||||||
|
*listener << "The actual packet contents: \n";
|
||||||
|
Print(value, listener->stream());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool Equals(const PacketContainer& c1, const PacketContainer& c2) const {
|
||||||
|
if (c1.size() != c2.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) {
|
||||||
|
Packet p1 = *i1, p2 = *i2;
|
||||||
|
if (p1.Timestamp() != p2.Timestamp() ||
|
||||||
|
p1.Get<PacketContent>() != p2.Get<PacketContent>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void Print(const PacketContainer& packets, ::std::ostream* os) const {
|
||||||
|
for (auto it = packets.begin(); it != packets.end(); ++it) {
|
||||||
|
const Packet& packet = *it;
|
||||||
|
*os << (it == packets.begin() ? "{" : "") << "{"
|
||||||
|
<< SourceString(packet.Timestamp()) << ", "
|
||||||
|
<< packet.Get<PacketContent>() << "}"
|
||||||
|
<< (std::next(it) == packets.end() ? "}" : ", ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const PacketContainer packets_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename PacketContainer, typename PacketContent>
|
||||||
|
::testing::Matcher<const PacketContainer&> PackestEq(
|
||||||
|
const PacketContainer& packets) {
|
||||||
|
return MakeMatcher(
|
||||||
|
new PacketsEqMatcher<PacketContainer, PacketContent>(packets));
|
||||||
|
}
|
||||||
|
|
||||||
// A Calculator::Process callback function.
|
// A Calculator::Process callback function.
|
||||||
typedef std::function<absl::Status(const InputStreamShardSet&,
|
typedef std::function<absl::Status(const InputStreamShardSet&,
|
||||||
OutputStreamShardSet*)>
|
OutputStreamShardSet*)>
|
||||||
|
@ -651,11 +724,12 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
|
||||||
input_packets_[17], input_packets_[19], input_packets_[20],
|
input_packets_[17], input_packets_[19], input_packets_[20],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_1_packets_, expected_output);
|
EXPECT_EQ(out_1_packets_, expected_output);
|
||||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
// The timestamps released by FlowLimiterCalculator for in_1_sampled,
|
||||||
|
// plus input_packets_[21].
|
||||||
std::vector<Packet> expected_output_2 = {
|
std::vector<Packet> expected_output_2 = {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||||
input_packets_[14], input_packets_[17], input_packets_[19],
|
input_packets_[14], input_packets_[17], input_packets_[19],
|
||||||
input_packets_[20],
|
input_packets_[20], input_packets_[21],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_2_packets, expected_output_2);
|
EXPECT_EQ(out_2_packets, expected_output_2);
|
||||||
}
|
}
|
||||||
|
@ -665,6 +739,9 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) {
|
||||||
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
// The processing time "sleep_time" is reduced from 22ms to 12ms to create
|
||||||
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
// the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams.
|
||||||
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
|
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
|
||||||
|
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
|
||||||
|
|
||||||
// Configure the test.
|
// Configure the test.
|
||||||
SetUpInputData();
|
SetUpInputData();
|
||||||
SetUpSimulationClock();
|
SetUpSimulationClock();
|
||||||
|
@ -699,10 +776,9 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
}
|
}
|
||||||
)pb");
|
)pb");
|
||||||
|
|
||||||
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(R"pb(
|
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
|
||||||
max_in_flight: 1
|
R"pb(
|
||||||
max_in_queue: 0
|
max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 100000 # 100 ms
|
||||||
in_flight_timeout: 100000 # 100 ms
|
|
||||||
)pb");
|
)pb");
|
||||||
std::map<std::string, Packet> side_packets = {
|
std::map<std::string, Packet> side_packets = {
|
||||||
{"limiter_options",
|
{"limiter_options",
|
||||||
|
@ -759,13 +835,131 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[15],
|
input_packets_[0], input_packets_[2], input_packets_[15],
|
||||||
input_packets_[17], input_packets_[19],
|
input_packets_[17], input_packets_[19],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_1_packets_, expected_output);
|
EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
|
||||||
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
// Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled.
|
||||||
std::vector<Packet> expected_output_2 = {
|
std::vector<Packet> expected_output_2 = {
|
||||||
input_packets_[0], input_packets_[2], input_packets_[4],
|
input_packets_[0], input_packets_[2], input_packets_[4],
|
||||||
input_packets_[15], input_packets_[17], input_packets_[19],
|
input_packets_[15], input_packets_[17], input_packets_[19],
|
||||||
};
|
};
|
||||||
EXPECT_EQ(out_2_packets, expected_output_2);
|
EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
|
||||||
|
|
||||||
|
// Validate the ALLOW stream output.
|
||||||
|
std::vector<Packet> expected_allow = MakePackets<bool>( //
|
||||||
|
{{Timestamp(0), true}, {Timestamp(10000), false},
|
||||||
|
{Timestamp(20000), true}, {Timestamp(30000), false},
|
||||||
|
{Timestamp(40000), true}, {Timestamp(50000), false},
|
||||||
|
{Timestamp(60000), false}, {Timestamp(70000), false},
|
||||||
|
{Timestamp(80000), false}, {Timestamp(90000), false},
|
||||||
|
{Timestamp(100000), false}, {Timestamp(110000), false},
|
||||||
|
{Timestamp(120000), false}, {Timestamp(130000), false},
|
||||||
|
{Timestamp(140000), false}, {Timestamp(150000), true},
|
||||||
|
{Timestamp(160000), false}, {Timestamp(170000), true},
|
||||||
|
{Timestamp(180000), false}, {Timestamp(190000), true},
|
||||||
|
{Timestamp(200000), false}});
|
||||||
|
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shows how FlowLimiterCalculator releases auxiliary input packets.
|
||||||
|
// In this test, auxiliary input packets arrive at twice the primary rate.
|
||||||
|
TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) {
|
||||||
|
auto BoolPackestEq = PackestEq<std::vector<Packet>, bool>;
|
||||||
|
auto IntPackestEq = PackestEq<std::vector<Packet>, int>;
|
||||||
|
|
||||||
|
// Configure the test.
|
||||||
|
SetUpInputData();
|
||||||
|
SetUpSimulationClock();
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'in_1'
|
||||||
|
input_stream: 'in_2'
|
||||||
|
node {
|
||||||
|
calculator: 'FlowLimiterCalculator'
|
||||||
|
input_side_packet: 'OPTIONS:limiter_options'
|
||||||
|
input_stream: 'in_1'
|
||||||
|
input_stream: 'in_2'
|
||||||
|
input_stream: 'FINISHED:out_1'
|
||||||
|
input_stream_info: { tag_index: 'FINISHED' back_edge: true }
|
||||||
|
output_stream: 'in_1_sampled'
|
||||||
|
output_stream: 'in_2_sampled'
|
||||||
|
output_stream: 'ALLOW:allow'
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: 'SleepCalculator'
|
||||||
|
input_side_packet: 'WARMUP_TIME:warmup_time'
|
||||||
|
input_side_packet: 'SLEEP_TIME:sleep_time'
|
||||||
|
input_side_packet: 'CLOCK:clock'
|
||||||
|
input_stream: 'PACKET:in_1_sampled'
|
||||||
|
output_stream: 'PACKET:out_1'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto limiter_options = ParseTextProtoOrDie<FlowLimiterCalculatorOptions>(
|
||||||
|
R"pb(
|
||||||
|
max_in_flight: 1 max_in_queue: 0 in_flight_timeout: 1000000 # 1s
|
||||||
|
)pb");
|
||||||
|
std::map<std::string, Packet> side_packets = {
|
||||||
|
{"limiter_options",
|
||||||
|
MakePacket<FlowLimiterCalculatorOptions>(limiter_options)},
|
||||||
|
{"warmup_time", MakePacket<int64>(22000)},
|
||||||
|
{"sleep_time", MakePacket<int64>(22000)},
|
||||||
|
{"clock", MakePacket<mediapipe::Clock*>(clock_)},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start the graph.
|
||||||
|
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||||
|
MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) {
|
||||||
|
out_1_packets_.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
std::vector<Packet> out_2_packets;
|
||||||
|
MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) {
|
||||||
|
out_2_packets.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) {
|
||||||
|
allow_packets_.push_back(p);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}));
|
||||||
|
simulation_clock_->ThreadStart();
|
||||||
|
MP_ASSERT_OK(graph_.StartRun(side_packets));
|
||||||
|
|
||||||
|
// Add packets 2,4,6,8 to stream in_1 and 1..9 to stream in_2.
|
||||||
|
clock_->Sleep(absl::Microseconds(10000));
|
||||||
|
for (int i = 1; i < 10; ++i) {
|
||||||
|
if (i % 2 == 0) {
|
||||||
|
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_1", input_packets_[i]));
|
||||||
|
}
|
||||||
|
MP_EXPECT_OK(graph_.AddPacketToInputStream("in_2", input_packets_[i]));
|
||||||
|
clock_->Sleep(absl::Microseconds(10000));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish the graph run.
|
||||||
|
MP_EXPECT_OK(graph_.CloseAllPacketSources());
|
||||||
|
clock_->Sleep(absl::Microseconds(40000));
|
||||||
|
MP_EXPECT_OK(graph_.WaitUntilDone());
|
||||||
|
simulation_clock_->ThreadFinish();
|
||||||
|
|
||||||
|
// Validate the output.
|
||||||
|
// Input packets 4 and 8 are dropped due to max_in_flight.
|
||||||
|
std::vector<Packet> expected_output = {
|
||||||
|
input_packets_[2],
|
||||||
|
input_packets_[6],
|
||||||
|
};
|
||||||
|
EXPECT_THAT(out_1_packets_, IntPackestEq(expected_output));
|
||||||
|
// Packets following input packets 2 and 6, and not input packets 4 and 8.
|
||||||
|
std::vector<Packet> expected_output_2 = {
|
||||||
|
input_packets_[1], input_packets_[2], input_packets_[3],
|
||||||
|
input_packets_[6], input_packets_[7],
|
||||||
|
};
|
||||||
|
EXPECT_THAT(out_2_packets, IntPackestEq(expected_output_2));
|
||||||
|
|
||||||
|
// Validate the ALLOW stream output.
|
||||||
|
std::vector<Packet> expected_allow =
|
||||||
|
MakePackets<bool>({{Timestamp(20000), 1},
|
||||||
|
{Timestamp(40000), 0},
|
||||||
|
{Timestamp(60000), 1},
|
||||||
|
{Timestamp(80000), 0}});
|
||||||
|
EXPECT_THAT(allow_packets_, BoolPackestEq(expected_allow));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
|
@ -25,12 +25,12 @@
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/util/android/file/base/file.h"
|
#include "mediapipe/util/android/file/base/file.h"
|
||||||
#include "mediapipe/util/android/file/base/filesystem.h"
|
#include "mediapipe/util/android/file/base/filesystem.h"
|
||||||
#include "mediapipe/util/android/file/base/helpers.h"
|
#include "mediapipe/util/android/file/base/helpers.h"
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner(
|
||||||
return tflite_gpu_runner_->Build();
|
return tflite_gpu_runner_->Build();
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
|
absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init(
|
||||||
const mediapipe::InferenceCalculatorOptions& options,
|
const mediapipe::InferenceCalculatorOptions& options,
|
||||||
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
|
const mediapipe::InferenceCalculatorOptions::Delegate::Gpu&
|
||||||
|
@ -318,7 +318,7 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches(
|
||||||
tflite::gpu::TFLiteGPURunner* gpu_runner) const {
|
tflite::gpu::TFLiteGPURunner* gpu_runner) const {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
#endif // MEDIAPIPE_ANDROID
|
#endif // defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS)
|
||||||
|
|
||||||
absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
|
absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
|
||||||
CalculatorContract* cc) {
|
CalculatorContract* cc) {
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# Description:
|
# Description:
|
||||||
# The dependencies of mediapipe.
|
# The dependencies of mediapipe.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
@ -38,7 +38,7 @@ bzl_library(
|
||||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||||
)
|
)
|
||||||
|
|
||||||
proto_library(
|
mediapipe_proto_library(
|
||||||
name = "proto_descriptor_proto",
|
name = "proto_descriptor_proto",
|
||||||
srcs = ["proto_descriptor.proto"],
|
srcs = ["proto_descriptor.proto"],
|
||||||
visibility = [
|
visibility = [
|
||||||
|
@ -47,13 +47,6 @@ proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_cc_proto_library(
|
|
||||||
name = "proto_descriptor_cc_proto",
|
|
||||||
srcs = ["proto_descriptor.proto"],
|
|
||||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
|
||||||
deps = [":proto_descriptor_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "aligned_malloc_and_free",
|
name = "aligned_malloc_and_free",
|
||||||
hdrs = ["aligned_malloc_and_free.h"],
|
hdrs = ["aligned_malloc_and_free.h"],
|
||||||
|
|
|
@ -4,6 +4,9 @@
|
||||||
""".bzl file for mediapipe open source build configs."""
|
""".bzl file for mediapipe open source build configs."""
|
||||||
|
|
||||||
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library")
|
load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library")
|
||||||
|
load("@npm//@bazel/typescript:index.bzl", "ts_project")
|
||||||
|
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||||
|
load("@rules_proto_grpc//js:defs.bzl", "js_proto_library")
|
||||||
load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library")
|
load("//mediapipe/framework/tool:mediapipe_graph.bzl", "mediapipe_options_library")
|
||||||
|
|
||||||
def provided_args(**kwargs):
|
def provided_args(**kwargs):
|
||||||
|
@ -71,7 +74,7 @@ def mediapipe_proto_library(
|
||||||
def_jspb_proto: define the jspb_proto_library target
|
def_jspb_proto: define the jspb_proto_library target
|
||||||
def_options_lib: define the mediapipe_options_library target
|
def_options_lib: define the mediapipe_options_library target
|
||||||
"""
|
"""
|
||||||
_ignore = [def_portable_proto, def_objc_proto, def_java_proto, def_jspb_proto, portable_deps]
|
_ignore = [def_portable_proto, def_objc_proto, def_java_proto, portable_deps] # buildifier: disable=unused-variable
|
||||||
|
|
||||||
# The proto_library targets for the compiled ".proto" source files.
|
# The proto_library targets for the compiled ".proto" source files.
|
||||||
proto_deps = [":" + name]
|
proto_deps = [":" + name]
|
||||||
|
@ -119,6 +122,24 @@ def mediapipe_proto_library(
|
||||||
compatible_with = compatible_with,
|
compatible_with = compatible_with,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
if def_jspb_proto:
|
||||||
|
js_deps = replace_deps(deps, "_proto", "_jspb_proto", False)
|
||||||
|
proto_library(
|
||||||
|
name = replace_suffix(name, "_proto", "_lib_proto"),
|
||||||
|
srcs = srcs,
|
||||||
|
deps = deps,
|
||||||
|
)
|
||||||
|
js_proto_library(
|
||||||
|
name = replace_suffix(name, "_proto", "_jspb_proto"),
|
||||||
|
protos = [replace_suffix(name, "_proto", "_lib_proto")],
|
||||||
|
output_mode = "NO_PREFIX_FLAT",
|
||||||
|
# Need to specify this to work around bug in js_proto_library()
|
||||||
|
# https://github.com/bazelbuild/rules_nodejs/issues/3503
|
||||||
|
legacy_path = "unused",
|
||||||
|
deps = js_deps,
|
||||||
|
visibility = visibility,
|
||||||
|
)
|
||||||
|
|
||||||
if def_options_lib:
|
if def_options_lib:
|
||||||
cc_deps = replace_deps(deps, "_proto", "_cc_proto")
|
cc_deps = replace_deps(deps, "_proto", "_cc_proto")
|
||||||
mediapipe_options_library(**provided_args(
|
mediapipe_options_library(**provided_args(
|
||||||
|
@ -182,3 +203,35 @@ def mediapipe_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps
|
||||||
default_runtime = "@com_google_protobuf//:protobuf",
|
default_runtime = "@com_google_protobuf//:protobuf",
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def mediapipe_ts_library(
|
||||||
|
name,
|
||||||
|
srcs,
|
||||||
|
visibility = None,
|
||||||
|
deps = [],
|
||||||
|
testonly = 0,
|
||||||
|
allow_unoptimized_namespaces = False):
|
||||||
|
"""Generate ts_project for MediaPipe open source version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name of the cc_proto_library.
|
||||||
|
srcs: the .proto files of the cc_proto_library for Bazel use.
|
||||||
|
visibility: visibility of this target.
|
||||||
|
deps: a list of dependency labels for Bazel use; must be cc_proto_library.
|
||||||
|
testonly: test only or not.
|
||||||
|
allow_unoptimized_namespaces: ignored, used only internally
|
||||||
|
"""
|
||||||
|
_ignore = [allow_unoptimized_namespaces] # buildifier: disable=unused-variable
|
||||||
|
|
||||||
|
ts_project(**provided_args(
|
||||||
|
name = name,
|
||||||
|
srcs = srcs,
|
||||||
|
visibility = visibility,
|
||||||
|
deps = deps + [
|
||||||
|
"@npm//@types/offscreencanvas",
|
||||||
|
"@npm//@types/google-protobuf",
|
||||||
|
],
|
||||||
|
testonly = testonly,
|
||||||
|
declaration = True,
|
||||||
|
tsconfig = "//:tsconfig.json",
|
||||||
|
))
|
||||||
|
|
|
@ -52,6 +52,7 @@ class CustomModel(abc.ABC):
|
||||||
"""Prints a summary of the model."""
|
"""Prints a summary of the model."""
|
||||||
self._model.summary()
|
self._model.summary()
|
||||||
|
|
||||||
|
# TODO: Remove this method when all tasks use Metadata writer
|
||||||
def export_tflite(
|
def export_tflite(
|
||||||
self,
|
self,
|
||||||
export_dir: str,
|
export_dir: str,
|
||||||
|
@ -62,7 +63,7 @@ class CustomModel(abc.ABC):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
export_dir: The directory to save exported files.
|
export_dir: The directory to save exported files.
|
||||||
tflite_filename: File name to save tflite model. The full export path is
|
tflite_filename: File name to save TFLite model. The full export path is
|
||||||
{export_dir}/{tflite_filename}.
|
{export_dir}/{tflite_filename}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
preprocess: A callable to preprocess the representative dataset for
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
|
@ -73,11 +74,11 @@ class CustomModel(abc.ABC):
|
||||||
tf.io.gfile.makedirs(export_dir)
|
tf.io.gfile.makedirs(export_dir)
|
||||||
|
|
||||||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||||
# TODO: Populate metadata to the exported TFLite model.
|
tflite_model = model_util.convert_to_tflite(
|
||||||
model_util.export_tflite(
|
|
||||||
model=self._model,
|
model=self._model,
|
||||||
tflite_filepath=tflite_filepath,
|
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
preprocess=preprocess)
|
preprocess=preprocess)
|
||||||
|
model_util.save_tflite(
|
||||||
|
tflite_model=tflite_model, tflite_file=tflite_filepath)
|
||||||
tf.compat.v1.logging.info(
|
tf.compat.v1.logging.info(
|
||||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||||
|
|
|
@ -89,28 +89,25 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
return len(train_data) // batch_size
|
return len(train_data) // batch_size
|
||||||
|
|
||||||
|
|
||||||
def export_tflite(
|
def convert_to_tflite(
|
||||||
model: tf.keras.Model,
|
model: tf.keras.Model,
|
||||||
tflite_filepath: str,
|
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
supported_ops: Tuple[tf.lite.OpsSet,
|
supported_ops: Tuple[tf.lite.OpsSet,
|
||||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
||||||
preprocess: Optional[Callable[..., bool]] = None):
|
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
||||||
"""Converts the model to tflite format and saves it.
|
"""Converts the input Keras model to TFLite format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: model to be converted to tflite.
|
model: Keras model to be converted to TFLite.
|
||||||
tflite_filepath: File path to save tflite model.
|
|
||||||
quantization_config: Configuration for post-training quantization.
|
quantization_config: Configuration for post-training quantization.
|
||||||
supported_ops: A list of supported ops in the converted TFLite file.
|
supported_ops: A list of supported ops in the converted TFLite file.
|
||||||
preprocess: A callable to preprocess the representative dataset for
|
preprocess: A callable to preprocess the representative dataset for
|
||||||
quantization. The callable takes three arguments in order: feature, label,
|
quantization. The callable takes three arguments in order: feature, label,
|
||||||
and is_training.
|
and is_training.
|
||||||
"""
|
|
||||||
if tflite_filepath is None:
|
|
||||||
raise ValueError(
|
|
||||||
"TFLite filepath couldn't be None when exporting to tflite.")
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytearray of TFLite model
|
||||||
|
"""
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
save_path = os.path.join(temp_dir, 'saved_model')
|
save_path = os.path.join(temp_dir, 'saved_model')
|
||||||
model.save(save_path, include_optimizer=False, save_format='tf')
|
model.save(save_path, include_optimizer=False, save_format='tf')
|
||||||
|
@ -122,9 +119,22 @@ def export_tflite(
|
||||||
|
|
||||||
converter.target_spec.supported_ops = supported_ops
|
converter.target_spec.supported_ops = supported_ops
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
return tflite_model
|
||||||
|
|
||||||
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
|
|
||||||
|
def save_tflite(tflite_model: bytearray, tflite_file: str) -> None:
|
||||||
|
"""Saves TFLite file to tflite_file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tflite_model: A valid flatbuffer representing the TFLite model.
|
||||||
|
tflite_file: File path to save TFLite model.
|
||||||
|
"""
|
||||||
|
if tflite_file is None:
|
||||||
|
raise ValueError("TFLite filepath can't be None when exporting to TFLite.")
|
||||||
|
with tf.io.gfile.GFile(tflite_file, 'wb') as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
|
tf.compat.v1.logging.info(
|
||||||
|
'TensorFlow Lite model exported successfully to: %s' % tflite_file)
|
||||||
|
|
||||||
|
|
||||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
|
@ -176,14 +186,12 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
class LiteRunner(object):
|
class LiteRunner(object):
|
||||||
"""A runner to do inference with the TFLite model."""
|
"""A runner to do inference with the TFLite model."""
|
||||||
|
|
||||||
def __init__(self, tflite_filepath: str):
|
def __init__(self, tflite_model: bytearray):
|
||||||
"""Initializes Lite runner with tflite model file.
|
"""Initializes Lite runner from TFLite model buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tflite_filepath: File path to the TFLite model.
|
tflite_model: A valid flatbuffer representing the TFLite model.
|
||||||
"""
|
"""
|
||||||
with tf.io.gfile.GFile(tflite_filepath, 'rb') as f:
|
|
||||||
tflite_model = f.read()
|
|
||||||
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
|
||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
self.input_details = self.interpreter.get_input_details()
|
self.input_details = self.interpreter.get_input_details()
|
||||||
|
@ -250,9 +258,9 @@ class LiteRunner(object):
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
|
|
||||||
def get_lite_runner(tflite_filepath: str) -> 'LiteRunner':
|
def get_lite_runner(tflite_buffer: bytearray) -> 'LiteRunner':
|
||||||
"""Returns a `LiteRunner` from file path to TFLite model."""
|
"""Returns a `LiteRunner` from flatbuffer of the TFLite model."""
|
||||||
lite_runner = LiteRunner(tflite_filepath)
|
lite_runner = LiteRunner(tflite_buffer)
|
||||||
return lite_runner
|
return lite_runner
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -95,13 +95,12 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
'name': 'test'
|
'name': 'test'
|
||||||
})
|
})
|
||||||
|
|
||||||
def test_export_tflite(self):
|
def test_convert_to_tflite(self):
|
||||||
input_dim = 4
|
input_dim = 4
|
||||||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
model_util.export_tflite(model, tflite_file)
|
|
||||||
test_util.test_tflite(
|
test_util.test_tflite(
|
||||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
keras_model=model, tflite_model=tflite_model, size=[1, input_dim])
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
dict(
|
dict(
|
||||||
|
@ -118,25 +117,32 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
testcase_name='float16_quantize',
|
testcase_name='float16_quantize',
|
||||||
config=quantization.QuantizationConfig.for_float16(),
|
config=quantization.QuantizationConfig.for_float16(),
|
||||||
model_size=1468))
|
model_size=1468))
|
||||||
def test_export_tflite_quantized(self, config, model_size):
|
def test_convert_to_tflite_quantized(self, config, model_size):
|
||||||
input_dim = 16
|
input_dim = 16
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
max_input_value = 5
|
max_input_value = 5
|
||||||
model = test_util.build_model(
|
model = test_util.build_model(
|
||||||
input_shape=[input_dim], num_classes=num_classes)
|
input_shape=[input_dim], num_classes=num_classes)
|
||||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
|
||||||
|
|
||||||
model_util.export_tflite(
|
tflite_model = model_util.convert_to_tflite(
|
||||||
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
model=model, quantization_config=config)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
test_util.test_tflite(
|
test_util.test_tflite(
|
||||||
keras_model=model,
|
keras_model=model,
|
||||||
tflite_file=tflite_file,
|
tflite_model=tflite_model,
|
||||||
size=[1, input_dim],
|
size=[1, input_dim],
|
||||||
high=max_input_value,
|
high=max_input_value,
|
||||||
atol=1e-00))
|
atol=1e-00))
|
||||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
self.assertNear(len(tflite_model), model_size, 300)
|
||||||
|
|
||||||
|
def test_save_tflite(self):
|
||||||
|
input_dim = 4
|
||||||
|
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||||
|
tflite_model = model_util.convert_to_tflite(model)
|
||||||
|
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||||
|
model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file)
|
||||||
|
test_util.test_tflite_file(
|
||||||
|
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
|
@ -79,13 +79,13 @@ def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def is_same_output(tflite_file: str,
|
def is_same_output(tflite_model: bytearray,
|
||||||
keras_model: tf.keras.Model,
|
keras_model: tf.keras.Model,
|
||||||
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
input_tensors: Union[List[tf.Tensor], tf.Tensor],
|
||||||
atol: float = 1e-04) -> bool:
|
atol: float = 1e-04) -> bool:
|
||||||
"""Returns if the output of TFLite model and keras model are identical."""
|
"""Returns if the output of TFLite model and keras model are identical."""
|
||||||
# Gets output from lite model.
|
# Gets output from lite model.
|
||||||
lite_runner = model_util.get_lite_runner(tflite_file)
|
lite_runner = model_util.get_lite_runner(tflite_model)
|
||||||
lite_output = lite_runner.run(input_tensors)
|
lite_output = lite_runner.run(input_tensors)
|
||||||
|
|
||||||
# Gets output from keras model.
|
# Gets output from keras model.
|
||||||
|
@ -95,7 +95,36 @@ def is_same_output(tflite_file: str,
|
||||||
|
|
||||||
|
|
||||||
def test_tflite(keras_model: tf.keras.Model,
|
def test_tflite(keras_model: tf.keras.Model,
|
||||||
tflite_file: str,
|
tflite_model: bytearray,
|
||||||
|
size: Union[int, List[int]],
|
||||||
|
high: float = 1,
|
||||||
|
atol: float = 1e-04) -> bool:
|
||||||
|
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keras_model: Input TensorFlow Keras model.
|
||||||
|
tflite_model: Input TFLite model flatbuffer.
|
||||||
|
size: Size of the input tesnor.
|
||||||
|
high: Higher boundary of the values in input tensors.
|
||||||
|
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||||
|
model and TFLite model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the output of TFLite model and TF Keras model are identical.
|
||||||
|
Otherwise, False.
|
||||||
|
"""
|
||||||
|
random_input = create_random_sample(size=size, high=high)
|
||||||
|
random_input = tf.convert_to_tensor(random_input)
|
||||||
|
|
||||||
|
return is_same_output(
|
||||||
|
tflite_model=tflite_model,
|
||||||
|
keras_model=keras_model,
|
||||||
|
input_tensors=random_input,
|
||||||
|
atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tflite_file(keras_model: tf.keras.Model,
|
||||||
|
tflite_file: bytearray,
|
||||||
size: Union[int, List[int]],
|
size: Union[int, List[int]],
|
||||||
high: float = 1,
|
high: float = 1,
|
||||||
atol: float = 1e-04) -> bool:
|
atol: float = 1e-04) -> bool:
|
||||||
|
@ -113,11 +142,6 @@ def test_tflite(keras_model: tf.keras.Model,
|
||||||
True if the output of TFLite model and TF Keras model are identical.
|
True if the output of TFLite model and TF Keras model are identical.
|
||||||
Otherwise, False.
|
Otherwise, False.
|
||||||
"""
|
"""
|
||||||
random_input = create_random_sample(size=size, high=high)
|
with tf.io.gfile.GFile(tflite_file, "rb") as f:
|
||||||
random_input = tf.convert_to_tensor(random_input)
|
tflite_model = f.read()
|
||||||
|
return test_tflite(keras_model, tflite_model, size, high, atol)
|
||||||
return is_same_output(
|
|
||||||
tflite_file=tflite_file,
|
|
||||||
keras_model=keras_model,
|
|
||||||
input_tensors=random_input,
|
|
||||||
atol=atol)
|
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
# MediaPipe Model Maker Internal Library
|
|
||||||
|
|
||||||
This directory contains model maker library for internal users and experimental
|
|
||||||
purposes.
|
|
|
@ -1 +0,0 @@
|
||||||
"""Model maker internal library."""
|
|
|
@ -81,6 +81,8 @@ py_library(
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:image_classifier",
|
||||||
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -88,7 +90,11 @@ py_library(
|
||||||
name = "image_classifier_test_lib",
|
name = "image_classifier_test_lib",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
srcs = ["image_classifier_test.py"],
|
srcs = ["image_classifier_test.py"],
|
||||||
deps = [":image_classifier_import"],
|
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
|
||||||
|
deps = [
|
||||||
|
":image_classifier_import",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""APIs to train image classifier model."""
|
"""APIs to train image classifier model."""
|
||||||
|
import os
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
@ -26,6 +27,8 @@ from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import image_classifier as image_classifier_writer
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
|
||||||
|
|
||||||
class ImageClassifier(classifier.Classifier):
|
class ImageClassifier(classifier.Classifier):
|
||||||
|
@ -156,15 +159,32 @@ class ImageClassifier(classifier.Classifier):
|
||||||
self,
|
self,
|
||||||
model_name: str = 'model.tflite',
|
model_name: str = 'model.tflite',
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
quantization_config: Optional[quantization.QuantizationConfig] = None):
|
||||||
"""Converts the model to the requested formats and exports to a file.
|
"""Converts and saves the model to a TFLite file with metadata included.
|
||||||
|
|
||||||
|
Note that only the TFLite file is needed for deployment. This function also
|
||||||
|
saves a metadata.json file to the same directory as the TFLite file which
|
||||||
|
can be used to interpret the metadata content in the TFLite file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: File name to save tflite model. The full export path is
|
model_name: File name to save TFLite model with metadata. The full export
|
||||||
{export_dir}/{tflite_filename}.
|
path is {self._hparams.model_dir}/{model_name}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
"""
|
"""
|
||||||
super().export_tflite(
|
if not tf.io.gfile.exists(self._hparams.model_dir):
|
||||||
self._hparams.model_dir,
|
tf.io.gfile.makedirs(self._hparams.model_dir)
|
||||||
model_name,
|
tflite_file = os.path.join(self._hparams.model_dir, model_name)
|
||||||
quantization_config,
|
metadata_file = os.path.join(self._hparams.model_dir, 'metadata.json')
|
||||||
|
|
||||||
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
model=self._model,
|
||||||
|
quantization_config=quantization_config,
|
||||||
preprocess=self._preprocess)
|
preprocess=self._preprocess)
|
||||||
|
writer = image_classifier_writer.MetadataWriter.create(
|
||||||
|
tflite_model,
|
||||||
|
self._model_spec.mean_rgb,
|
||||||
|
self._model_spec.stddev_rgb,
|
||||||
|
labels=metadata_writer.Labels().add(self._label_names))
|
||||||
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
|
with open(metadata_file, 'w') as f:
|
||||||
|
f.write(metadata_json)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import filecmp
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
@ -19,6 +20,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision import image_classifier
|
from mediapipe.model_maker.python.vision import image_classifier
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
def _fill_image(rgb, image_size):
|
def _fill_image(rgb, image_size):
|
||||||
|
@ -86,7 +88,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
validation_data=self.test_data)
|
validation_data=self.test_data)
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
def test_efficientnetlite0_model_with_model_maker_retraining_lib(self):
|
def test_efficientnetlite0_model_train_and_export(self):
|
||||||
hparams = image_classifier.HParams(
|
hparams = image_classifier.HParams(
|
||||||
train_epochs=1, batch_size=1, shuffle=True)
|
train_epochs=1, batch_size=1, shuffle=True)
|
||||||
model = image_classifier.ImageClassifier.create(
|
model = image_classifier.ImageClassifier.create(
|
||||||
|
@ -96,6 +98,19 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
validation_data=self.test_data)
|
validation_data=self.test_data)
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
# Test export_model
|
||||||
|
model.export_model()
|
||||||
|
output_metadata_file = os.path.join(hparams.model_dir, 'metadata.json')
|
||||||
|
output_tflite_file = os.path.join(hparams.model_dir, 'model.tflite')
|
||||||
|
expected_metadata_file = test_utils.get_test_data_path('metadata.json')
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_tflite_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_tflite_file), 0)
|
||||||
|
|
||||||
|
self.assertTrue(os.path.exists(output_metadata_file))
|
||||||
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file))
|
||||||
|
|
||||||
def _test_accuracy(self, model, threshold=0.0):
|
def _test_accuracy(self, model, threshold=0.0):
|
||||||
_, accuracy = model.evaluate(self.test_data)
|
_, accuracy = model.evaluate(self.test_data)
|
||||||
self.assertGreaterEqual(accuracy, threshold)
|
self.assertGreaterEqual(accuracy, threshold)
|
||||||
|
|
23
mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD
vendored
Normal file
23
mediapipe/model_maker/python/vision/image_classifier/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//mediapipe/model_maker/python/vision/image_classifier:__subpackages__"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = ["metadata.json"],
|
||||||
|
)
|
68
mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json
vendored
Normal file
68
mediapipe/model_maker/python/vision/image_classifier/testdata/metadata.json
vendored
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
{
|
||||||
|
"name": "ImageClassifier",
|
||||||
|
"description": "Identify the most prominent object in the image from a known set of categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "image",
|
||||||
|
"description": "Input image to be processed.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "ImageProperties",
|
||||||
|
"content_properties": {
|
||||||
|
"color_space": "RGB"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "NormalizationOptions",
|
||||||
|
"options": {
|
||||||
|
"mean": [
|
||||||
|
0.0
|
||||||
|
],
|
||||||
|
"std": [
|
||||||
|
255.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.0.0"
|
||||||
|
}
|
|
@ -87,6 +87,7 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "builtin_task_graphs",
|
name = "builtin_task_graphs",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
|
|
|
@ -602,8 +602,11 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
// TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
|
// TODO: Should take "const Eigen::Ref<const Eigen::MatrixXf>&"
|
||||||
// as the input argument. Investigate why bazel non-optimized mode
|
// as the input argument. Investigate why bazel non-optimized mode
|
||||||
// triggers a memory allocation bug in Eigen::internal::aligned_free().
|
// triggers a memory allocation bug in Eigen::internal::aligned_free().
|
||||||
[](const Eigen::MatrixXf& matrix) {
|
[](const Eigen::MatrixXf& matrix, bool transpose) {
|
||||||
// MakePacket copies the data.
|
// MakePacket copies the data.
|
||||||
|
if (transpose) {
|
||||||
|
return MakePacket<Matrix>(matrix.transpose());
|
||||||
|
}
|
||||||
return MakePacket<Matrix>(matrix);
|
return MakePacket<Matrix>(matrix);
|
||||||
},
|
},
|
||||||
R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
|
R"doc(Create a MediaPipe Matrix Packet from a 2d numpy float ndarray.
|
||||||
|
@ -613,6 +616,8 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
matrix: A 2d numpy float ndarray.
|
matrix: A 2d numpy float ndarray.
|
||||||
|
transpose: A boolean to indicate if the input matrix needs to be transposed.
|
||||||
|
Default to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A MediaPipe Matrix Packet.
|
A MediaPipe Matrix Packet.
|
||||||
|
@ -625,6 +630,7 @@ void PublicPacketCreators(pybind11::module* m) {
|
||||||
np.array([[.1, .2, .3], [.4, .5, .6]])
|
np.array([[.1, .2, .3], [.4, .5, .6]])
|
||||||
matrix = mp.packet_getter.get_matrix(packet)
|
matrix = mp.packet_getter.get_matrix(packet)
|
||||||
)doc",
|
)doc",
|
||||||
|
py::arg("matrix"), py::arg("transpose") = false,
|
||||||
py::return_value_policy::move);
|
py::return_value_policy::move);
|
||||||
} // NOLINT(readability/fn_size)
|
} // NOLINT(readability/fn_size)
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
|
|
|
@ -18,12 +18,14 @@ limitations under the License.
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
|
@ -38,12 +40,16 @@ namespace audio_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr char kAudioStreamName[] = "audio_in";
|
constexpr char kAudioStreamName[] = "audio_in";
|
||||||
constexpr char kAudioTag[] = "AUDIO";
|
constexpr char kAudioTag[] = "AUDIO";
|
||||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationsName[] = "classifications_out";
|
||||||
|
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
|
constexpr char kTimestampedClassificationsName[] =
|
||||||
|
"timestamped_classifications_out";
|
||||||
constexpr char kSampleRateName[] = "sample_rate_in";
|
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
|
@ -63,9 +69,11 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
}
|
}
|
||||||
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||||
options_proto.get());
|
options_proto.get());
|
||||||
subgraph.Out(kClassificationResultTag)
|
subgraph.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||||
.SetName(kClassificationResultStreamName) >>
|
graph.Out(kClassificationsTag);
|
||||||
graph.Out(kClassificationResultTag);
|
subgraph.Out(kTimestampedClassificationsTag)
|
||||||
|
.SetName(kTimestampedClassificationsName) >>
|
||||||
|
graph.Out(kTimestampedClassificationsTag);
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,13 +99,30 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ClassificationResult> ConvertOutputPackets(
|
absl::StatusOr<std::vector<AudioClassifierResult>> ConvertOutputPackets(
|
||||||
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
if (!status_or_packets.ok()) {
|
if (!status_or_packets.ok()) {
|
||||||
return status_or_packets.status();
|
return status_or_packets.status();
|
||||||
}
|
}
|
||||||
return status_or_packets.value()[kClassificationResultStreamName]
|
auto classification_results =
|
||||||
.Get<ClassificationResult>();
|
status_or_packets.value()[kTimestampedClassificationsName]
|
||||||
|
.Get<std::vector<ClassificationResult>>();
|
||||||
|
std::vector<AudioClassifierResult> results;
|
||||||
|
results.reserve(classification_results.size());
|
||||||
|
for (const auto& classification_result : classification_results) {
|
||||||
|
results.emplace_back(ConvertToClassificationResult(classification_result));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<AudioClassifierResult> ConvertAsyncOutputPackets(
|
||||||
|
absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
|
if (!status_or_packets.ok()) {
|
||||||
|
return status_or_packets.status();
|
||||||
|
}
|
||||||
|
return ConvertToClassificationResult(
|
||||||
|
status_or_packets.value()[kClassificationsName]
|
||||||
|
.Get<ClassificationResult>());
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -118,7 +143,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
||||||
auto result_callback = options->result_callback;
|
auto result_callback = options->result_callback;
|
||||||
packets_callback =
|
packets_callback =
|
||||||
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
[=](absl::StatusOr<tasks::core::PacketMap> status_or_packets) {
|
||||||
result_callback(ConvertOutputPackets(status_or_packets));
|
result_callback(ConvertAsyncOutputPackets(status_or_packets));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
||||||
|
@ -128,7 +153,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ClassificationResult> AudioClassifier::Classify(
|
absl::StatusOr<std::vector<AudioClassifierResult>> AudioClassifier::Classify(
|
||||||
Matrix audio_clip, double audio_sample_rate) {
|
Matrix audio_clip, double audio_sample_rate) {
|
||||||
return ConvertOutputPackets(ProcessAudioClip(
|
return ConvertOutputPackets(ProcessAudioClip(
|
||||||
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},
|
{{kAudioStreamName, MakePacket<Matrix>(std::move(audio_clip))},
|
||||||
|
|
|
@ -18,12 +18,13 @@ limitations under the License.
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
|
@ -32,6 +33,10 @@ namespace tasks {
|
||||||
namespace audio {
|
namespace audio {
|
||||||
namespace audio_classifier {
|
namespace audio_classifier {
|
||||||
|
|
||||||
|
// Alias the shared ClassificationResult struct as result type.
|
||||||
|
using AudioClassifierResult =
|
||||||
|
::mediapipe::tasks::components::containers::ClassificationResult;
|
||||||
|
|
||||||
// The options for configuring a mediapipe audio classifier task.
|
// The options for configuring a mediapipe audio classifier task.
|
||||||
struct AudioClassifierOptions {
|
struct AudioClassifierOptions {
|
||||||
// Base options for configuring Task library, such as specifying the TfLite
|
// Base options for configuring Task library, such as specifying the TfLite
|
||||||
|
@ -59,9 +64,8 @@ struct AudioClassifierOptions {
|
||||||
// The user-defined result callback for processing audio stream data.
|
// The user-defined result callback for processing audio stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::AUDIO_STREAM.
|
// to RunningMode::AUDIO_STREAM.
|
||||||
std::function<void(
|
std::function<void(absl::StatusOr<AudioClassifierResult>)> result_callback =
|
||||||
absl::StatusOr<components::containers::proto::ClassificationResult>)>
|
nullptr;
|
||||||
result_callback = nullptr;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Performs audio classification on audio clips or audio stream.
|
// Performs audio classification on audio clips or audio stream.
|
||||||
|
@ -117,23 +121,36 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
||||||
// required to provide the corresponding audio sample rate along with the
|
// required to provide the corresponding audio sample rate along with the
|
||||||
// input audio clips.
|
// input audio clips.
|
||||||
//
|
//
|
||||||
// For each audio clip, the output classifications are grouped in a
|
// The input audio clip may be longer than what the model is able to process
|
||||||
// ClassificationResult object that has three dimensions:
|
// in a single inference. When this occurs, the input audio clip is split into
|
||||||
// Classification head:
|
// multiple chunks starting at different timestamps. For this reason, this
|
||||||
// The prediction heads targeting different audio classification tasks
|
// function returns a vector of ClassificationResult objects, each associated
|
||||||
// such as audio event classification and bird sound classification.
|
// with a timestamp corresponding to the start (in milliseconds) of the chunk
|
||||||
// Classification timestamp:
|
// data that was classified, e.g:
|
||||||
// The start time (in milliseconds) of each audio clip that is sent to the
|
//
|
||||||
// model for audio classification. As the audio classification models take
|
// ClassificationResult #0 (first chunk of data):
|
||||||
// a fixed number of audio samples, long audio clips will be framed to
|
// timestamp_ms: 0 (starts at 0ms)
|
||||||
// multiple buffers (with the desired number of audio samples) during
|
// classifications #0 (single head model):
|
||||||
// preprocessing.
|
// category #0:
|
||||||
// Classification category:
|
// category_name: "Speech"
|
||||||
// The list of the classification categories that model predicts per
|
// score: 0.6
|
||||||
// framed audio clip.
|
// category #1:
|
||||||
|
// category_name: "Music"
|
||||||
|
// score: 0.2
|
||||||
|
// ClassificationResult #1 (second chunk of data):
|
||||||
|
// timestamp_ms: 800 (starts at 800ms)
|
||||||
|
// classifications #0 (single head model):
|
||||||
|
// category #0:
|
||||||
|
// category_name: "Speech"
|
||||||
|
// score: 0.5
|
||||||
|
// category #1:
|
||||||
|
// category_name: "Silence"
|
||||||
|
// score: 0.1
|
||||||
|
// ...
|
||||||
|
//
|
||||||
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
||||||
// and makes `audio_sample_rate` optional.
|
// and makes `audio_sample_rate` optional.
|
||||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
absl::StatusOr<std::vector<AudioClassifierResult>> Classify(
|
||||||
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
||||||
|
|
||||||
// Sends audio data (a block in a continuous audio stream) to perform audio
|
// Sends audio data (a block in a continuous audio stream) to perform audio
|
||||||
|
@ -147,17 +164,10 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
||||||
// milliseconds) to indicate the start time of the input audio block. The
|
// milliseconds) to indicate the start time of the input audio block. The
|
||||||
// timestamps must be monotonically increasing.
|
// timestamps must be monotonically increasing.
|
||||||
//
|
//
|
||||||
// The output classifications are grouped in a ClassificationResult object
|
// The input audio block may be longer than what the model is able to process
|
||||||
// that has three dimensions:
|
// in a single inference. When this occurs, the input audio block is split
|
||||||
// Classification head:
|
// into multiple chunks. For this reason, the callback may be called multiple
|
||||||
// The prediction heads targeting different audio classification tasks
|
// times (once per chunk) for each call to this function.
|
||||||
// such as audio event classification and bird sound classification.
|
|
||||||
// Classification timestamp :
|
|
||||||
// The start time (in milliseconds) of the framed audio block that is sent
|
|
||||||
// to the model for audio classification.
|
|
||||||
// Classification category:
|
|
||||||
// The list of the classification categories that model predicts per
|
|
||||||
// framed audio clip.
|
|
||||||
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
|
absl::Status ClassifyAsync(mediapipe::Matrix audio_block, int64 timestamp_ms);
|
||||||
|
|
||||||
// Shuts down the AudioClassifier when all works are done.
|
// Shuts down the AudioClassifier when all works are done.
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
@ -57,12 +58,20 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
||||||
constexpr char kAudioTag[] = "AUDIO";
|
constexpr char kAudioTag[] = "AUDIO";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
constexpr char kPacketTag[] = "PACKET";
|
constexpr char kPacketTag[] = "PACKET";
|
||||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
|
||||||
|
// Struct holding the different output streams produced by the audio classifier
|
||||||
|
// graph.
|
||||||
|
struct AudioClassifierOutputStreams {
|
||||||
|
Source<ClassificationResult> classifications;
|
||||||
|
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||||
|
};
|
||||||
|
|
||||||
absl::Status SanityCheckOptions(
|
absl::Status SanityCheckOptions(
|
||||||
const proto::AudioClassifierGraphOptions& options) {
|
const proto::AudioClassifierGraphOptions& options) {
|
||||||
if (options.base_options().use_stream_mode() &&
|
if (options.base_options().use_stream_mode() &&
|
||||||
|
@ -124,16 +133,20 @@ void ConfigureAudioToTensorCalculator(
|
||||||
// series stream header with sample rate info.
|
// series stream header with sample rate info.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
// The aggregated classification result object that has 3 dimensions:
|
// The classification results aggregated by head. Only produces results if
|
||||||
// (classification head, classification timestamp, classification category).
|
// the graph if the 'use_stream_mode' option is true.
|
||||||
|
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||||
|
// The classification result aggregated by timestamp, then by head. Only
|
||||||
|
// produces results if the graph if the 'use_stream_mode' option is false.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
|
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
|
||||||
// input_stream: "AUDIO:audio_in"
|
// input_stream: "AUDIO:audio_in"
|
||||||
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
||||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
// output_stream: "CLASSIFICATIONS:classifications"
|
||||||
|
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
|
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
|
||||||
// {
|
// {
|
||||||
|
@ -162,7 +175,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
.base_options()
|
.base_options()
|
||||||
.use_stream_mode();
|
.use_stream_mode();
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto classification_result_out,
|
auto output_streams,
|
||||||
BuildAudioClassificationTask(
|
BuildAudioClassificationTask(
|
||||||
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<Matrix>(kAudioTag)],
|
graph[Input<Matrix>(kAudioTag)],
|
||||||
|
@ -170,8 +183,11 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
? absl::nullopt
|
? absl::nullopt
|
||||||
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
|
: absl::make_optional(graph[Input<double>(kSampleRateTag)]),
|
||||||
graph));
|
graph));
|
||||||
classification_result_out >>
|
output_streams.classifications >>
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
|
output_streams.timestamped_classifications >>
|
||||||
|
graph[Output<std::vector<ClassificationResult>>(
|
||||||
|
kTimestampedClassificationsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +203,7 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
||||||
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
|
absl::StatusOr<AudioClassifierOutputStreams> BuildAudioClassificationTask(
|
||||||
const proto::AudioClassifierGraphOptions& task_options,
|
const proto::AudioClassifierGraphOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||||
|
@ -250,16 +266,20 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Time aggregation is only needed for performing audio classification on
|
// Time aggregation is only needed for performing audio classification on
|
||||||
// audio files. Disables time aggregration by not connecting the
|
// audio files. Disables timestamp aggregation by not connecting the
|
||||||
// "TIMESTAMPS" streams.
|
// "TIMESTAMPS" streams.
|
||||||
if (!use_stream_mode) {
|
if (!use_stream_mode) {
|
||||||
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
|
audio_to_tensor.Out(kTimestampsTag) >> postprocessing.In(kTimestampsTag);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Outputs the aggregated classification result as the subgraph output
|
// Output both streams as graph output streams/
|
||||||
// stream.
|
return AudioClassifierOutputStreams{
|
||||||
return postprocessing[Output<ClassificationResult>(
|
/*classifications=*/postprocessing[Output<ClassificationResult>(
|
||||||
kClassificationResultTag)];
|
kClassificationsTag)],
|
||||||
|
/*timestamped_classifications=*/
|
||||||
|
postprocessing[Output<std::vector<ClassificationResult>>(
|
||||||
|
kTimestampedClassificationsTag)],
|
||||||
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -32,13 +32,11 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
|
||||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -49,7 +47,6 @@ namespace {
|
||||||
|
|
||||||
using ::absl::StatusOr;
|
using ::absl::StatusOr;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
@ -73,95 +70,86 @@ Matrix GetAudioData(absl::string_view filename) {
|
||||||
return matrix_mapping.matrix();
|
return matrix_mapping.matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckSpeechClassificationResult(const ClassificationResult& result) {
|
void CheckSpeechResult(const std::vector<AudioClassifierResult>& result,
|
||||||
EXPECT_THAT(result.classifications_size(), testing::Eq(1));
|
int expected_num_categories = 521) {
|
||||||
EXPECT_EQ(result.classifications(0).head_name(), "scores");
|
EXPECT_EQ(result.size(), 5);
|
||||||
EXPECT_EQ(result.classifications(0).head_index(), 0);
|
// Ignore last result, which operates on a too small chunk to return relevant
|
||||||
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(5));
|
// results.
|
||||||
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
|
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
|
||||||
for (int i = 0; i < timestamps_ms.size(); i++) {
|
for (int i = 0; i < timestamps_ms.size(); i++) {
|
||||||
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
|
EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
|
||||||
testing::Eq(521));
|
EXPECT_EQ(result[i].classifications.size(), 1);
|
||||||
const auto* top_category =
|
auto classifications = result[i].classifications[0];
|
||||||
&result.classifications(0).entries(0).categories(0);
|
EXPECT_EQ(classifications.head_index, 0);
|
||||||
EXPECT_THAT(top_category->category_name(), testing::Eq("Speech"));
|
EXPECT_EQ(classifications.head_name, "scores");
|
||||||
EXPECT_GT(top_category->score(), 0.9f);
|
EXPECT_EQ(classifications.categories.size(), expected_num_categories);
|
||||||
EXPECT_EQ(result.classifications(0).entries(i).timestamp_ms(),
|
auto category = classifications.categories[0];
|
||||||
timestamps_ms[i]);
|
EXPECT_EQ(category.index, 0);
|
||||||
|
EXPECT_EQ(category.category_name, "Speech");
|
||||||
|
EXPECT_GT(category.score, 0.9f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckTwoHeadsClassificationResult(const ClassificationResult& result) {
|
void CheckTwoHeadsResult(const std::vector<AudioClassifierResult>& result) {
|
||||||
EXPECT_THAT(result.classifications_size(), testing::Eq(2));
|
EXPECT_GE(result.size(), 1);
|
||||||
// Checks classification head #1.
|
EXPECT_LE(result.size(), 2);
|
||||||
EXPECT_EQ(result.classifications(0).head_name(), "yamnet_classification");
|
// Check first result.
|
||||||
EXPECT_EQ(result.classifications(0).head_index(), 0);
|
EXPECT_EQ(result[0].timestamp_ms, 0);
|
||||||
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
|
EXPECT_EQ(result[0].classifications.size(), 2);
|
||||||
testing::Eq(521));
|
// Check first head.
|
||||||
const auto* top_category =
|
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
||||||
&result.classifications(0).entries(0).categories(0);
|
EXPECT_EQ(result[0].classifications[0].head_name, "yamnet_classification");
|
||||||
EXPECT_THAT(top_category->category_name(),
|
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
||||||
testing::Eq("Environmental noise"));
|
EXPECT_EQ(result[0].classifications[0].categories[0].index, 508);
|
||||||
EXPECT_GT(top_category->score(), 0.5f);
|
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
||||||
EXPECT_EQ(result.classifications(0).entries(0).timestamp_ms(), 0);
|
"Environmental noise");
|
||||||
if (result.classifications(0).entries_size() == 2) {
|
EXPECT_GT(result[0].classifications[0].categories[0].score, 0.5f);
|
||||||
top_category = &result.classifications(0).entries(1).categories(0);
|
// Check second head.
|
||||||
EXPECT_THAT(top_category->category_name(), testing::Eq("Silence"));
|
EXPECT_EQ(result[0].classifications[1].head_index, 1);
|
||||||
EXPECT_GT(top_category->score(), 0.99f);
|
EXPECT_EQ(result[0].classifications[1].head_name, "bird_classification");
|
||||||
EXPECT_EQ(result.classifications(0).entries(1).timestamp_ms(), 975);
|
EXPECT_EQ(result[0].classifications[1].categories.size(), 5);
|
||||||
|
EXPECT_EQ(result[0].classifications[1].categories[0].index, 4);
|
||||||
|
EXPECT_EQ(result[0].classifications[1].categories[0].category_name,
|
||||||
|
"Chestnut-crowned Antpitta");
|
||||||
|
EXPECT_GT(result[0].classifications[1].categories[0].score, 0.9f);
|
||||||
|
// Check second result, if present.
|
||||||
|
if (result.size() == 2) {
|
||||||
|
EXPECT_EQ(result[1].timestamp_ms, 975);
|
||||||
|
EXPECT_EQ(result[1].classifications.size(), 2);
|
||||||
|
// Check first head.
|
||||||
|
EXPECT_EQ(result[1].classifications[0].head_index, 0);
|
||||||
|
EXPECT_EQ(result[1].classifications[0].head_name, "yamnet_classification");
|
||||||
|
EXPECT_EQ(result[1].classifications[0].categories.size(), 521);
|
||||||
|
EXPECT_EQ(result[1].classifications[0].categories[0].index, 494);
|
||||||
|
EXPECT_EQ(result[1].classifications[0].categories[0].category_name,
|
||||||
|
"Silence");
|
||||||
|
EXPECT_GT(result[1].classifications[0].categories[0].score, 0.99f);
|
||||||
|
// Check second head.
|
||||||
|
EXPECT_EQ(result[1].classifications[1].head_index, 1);
|
||||||
|
EXPECT_EQ(result[1].classifications[1].head_name, "bird_classification");
|
||||||
|
EXPECT_EQ(result[1].classifications[1].categories.size(), 5);
|
||||||
|
EXPECT_EQ(result[1].classifications[1].categories[0].index, 1);
|
||||||
|
EXPECT_EQ(result[1].classifications[1].categories[0].category_name,
|
||||||
|
"White-breasted Wood-Wren");
|
||||||
|
EXPECT_GT(result[1].classifications[1].categories[0].score, 0.99f);
|
||||||
}
|
}
|
||||||
// Checks classification head #2.
|
|
||||||
EXPECT_EQ(result.classifications(1).head_name(), "bird_classification");
|
|
||||||
EXPECT_EQ(result.classifications(1).head_index(), 1);
|
|
||||||
EXPECT_THAT(result.classifications(1).entries(0).categories_size(),
|
|
||||||
testing::Eq(5));
|
|
||||||
top_category = &result.classifications(1).entries(0).categories(0);
|
|
||||||
EXPECT_THAT(top_category->category_name(),
|
|
||||||
testing::Eq("Chestnut-crowned Antpitta"));
|
|
||||||
EXPECT_GT(top_category->score(), 0.9f);
|
|
||||||
EXPECT_EQ(result.classifications(1).entries(0).timestamp_ms(), 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ClassificationResult GenerateSpeechClassificationResult() {
|
void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
|
||||||
return ParseTextProtoOrDie<ClassificationResult>(
|
EXPECT_EQ(outputs.size(), 5);
|
||||||
R"pb(classifications {
|
// Ignore last result, which operates on a too small chunk to return relevant
|
||||||
head_index: 0
|
// results.
|
||||||
head_name: "scores"
|
for (int i = 0; i < outputs.size() - 1; i++) {
|
||||||
entries {
|
EXPECT_FALSE(outputs[i].timestamp_ms.has_value());
|
||||||
categories { index: 0 score: 0.94140625 category_name: "Speech" }
|
EXPECT_EQ(outputs[i].classifications.size(), 1);
|
||||||
timestamp_ms: 0
|
EXPECT_EQ(outputs[i].classifications[0].head_index, 0);
|
||||||
|
EXPECT_EQ(outputs[i].classifications[0].head_name, "scores");
|
||||||
|
EXPECT_EQ(outputs[i].classifications[0].categories.size(), 1);
|
||||||
|
EXPECT_EQ(outputs[i].classifications[0].categories[0].index, 0);
|
||||||
|
EXPECT_EQ(outputs[i].classifications[0].categories[0].category_name,
|
||||||
|
"Speech");
|
||||||
|
EXPECT_GT(outputs[i].classifications[0].categories[0].score, 0.9f);
|
||||||
}
|
}
|
||||||
entries {
|
|
||||||
categories { index: 0 score: 0.9921875 category_name: "Speech" }
|
|
||||||
timestamp_ms: 975
|
|
||||||
}
|
|
||||||
entries {
|
|
||||||
categories { index: 0 score: 0.98828125 category_name: "Speech" }
|
|
||||||
timestamp_ms: 1950
|
|
||||||
}
|
|
||||||
entries {
|
|
||||||
categories { index: 0 score: 0.99609375 category_name: "Speech" }
|
|
||||||
timestamp_ms: 2925
|
|
||||||
}
|
|
||||||
entries {
|
|
||||||
# categories are filtered out due to the low scores.
|
|
||||||
timestamp_ms: 3900
|
|
||||||
}
|
|
||||||
})pb");
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckStreamingModeClassificationResult(
|
|
||||||
std::vector<ClassificationResult> outputs) {
|
|
||||||
ASSERT_TRUE(outputs.size() == 5 || outputs.size() == 6);
|
|
||||||
auto expected_results = GenerateSpeechClassificationResult();
|
|
||||||
for (int i = 0; i < outputs.size() - 1; ++i) {
|
|
||||||
EXPECT_THAT(outputs[i].classifications(0).entries(0),
|
|
||||||
EqualsProto(expected_results.classifications(0).entries(i)));
|
|
||||||
}
|
|
||||||
int last_elem_index = outputs.size() - 1;
|
|
||||||
EXPECT_EQ(
|
|
||||||
mediapipe::Timestamp::Done().Value() / 1000,
|
|
||||||
outputs[last_elem_index].classifications(0).entries(0).timestamp_ms());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||||
|
@ -264,7 +252,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[](absl::StatusOr<ClassificationResult> status_or_result) {};
|
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
|
||||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||||
AudioClassifier::Create(std::move(options));
|
AudioClassifier::Create(std::move(options));
|
||||||
|
|
||||||
|
@ -284,7 +272,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) {
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithoutMetadata);
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[](absl::StatusOr<ClassificationResult> status_or_result) {};
|
[](absl::StatusOr<AudioClassifierResult> status_or_result) {};
|
||||||
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or =
|
||||||
AudioClassifier::Create(std::move(options));
|
AudioClassifier::Create(std::move(options));
|
||||||
|
|
||||||
|
@ -310,7 +298,7 @@ TEST_F(ClassifyTest, Succeeds) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/16000));
|
/*audio_sample_rate=*/16000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckSpeechClassificationResult(result);
|
CheckSpeechResult(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithResampling) {
|
TEST_F(ClassifyTest, SucceedsWithResampling) {
|
||||||
|
@ -324,7 +312,7 @@ TEST_F(ClassifyTest, SucceedsWithResampling) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/48000));
|
/*audio_sample_rate=*/48000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckSpeechClassificationResult(result);
|
CheckSpeechResult(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
||||||
|
@ -339,13 +327,13 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) {
|
||||||
auto result_16k_hz,
|
auto result_16k_hz,
|
||||||
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
|
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
|
||||||
/*audio_sample_rate=*/16000));
|
/*audio_sample_rate=*/16000));
|
||||||
CheckSpeechClassificationResult(result_16k_hz);
|
CheckSpeechResult(result_16k_hz);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto result_48k_hz,
|
auto result_48k_hz,
|
||||||
audio_classifier->Classify(std::move(audio_buffer_48k_hz),
|
audio_classifier->Classify(std::move(audio_buffer_48k_hz),
|
||||||
/*audio_sample_rate=*/48000));
|
/*audio_sample_rate=*/48000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckSpeechClassificationResult(result_48k_hz);
|
CheckSpeechResult(result_48k_hz);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
||||||
|
@ -361,15 +349,16 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
|
auto result, audio_classifier->Classify(std::move(zero_matrix), 16000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
EXPECT_THAT(result.classifications_size(), testing::Eq(1));
|
EXPECT_EQ(result.size(), 1);
|
||||||
EXPECT_THAT(result.classifications(0).entries_size(), testing::Eq(1));
|
EXPECT_EQ(result[0].timestamp_ms, 0);
|
||||||
EXPECT_THAT(result.classifications(0).entries(0).categories_size(),
|
EXPECT_EQ(result[0].classifications.size(), 1);
|
||||||
testing::Eq(521));
|
EXPECT_EQ(result[0].classifications[0].head_index, 0);
|
||||||
EXPECT_THAT(
|
EXPECT_EQ(result[0].classifications[0].head_name, "scores");
|
||||||
result.classifications(0).entries(0).categories(0).category_name(),
|
EXPECT_EQ(result[0].classifications[0].categories.size(), 521);
|
||||||
testing::Eq("Silence"));
|
EXPECT_EQ(result[0].classifications[0].categories[0].index, 494);
|
||||||
EXPECT_THAT(result.classifications(0).entries(0).categories(0).score(),
|
EXPECT_EQ(result[0].classifications[0].categories[0].category_name,
|
||||||
testing::FloatEq(.800781f));
|
"Silence");
|
||||||
|
EXPECT_FLOAT_EQ(result[0].classifications[0].categories[0].score, 0.800781f);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
||||||
|
@ -383,7 +372,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/16000));
|
/*audio_sample_rate=*/16000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckTwoHeadsClassificationResult(result);
|
CheckTwoHeadsResult(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
||||||
|
@ -397,7 +386,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/44100));
|
/*audio_sample_rate=*/44100));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckTwoHeadsClassificationResult(result);
|
CheckTwoHeadsResult(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest,
|
TEST_F(ClassifyTest,
|
||||||
|
@ -413,13 +402,13 @@ TEST_F(ClassifyTest,
|
||||||
auto result_44k_hz,
|
auto result_44k_hz,
|
||||||
audio_classifier->Classify(std::move(audio_buffer_44k_hz),
|
audio_classifier->Classify(std::move(audio_buffer_44k_hz),
|
||||||
/*audio_sample_rate=*/44100));
|
/*audio_sample_rate=*/44100));
|
||||||
CheckTwoHeadsClassificationResult(result_44k_hz);
|
CheckTwoHeadsResult(result_44k_hz);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto result_16k_hz,
|
auto result_16k_hz,
|
||||||
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
|
audio_classifier->Classify(std::move(audio_buffer_16k_hz),
|
||||||
/*audio_sample_rate=*/16000));
|
/*audio_sample_rate=*/16000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckTwoHeadsClassificationResult(result_16k_hz);
|
CheckTwoHeadsResult(result_16k_hz);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
||||||
|
@ -428,14 +417,13 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
JoinPath("./", kTestDataDirectory, kModelWithMetadata);
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->classifier_options.score_threshold = 0.35f;
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
AudioClassifier::Create(std::move(options)));
|
AudioClassifier::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/48000));
|
/*audio_sample_rate=*/48000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
|
CheckSpeechResult(result, /*expected_num_categories=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
||||||
|
@ -450,7 +438,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/48000));
|
/*audio_sample_rate=*/48000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
|
CheckSpeechResult(result, /*expected_num_categories=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
||||||
|
@ -466,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/48000));
|
/*audio_sample_rate=*/48000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
EXPECT_THAT(result, EqualsProto(GenerateSpeechClassificationResult()));
|
CheckSpeechResult(result, /*expected_num_categories=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
||||||
|
@ -482,16 +470,16 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
||||||
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
auto result, audio_classifier->Classify(std::move(audio_buffer),
|
||||||
/*audio_sample_rate=*/48000));
|
/*audio_sample_rate=*/48000));
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
// All categroies with the "Speech" label are filtered out.
|
// All categories with the "Speech" label are filtered out.
|
||||||
EXPECT_THAT(result, EqualsProto(R"pb(classifications {
|
std::vector<int64> timestamps_ms = {0, 975, 1950, 2925};
|
||||||
head_index: 0
|
for (int i = 0; i < timestamps_ms.size(); i++) {
|
||||||
head_name: "scores"
|
EXPECT_EQ(result[i].timestamp_ms, timestamps_ms[i]);
|
||||||
entries { timestamp_ms: 0 }
|
EXPECT_EQ(result[i].classifications.size(), 1);
|
||||||
entries { timestamp_ms: 975 }
|
auto classifications = result[i].classifications[0];
|
||||||
entries { timestamp_ms: 1950 }
|
EXPECT_EQ(classifications.head_index, 0);
|
||||||
entries { timestamp_ms: 2925 }
|
EXPECT_EQ(classifications.head_name, "scores");
|
||||||
entries { timestamp_ms: 3900 }
|
EXPECT_TRUE(classifications.categories.empty());
|
||||||
})pb"));
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
|
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
|
||||||
|
@ -506,9 +494,9 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
options->classifier_options.score_threshold = 0.3f;
|
options->classifier_options.score_threshold = 0.3f;
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->sample_rate = kSampleRateHz;
|
options->sample_rate = kSampleRateHz;
|
||||||
std::vector<ClassificationResult> outputs;
|
std::vector<AudioClassifierResult> outputs;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) {
|
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
|
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
|
||||||
};
|
};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
|
@ -523,7 +511,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
start_col += kYamnetNumOfAudioSamples * 3;
|
start_col += kYamnetNumOfAudioSamples * 3;
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckStreamingModeClassificationResult(outputs);
|
CheckStreamingModeResults(outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
|
@ -536,9 +524,9 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
options->classifier_options.score_threshold = 0.3f;
|
options->classifier_options.score_threshold = 0.3f;
|
||||||
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
options->running_mode = core::RunningMode::AUDIO_STREAM;
|
||||||
options->sample_rate = kSampleRateHz;
|
options->sample_rate = kSampleRateHz;
|
||||||
std::vector<ClassificationResult> outputs;
|
std::vector<AudioClassifierResult> outputs;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&outputs](absl::StatusOr<ClassificationResult> status_or_result) {
|
[&outputs](absl::StatusOr<AudioClassifierResult> status_or_result) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
|
MP_ASSERT_OK_AND_ASSIGN(outputs.emplace_back(), status_or_result);
|
||||||
};
|
};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier,
|
||||||
|
@ -555,7 +543,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
||||||
start_col += num_samples;
|
start_col += num_samples;
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(audio_classifier->Close());
|
MP_ASSERT_OK(audio_classifier->Close());
|
||||||
CheckStreamingModeClassificationResult(outputs);
|
CheckStreamingModeResults(outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -40,6 +40,7 @@ cc_library(
|
||||||
"//mediapipe/calculators/image:image_properties_calculator",
|
"//mediapipe/calculators/image:image_properties_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
|
@ -60,41 +61,6 @@ cc_library(
|
||||||
|
|
||||||
# TODO: Enable this test
|
# TODO: Enable this test
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "embedder_options",
|
|
||||||
srcs = ["embedder_options.cc"],
|
|
||||||
hdrs = ["embedder_options.h"],
|
|
||||||
deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "embedding_postprocessing_graph",
|
|
||||||
srcs = ["embedding_postprocessing_graph.cc"],
|
|
||||||
hdrs = ["embedding_postprocessing_graph.h"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
|
||||||
"//mediapipe/framework/api2:port",
|
|
||||||
"//mediapipe/framework/formats:tensor",
|
|
||||||
"//mediapipe/framework/tool:options_map",
|
|
||||||
"//mediapipe/tasks/cc:common",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Investigate rewriting the build rule to only link
|
# TODO: Investigate rewriting the build rule to only link
|
||||||
# the Bert Preprocessor if it's needed.
|
# the Bert Preprocessor if it's needed.
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -163,7 +163,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -178,7 +178,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
|
|
|
@ -26,14 +26,14 @@
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::proto::Embedding;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
|
||||||
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
|
// Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
|
||||||
|
@ -66,7 +66,7 @@ float GetInverseL2Norm(const float* values, int size) {
|
||||||
class TensorsToEmbeddingsCalculator : public Node {
|
class TensorsToEmbeddingsCalculator : public Node {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"};
|
||||||
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"};
|
static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDINGS"};
|
||||||
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
|
MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut);
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override;
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
@ -77,8 +77,8 @@ class TensorsToEmbeddingsCalculator : public Node {
|
||||||
bool quantize_;
|
bool quantize_;
|
||||||
std::vector<std::string> head_names_;
|
std::vector<std::string> head_names_;
|
||||||
|
|
||||||
void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
void FillFloatEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||||
void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry);
|
void FillQuantizedEmbedding(const Tensor& tensor, Embedding* embedding);
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) {
|
||||||
|
@ -104,42 +104,42 @@ absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) {
|
||||||
for (int i = 0; i < tensors.size(); ++i) {
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
const auto& tensor = tensors[i];
|
const auto& tensor = tensors[i];
|
||||||
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32);
|
||||||
auto* embeddings = result.add_embeddings();
|
auto* embedding = result.add_embeddings();
|
||||||
embeddings->set_head_index(i);
|
embedding->set_head_index(i);
|
||||||
if (!head_names_.empty()) {
|
if (!head_names_.empty()) {
|
||||||
embeddings->set_head_name(head_names_[i]);
|
embedding->set_head_name(head_names_[i]);
|
||||||
}
|
}
|
||||||
if (quantize_) {
|
if (quantize_) {
|
||||||
FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries());
|
FillQuantizedEmbedding(tensor, embedding);
|
||||||
} else {
|
} else {
|
||||||
FillFloatEmbeddingEntry(tensor, embeddings->add_entries());
|
FillFloatEmbedding(tensor, embedding);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kEmbeddingsOut(cc).Send(result);
|
kEmbeddingsOut(cc).Send(result);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry(
|
void TensorsToEmbeddingsCalculator::FillFloatEmbedding(const Tensor& tensor,
|
||||||
const Tensor& tensor, EmbeddingEntry* entry) {
|
Embedding* embedding) {
|
||||||
int size = tensor.shape().num_elements();
|
int size = tensor.shape().num_elements();
|
||||||
auto tensor_view = tensor.GetCpuReadView();
|
auto tensor_view = tensor.GetCpuReadView();
|
||||||
const float* tensor_buffer = tensor_view.buffer<float>();
|
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||||
float inv_l2_norm =
|
float inv_l2_norm =
|
||||||
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||||
auto* float_embedding = entry->mutable_float_embedding();
|
auto* float_embedding = embedding->mutable_float_embedding();
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
|
float_embedding->add_values(tensor_buffer[i] * inv_l2_norm);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry(
|
void TensorsToEmbeddingsCalculator::FillQuantizedEmbedding(
|
||||||
const Tensor& tensor, EmbeddingEntry* entry) {
|
const Tensor& tensor, Embedding* embedding) {
|
||||||
int size = tensor.shape().num_elements();
|
int size = tensor.shape().num_elements();
|
||||||
auto tensor_view = tensor.GetCpuReadView();
|
auto tensor_view = tensor.GetCpuReadView();
|
||||||
const float* tensor_buffer = tensor_view.buffer<float>();
|
const float* tensor_buffer = tensor_view.buffer<float>();
|
||||||
float inv_l2_norm =
|
float inv_l2_norm =
|
||||||
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f;
|
||||||
auto* values = entry->mutable_quantized_embedding()->mutable_values();
|
auto* values = embedding->mutable_quantized_embedding()->mutable_values();
|
||||||
values->resize(size);
|
values->resize(size);
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
// Normalize.
|
// Normalize.
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe;
|
package mediapipe;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
|
|
||||||
message TensorsToEmbeddingsCalculatorOptions {
|
message TensorsToEmbeddingsCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
|
@ -27,8 +27,8 @@ message TensorsToEmbeddingsCalculatorOptions {
|
||||||
|
|
||||||
// The embedder options defining whether to L2-normalize or scalar-quantize
|
// The embedder options defining whether to L2-normalize or scalar-quantize
|
||||||
// the outputs.
|
// the outputs.
|
||||||
optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options =
|
optional mediapipe.tasks.components.processors.proto.EmbedderOptions
|
||||||
1;
|
embedder_options = 1;
|
||||||
|
|
||||||
// The embedder head names.
|
// The embedder head names.
|
||||||
repeated string head_names = 2;
|
repeated string head_names = 2;
|
||||||
|
|
|
@ -55,7 +55,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" }
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: false quantize: false }
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
@ -84,19 +84,15 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
|
||||||
result,
|
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries { float_embedding { values: 0.1 values: 0.2 } }
|
float_embedding { values: 0.1 values: 0.2 }
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries { float_embedding { values: -0.2 values: -0.3 } }
|
float_embedding { values: -0.2 values: -0.3 }
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
@ -105,7 +101,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: false quantize: false }
|
embedder_options { l2_normalize: false quantize: false }
|
||||||
|
@ -118,20 +114,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
EXPECT_THAT(result, EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
|
||||||
result,
|
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries { float_embedding { values: 0.1 values: 0.2 } }
|
float_embedding { values: 0.1 values: 0.2 }
|
||||||
head_index: 0
|
head_index: 0
|
||||||
head_name: "foo"
|
head_name: "foo"
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries { float_embedding { values: -0.2 values: -0.3 } }
|
float_embedding { values: -0.2 values: -0.3 }
|
||||||
head_index: 1
|
head_index: 1
|
||||||
head_name: "bar"
|
head_name: "bar"
|
||||||
})pb")));
|
})pb")));
|
||||||
|
@ -141,7 +133,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: true quantize: false }
|
embedder_options { l2_normalize: true quantize: false }
|
||||||
|
@ -152,23 +144,17 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
result,
|
result,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries {
|
|
||||||
float_embedding { values: 0.44721356 values: 0.8944271 }
|
float_embedding { values: 0.44721356 values: 0.8944271 }
|
||||||
}
|
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries {
|
|
||||||
float_embedding { values: -0.5547002 values: -0.8320503 }
|
float_embedding { values: -0.5547002 values: -0.8320503 }
|
||||||
}
|
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
@ -177,7 +163,7 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: false quantize: true }
|
embedder_options { l2_normalize: false quantize: true }
|
||||||
|
@ -188,22 +174,16 @@ TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) {
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(result,
|
EXPECT_THAT(result,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
quantized_embedding { values: "\x0d\x1a" } # 13,26
|
||||||
}
|
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
quantized_embedding { values: "\xe6\xda" } # -26,-38
|
||||||
}
|
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
@ -213,7 +193,7 @@ TEST(TensorsToEmbeddingsCalculatorTest,
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
calculator: "TensorsToEmbeddingsCalculator"
|
calculator: "TensorsToEmbeddingsCalculator"
|
||||||
input_stream: "TENSORS:tensors"
|
input_stream: "TENSORS:tensors"
|
||||||
output_stream: "EMBEDDING_RESULT:embeddings"
|
output_stream: "EMBEDDINGS:embeddings"
|
||||||
options {
|
options {
|
||||||
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
[mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] {
|
||||||
embedder_options { l2_normalize: true quantize: true }
|
embedder_options { l2_normalize: true quantize: true }
|
||||||
|
@ -224,23 +204,16 @@ TEST(TensorsToEmbeddingsCalculatorTest,
|
||||||
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}});
|
||||||
MP_ASSERT_OK(runner.Run());
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
const EmbeddingResult& result = runner.Outputs()
|
const EmbeddingResult& result =
|
||||||
.Get("EMBEDDING_RESULT", 0)
|
runner.Outputs().Get("EMBEDDINGS", 0).packets[0].Get<EmbeddingResult>();
|
||||||
.packets[0]
|
EXPECT_THAT(result,
|
||||||
.Get<EmbeddingResult>();
|
|
||||||
EXPECT_THAT(
|
|
||||||
result,
|
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
EqualsProto(ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
R"pb(embeddings {
|
R"pb(embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\x39\x72" } # 57,114
|
quantized_embedding { values: "\x39\x72" } # 57,114
|
||||||
}
|
|
||||||
head_index: 0
|
head_index: 0
|
||||||
}
|
}
|
||||||
embeddings {
|
embeddings {
|
||||||
entries {
|
|
||||||
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
quantized_embedding { values: "\xb9\x95" } # -71,-107
|
||||||
}
|
|
||||||
head_index: 1
|
head_index: 1
|
||||||
})pb")));
|
})pb")));
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,3 +49,12 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_result",
|
||||||
|
srcs = ["embedding_result.cc"],
|
||||||
|
hdrs = ["embedding_result.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
57
mediapipe/tasks/cc/components/containers/embedding_result.cc
Normal file
57
mediapipe/tasks/cc/components/containers/embedding_result.cc
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
Embedding ConvertToEmbedding(const proto::Embedding& proto) {
|
||||||
|
Embedding embedding;
|
||||||
|
if (proto.has_float_embedding()) {
|
||||||
|
embedding.float_embedding = {
|
||||||
|
std::make_move_iterator(proto.float_embedding().values().begin()),
|
||||||
|
std::make_move_iterator(proto.float_embedding().values().end())};
|
||||||
|
} else {
|
||||||
|
embedding.quantized_embedding = {
|
||||||
|
std::make_move_iterator(proto.quantized_embedding().values().begin()),
|
||||||
|
std::make_move_iterator(proto.quantized_embedding().values().end())};
|
||||||
|
}
|
||||||
|
embedding.head_index = proto.head_index();
|
||||||
|
if (proto.has_head_name()) {
|
||||||
|
embedding.head_name = proto.head_name();
|
||||||
|
}
|
||||||
|
return embedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto) {
|
||||||
|
EmbeddingResult embedding_result;
|
||||||
|
embedding_result.embeddings.reserve(proto.embeddings_size());
|
||||||
|
for (const auto& embedding : proto.embeddings()) {
|
||||||
|
embedding_result.embeddings.push_back(ConvertToEmbedding(embedding));
|
||||||
|
}
|
||||||
|
if (proto.has_timestamp_ms()) {
|
||||||
|
embedding_result.timestamp_ms = proto.timestamp_ms();
|
||||||
|
}
|
||||||
|
return embedding_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
72
mediapipe/tasks/cc/components/containers/embedding_result.h
Normal file
72
mediapipe/tasks/cc/components/containers/embedding_result.h
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
// Embedding result for a given embedder head.
|
||||||
|
//
|
||||||
|
// One and only one of the two 'float_embedding' and 'quantized_embedding' will
|
||||||
|
// contain data, based on whether or not the embedder was configured to perform
|
||||||
|
// scalar quantization.
|
||||||
|
struct Embedding {
|
||||||
|
// Floating-point embedding. Empty if the embedder was configured to perform
|
||||||
|
// scalar-quantization.
|
||||||
|
std::vector<float> float_embedding;
|
||||||
|
// Scalar-quantized embedding. Empty if the embedder was not configured to
|
||||||
|
// perform scalar quantization.
|
||||||
|
std::string quantized_embedding;
|
||||||
|
// The index of the embedder head (i.e. output tensor) this embedding comes
|
||||||
|
// from. This is useful for multi-head models.
|
||||||
|
int head_index;
|
||||||
|
// The optional name of the embedder head, as provided in the TFLite Model
|
||||||
|
// Metadata [1] if present. This is useful for multi-head models.
|
||||||
|
//
|
||||||
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
|
std::optional<std::string> head_name = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Defines embedding results of a model.
|
||||||
|
struct EmbeddingResult {
|
||||||
|
// The embedding results for each head of the model.
|
||||||
|
std::vector<Embedding> embeddings;
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for embedding extraction on time series (e.g. audio
|
||||||
|
// embedding). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
std::optional<int64_t> timestamp_ms = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility function to convert from Embedding proto to Embedding struct.
|
||||||
|
Embedding ConvertToEmbedding(const proto::Embedding& proto);
|
||||||
|
|
||||||
|
// Utility function to convert from EmbeddingResult proto to EmbeddingResult
|
||||||
|
// struct.
|
||||||
|
EmbeddingResult ConvertToEmbeddingResult(const proto::EmbeddingResult& proto);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_EMBEDDING_RESULT_H_
|
|
@ -30,30 +30,31 @@ message QuantizedEmbedding {
|
||||||
optional bytes values = 1;
|
optional bytes values = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Floating-point or scalar-quantized embedding with an optional timestamp.
|
// Embedding result for a given embedder head.
|
||||||
message EmbeddingEntry {
|
message Embedding {
|
||||||
// The actual embedding, either floating-point or scalar-quantized.
|
// The actual embedding, either floating-point or quantized.
|
||||||
oneof embedding {
|
oneof embedding {
|
||||||
FloatEmbedding float_embedding = 1;
|
FloatEmbedding float_embedding = 1;
|
||||||
QuantizedEmbedding quantized_embedding = 2;
|
QuantizedEmbedding quantized_embedding = 2;
|
||||||
}
|
}
|
||||||
// The optional timestamp (in milliseconds) associated to the embedding entry.
|
|
||||||
// This is useful for time series use cases, e.g. audio embedding.
|
|
||||||
optional int64 timestamp_ms = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Embeddings for a given embedder head.
|
|
||||||
message Embeddings {
|
|
||||||
repeated EmbeddingEntry entries = 1;
|
|
||||||
// The index of the embedder head that produced this embedding. This is useful
|
// The index of the embedder head that produced this embedding. This is useful
|
||||||
// for multi-head models.
|
// for multi-head models.
|
||||||
optional int32 head_index = 2;
|
optional int32 head_index = 3;
|
||||||
// The name of the embedder head, which is the corresponding tensor metadata
|
// The name of the embedder head, which is the corresponding tensor metadata
|
||||||
// name (if any). This is useful for multi-head models.
|
// name (if any). This is useful for multi-head models.
|
||||||
optional string head_name = 3;
|
optional string head_name = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains one set of results per embedder head.
|
// Embedding results for a given embedder model.
|
||||||
message EmbeddingResult {
|
message EmbeddingResult {
|
||||||
repeated Embeddings embeddings = 1;
|
// The embedding results for each model head, i.e. one for each output tensor.
|
||||||
|
repeated Embedding embeddings = 1;
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for embedding extraction on time series (e.g. audio
|
||||||
|
// embedding). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
optional int64 timestamp_ms = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||||
|
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
|
|
@ -62,3 +62,38 @@ cc_library(
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedder_options",
|
||||||
|
srcs = ["embedder_options.cc"],
|
||||||
|
hdrs = ["embedder_options.h"],
|
||||||
|
deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_postprocessing_graph",
|
||||||
|
srcs = ["embedding_postprocessing_graph.cc"],
|
||||||
|
hdrs = ["embedding_postprocessing_graph.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/framework/tool:options_map",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
|
@ -13,22 +13,24 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||||
EmbedderOptions* embedder_options) {
|
EmbedderOptions* embedder_options) {
|
||||||
tasks::components::proto::EmbedderOptions options_proto;
|
proto::EmbedderOptions options_proto;
|
||||||
options_proto.set_l2_normalize(embedder_options->l2_normalize);
|
options_proto.set_l2_normalize(embedder_options->l2_normalize);
|
||||||
options_proto.set_quantize(embedder_options->quantize);
|
options_proto.set_quantize(embedder_options->quantize);
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
// Embedder options for MediaPipe C++ embedding extraction tasks.
|
// Embedder options for MediaPipe C++ embedding extraction tasks.
|
||||||
struct EmbedderOptions {
|
struct EmbedderOptions {
|
||||||
|
@ -37,11 +38,12 @@ struct EmbedderOptions {
|
||||||
bool quantize;
|
bool quantize;
|
||||||
};
|
};
|
||||||
|
|
||||||
tasks::components::proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
proto::EmbedderOptions ConvertEmbedderOptionsToProto(
|
||||||
EmbedderOptions* embedder_options);
|
EmbedderOptions* embedder_options);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDER_OPTIONS_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDER_OPTIONS_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -29,8 +29,8 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -49,13 +50,12 @@ using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::components::proto::EmbedderOptions;
|
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using TensorsSource =
|
using TensorsSource =
|
||||||
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||||
|
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
|
||||||
// Identifies whether or not the model has quantized outputs, and performs
|
// Identifies whether or not the model has quantized outputs, and performs
|
||||||
// sanity checks.
|
// sanity checks.
|
||||||
|
@ -144,7 +144,7 @@ absl::StatusOr<std::vector<std::string>> GetHeadNames(
|
||||||
|
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
const ModelResources& model_resources,
|
const ModelResources& model_resources,
|
||||||
const EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
proto::EmbeddingPostprocessingGraphOptions* options) {
|
proto::EmbeddingPostprocessingGraphOptions* options) {
|
||||||
ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
||||||
HasQuantizedOutputs(model_resources));
|
HasQuantizedOutputs(model_resources));
|
||||||
|
@ -188,7 +188,7 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
BuildEmbeddingPostprocessing(
|
BuildEmbeddingPostprocessing(
|
||||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
||||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
|
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,13 +220,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||||
.CopyFrom(options.tensors_to_embeddings_options());
|
.CopyFrom(options.tensors_to_embeddings_options());
|
||||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(
|
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
kEmbeddingResultTag)];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
::mediapipe::tasks::components::EmbeddingPostprocessingGraph);
|
::mediapipe::tasks::components::processors::EmbeddingPostprocessingGraph);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
// Configures an EmbeddingPostprocessingGraph using the provided model resources
|
// Configures an EmbeddingPostprocessingGraph using the provided model resources
|
||||||
// and EmbedderOptions.
|
// and EmbedderOptions.
|
||||||
|
@ -44,18 +45,19 @@ namespace components {
|
||||||
// The output tensors of an InferenceCalculator, to convert into
|
// The output tensors of an InferenceCalculator, to convert into
|
||||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDING_RESULT - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult
|
||||||
// The output EmbeddingResult.
|
// The output EmbeddingResult.
|
||||||
//
|
//
|
||||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||||
// embeddings aggregation.
|
// embeddings aggregation.
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const tasks::components::proto::EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
proto::EmbeddingPostprocessingGraphOptions* options);
|
proto::EmbeddingPostprocessingGraphOptions* options);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_EMBEDDING_POSTPROCESSING_GRAPH_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
@ -34,12 +34,10 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::proto::EmbedderOptions;
|
|
||||||
using ::mediapipe::tasks::components::proto::
|
|
||||||
EmbeddingPostprocessingGraphOptions;
|
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||||
|
@ -69,16 +67,17 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
EmbedderOptions options_in;
|
proto::EmbedderOptions options_in;
|
||||||
options_in.set_l2_normalize(true);
|
options_in.set_l2_normalize(true);
|
||||||
|
|
||||||
EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||||
&options_out));
|
&options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
EqualsProto(
|
||||||
|
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||||
R"pb(tensors_to_embeddings_options {
|
R"pb(tensors_to_embeddings_options {
|
||||||
embedder_options { l2_normalize: true }
|
embedder_options { l2_normalize: true }
|
||||||
head_names: "probability"
|
head_names: "probability"
|
||||||
|
@ -90,16 +89,17 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||||
EmbedderOptions options_in;
|
proto::EmbedderOptions options_in;
|
||||||
options_in.set_quantize(true);
|
options_in.set_quantize(true);
|
||||||
|
|
||||||
EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||||
&options_out));
|
&options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
EqualsProto(
|
||||||
|
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||||
R"pb(tensors_to_embeddings_options {
|
R"pb(tensors_to_embeddings_options {
|
||||||
embedder_options { quantize: true }
|
embedder_options { quantize: true }
|
||||||
}
|
}
|
||||||
|
@ -109,17 +109,18 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) {
|
||||||
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||||
CreateModelResourcesForModel(kMobileNetV3Embedder));
|
CreateModelResourcesForModel(kMobileNetV3Embedder));
|
||||||
EmbedderOptions options_in;
|
proto::EmbedderOptions options_in;
|
||||||
options_in.set_quantize(true);
|
options_in.set_quantize(true);
|
||||||
options_in.set_l2_normalize(true);
|
options_in.set_l2_normalize(true);
|
||||||
|
|
||||||
EmbeddingPostprocessingGraphOptions options_out;
|
proto::EmbeddingPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in,
|
||||||
&options_out));
|
&options_out));
|
||||||
|
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out,
|
options_out,
|
||||||
EqualsProto(ParseTextProtoOrDie<EmbeddingPostprocessingGraphOptions>(
|
EqualsProto(
|
||||||
|
ParseTextProtoOrDie<proto::EmbeddingPostprocessingGraphOptions>(
|
||||||
R"pb(tensors_to_embeddings_options {
|
R"pb(tensors_to_embeddings_options {
|
||||||
embedder_options { quantize: true l2_normalize: true }
|
embedder_options { quantize: true l2_normalize: true }
|
||||||
head_names: "feature"
|
head_names: "feature"
|
||||||
|
@ -131,6 +132,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
// supported.
|
// supported.
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -34,3 +34,18 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "embedder_options_proto",
|
||||||
|
srcs = ["embedder_options.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "embedding_postprocessing_graph_options_proto",
|
||||||
|
srcs = ["embedding_postprocessing_graph_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -15,7 +15,10 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||||
|
option java_outer_classname = "EmbedderOptionsProto";
|
||||||
|
|
||||||
// Shared options used by all embedding extraction tasks.
|
// Shared options used by all embedding extraction tasks.
|
||||||
message EmbedderOptions {
|
message EmbedderOptions {
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";
|
import "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.proto";
|
|
@ -23,21 +23,6 @@ mediapipe_proto_library(
|
||||||
srcs = ["segmenter_options.proto"],
|
srcs = ["segmenter_options.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "embedder_options_proto",
|
|
||||||
srcs = ["embedder_options.proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "embedding_postprocessing_graph_options_proto",
|
|
||||||
srcs = ["embedding_postprocessing_graph_options.proto"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
|
||||||
"//mediapipe/framework:calculator_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "text_preprocessing_graph_options_proto",
|
name = "text_preprocessing_graph_options_proto",
|
||||||
srcs = ["text_preprocessing_graph_options.proto"],
|
srcs = ["text_preprocessing_graph_options.proto"],
|
||||||
|
|
|
@ -26,7 +26,7 @@ cc_library(
|
||||||
hdrs = ["cosine_similarity.h"],
|
hdrs = ["cosine_similarity.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -39,7 +39,7 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":cosine_similarity",
|
":cosine_similarity",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -30,7 +30,7 @@ namespace utils {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::Embedding;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
||||||
|
@ -66,39 +66,35 @@ absl::StatusOr<double> ComputeCosineSimilarity(const T& u, const T& v,
|
||||||
// an L2-norm of 0.
|
// an L2-norm of 0.
|
||||||
//
|
//
|
||||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
absl::StatusOr<double> CosineSimilarity(const EmbeddingEntry& u,
|
absl::StatusOr<double> CosineSimilarity(const Embedding& u,
|
||||||
const EmbeddingEntry& v) {
|
const Embedding& v) {
|
||||||
if (u.has_float_embedding() && v.has_float_embedding()) {
|
if (!u.float_embedding.empty() && !v.float_embedding.empty()) {
|
||||||
if (u.float_embedding().values().size() !=
|
if (u.float_embedding.size() != v.float_embedding.size()) {
|
||||||
v.float_embedding().values().size()) {
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
||||||
"of different sizes (%d vs. %d)",
|
"of different sizes (%d vs. %d)",
|
||||||
u.float_embedding().values().size(),
|
u.float_embedding.size(), v.float_embedding.size()),
|
||||||
v.float_embedding().values().size()),
|
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
return ComputeCosineSimilarity(u.float_embedding().values().data(),
|
return ComputeCosineSimilarity(u.float_embedding.data(),
|
||||||
v.float_embedding().values().data(),
|
v.float_embedding.data(),
|
||||||
u.float_embedding().values().size());
|
u.float_embedding.size());
|
||||||
}
|
}
|
||||||
if (u.has_quantized_embedding() && v.has_quantized_embedding()) {
|
if (!u.quantized_embedding.empty() && !v.quantized_embedding.empty()) {
|
||||||
if (u.quantized_embedding().values().size() !=
|
if (u.quantized_embedding.size() != v.quantized_embedding.size()) {
|
||||||
v.quantized_embedding().values().size()) {
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
absl::StrFormat("Cannot compute cosine similarity between embeddings "
|
||||||
"of different sizes (%d vs. %d)",
|
"of different sizes (%d vs. %d)",
|
||||||
u.quantized_embedding().values().size(),
|
u.quantized_embedding.size(),
|
||||||
v.quantized_embedding().values().size()),
|
v.quantized_embedding.size()),
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
return ComputeCosineSimilarity(reinterpret_cast<const int8_t*>(
|
return ComputeCosineSimilarity(
|
||||||
u.quantized_embedding().values().data()),
|
reinterpret_cast<const int8_t*>(u.quantized_embedding.data()),
|
||||||
reinterpret_cast<const int8_t*>(
|
reinterpret_cast<const int8_t*>(v.quantized_embedding.data()),
|
||||||
v.quantized_embedding().values().data()),
|
u.quantized_embedding.size());
|
||||||
u.quantized_embedding().values().size());
|
|
||||||
}
|
}
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
|
|
@ -17,22 +17,20 @@ limitations under the License.
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_COSINE_SIMILARITY_H_
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
// Utility function to compute cosine similarity [1] between two embedding
|
// Utility function to compute cosine similarity [1] between two embeddings. May
|
||||||
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
// return an InvalidArgumentError if e.g. the embeddings are of different types
|
||||||
// of different types (quantized vs. float), have different sizes, or have a
|
// (quantized vs. float), have different sizes, or have a an L2-norm of 0.
|
||||||
// an L2-norm of 0.
|
|
||||||
//
|
//
|
||||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
absl::StatusOr<double> CosineSimilarity(
|
absl::StatusOr<double> CosineSimilarity(const containers::Embedding& u,
|
||||||
const containers::proto::EmbeddingEntry& u,
|
const containers::Embedding& v);
|
||||||
const containers::proto::EmbeddingEntry& v);
|
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
} // namespace components
|
} // namespace components
|
||||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -30,29 +30,27 @@ namespace components {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::Embedding;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
// Helper function to generate float EmbeddingEntry.
|
// Helper function to generate float Embedding.
|
||||||
EmbeddingEntry BuildFloatEntry(std::vector<float> values) {
|
Embedding BuildFloatEmbedding(std::vector<float> values) {
|
||||||
EmbeddingEntry entry;
|
Embedding embedding;
|
||||||
for (const float value : values) {
|
embedding.float_embedding = values;
|
||||||
entry.mutable_float_embedding()->add_values(value);
|
return embedding;
|
||||||
}
|
|
||||||
return entry;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to generate quantized EmbeddingEntry.
|
// Helper function to generate quantized Embedding.
|
||||||
EmbeddingEntry BuildQuantizedEntry(std::vector<int8_t> values) {
|
Embedding BuildQuantizedEmbedding(std::vector<int8_t> values) {
|
||||||
EmbeddingEntry entry;
|
Embedding embedding;
|
||||||
entry.mutable_quantized_embedding()->set_values(
|
uint8_t* data = reinterpret_cast<uint8_t*>(values.data());
|
||||||
reinterpret_cast<uint8_t*>(values.data()), values.size());
|
embedding.quantized_embedding = {data, data + values.size()};
|
||||||
return entry;
|
return embedding;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
||||||
auto u = BuildFloatEntry({0.1, 0.2});
|
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||||
auto v = BuildQuantizedEntry({0, 1});
|
auto v = BuildQuantizedEmbedding({0, 1});
|
||||||
|
|
||||||
auto status = CosineSimilarity(u, v);
|
auto status = CosineSimilarity(u, v);
|
||||||
|
|
||||||
|
@ -63,8 +61,8 @@ TEST(CosineSimilarity, FailsWithQuantizedAndFloatEmbeddings) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, FailsWithZeroNorm) {
|
TEST(CosineSimilarity, FailsWithZeroNorm) {
|
||||||
auto u = BuildFloatEntry({0.1, 0.2});
|
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||||
auto v = BuildFloatEntry({0.0, 0.0});
|
auto v = BuildFloatEmbedding({0.0, 0.0});
|
||||||
|
|
||||||
auto status = CosineSimilarity(u, v);
|
auto status = CosineSimilarity(u, v);
|
||||||
|
|
||||||
|
@ -75,8 +73,8 @@ TEST(CosineSimilarity, FailsWithZeroNorm) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
||||||
auto u = BuildFloatEntry({0.1, 0.2});
|
auto u = BuildFloatEmbedding({0.1, 0.2});
|
||||||
auto v = BuildFloatEntry({0.1, 0.2, 0.3});
|
auto v = BuildFloatEmbedding({0.1, 0.2, 0.3});
|
||||||
|
|
||||||
auto status = CosineSimilarity(u, v);
|
auto status = CosineSimilarity(u, v);
|
||||||
|
|
||||||
|
@ -87,8 +85,8 @@ TEST(CosineSimilarity, FailsWithDifferentSizes) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
||||||
auto u = BuildFloatEntry({1.0, 0.0, 0.0, 0.0});
|
auto u = BuildFloatEmbedding({1.0, 0.0, 0.0, 0.0});
|
||||||
auto v = BuildFloatEntry({0.5, 0.5, 0.5, 0.5});
|
auto v = BuildFloatEmbedding({0.5, 0.5, 0.5, 0.5});
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
||||||
|
|
||||||
|
@ -96,8 +94,8 @@ TEST(CosineSimilarity, SucceedsWithFloatEntries) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
|
TEST(CosineSimilarity, SucceedsWithQuantizedEntries) {
|
||||||
auto u = BuildQuantizedEntry({127, 0, 0, 0});
|
auto u = BuildQuantizedEmbedding({127, 0, 0, 0});
|
||||||
auto v = BuildQuantizedEntry({-128, 0, 0, 0});
|
auto v = BuildQuantizedEmbedding({-128, 0, 0, 0});
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
MP_ASSERT_OK_AND_ASSIGN(auto result, CosineSimilarity(u, v));
|
||||||
|
|
||||||
|
|
|
@ -273,11 +273,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
||||||
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
|
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
|
||||||
kHandGesturesTag)];
|
kHandGesturesTag)];
|
||||||
|
|
||||||
return {{.gesture = hand_gestures,
|
return GestureRecognizerOutputs{
|
||||||
.handedness = handedness,
|
/*gesture=*/hand_gestures,
|
||||||
.hand_landmarks = hand_landmarks,
|
/*handedness=*/handedness,
|
||||||
.hand_world_landmarks = hand_world_landmarks,
|
/*hand_landmarks=*/hand_landmarks,
|
||||||
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}};
|
/*hand_world_landmarks=*/hand_world_landmarks,
|
||||||
|
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -144,10 +144,10 @@ Rect CalculateBound(const NormalizedLandmarkList& list) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate normalized non rotated face bounding box
|
// Populate normalized non rotated face bounding box
|
||||||
return {.left = bounding_box_left,
|
return Rect{/*left=*/bounding_box_left,
|
||||||
.top = bounding_box_top,
|
/*top=*/bounding_box_top,
|
||||||
.right = bounding_box_right,
|
/*right=*/bounding_box_right,
|
||||||
.bottom = bounding_box_bottom};
|
/*bottom=*/bounding_box_bottom};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uses IoU and distance of some corresponding hand landmarks to detect
|
// Uses IoU and distance of some corresponding hand landmarks to detect
|
||||||
|
|
|
@ -26,12 +26,12 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components:embedding_postprocessing_graph",
|
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
@ -49,9 +49,10 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/tool:options_map",
|
"//mediapipe/framework/tool:options_map",
|
||||||
"//mediapipe/tasks/cc/components:embedder_options",
|
"//mediapipe/tasks/cc/components/containers:embedding_result",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:embedder_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
"//mediapipe/tasks/cc/components/utils:cosine_similarity",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
|
|
|
@ -21,9 +21,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/tool/options_map.h"
|
#include "mediapipe/framework/tool/options_map.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
#include "mediapipe/tasks/cc/components/utils/cosine_similarity.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
@ -41,8 +42,8 @@ namespace image_embedder {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kEmbeddingResultStreamName[] = "embedding_result_out";
|
constexpr char kEmbeddingsStreamName[] = "embeddings_out";
|
||||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
@ -53,7 +54,7 @@ constexpr char kGraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
|
"mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry;
|
using ::mediapipe::tasks::components::containers::ConvertToEmbeddingResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
using ::mediapipe::tasks::vision::image_embedder::proto::
|
using ::mediapipe::tasks::vision::image_embedder::proto::
|
||||||
|
@ -71,13 +72,13 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
auto& task_graph = graph.AddNode(kGraphTypeName);
|
auto& task_graph = graph.AddNode(kGraphTypeName);
|
||||||
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
|
task_graph.GetOptions<ImageEmbedderGraphOptions>().Swap(options_proto.get());
|
||||||
task_graph.Out(kEmbeddingResultTag).SetName(kEmbeddingResultStreamName) >>
|
task_graph.Out(kEmbeddingsTag).SetName(kEmbeddingsStreamName) >>
|
||||||
graph.Out(kEmbeddingResultTag);
|
graph.Out(kEmbeddingsTag);
|
||||||
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_graph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingResultTag);
|
graph, task_graph, {kImageTag, kNormRectTag}, kEmbeddingsTag);
|
||||||
}
|
}
|
||||||
graph.In(kImageTag) >> task_graph.In(kImageTag);
|
graph.In(kImageTag) >> task_graph.In(kImageTag);
|
||||||
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
|
graph.In(kNormRectTag) >> task_graph.In(kNormRectTag);
|
||||||
|
@ -95,8 +96,8 @@ std::unique_ptr<ImageEmbedderGraphOptions> ConvertImageEmbedderOptionsToProto(
|
||||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
options->running_mode != core::RunningMode::IMAGE);
|
options->running_mode != core::RunningMode::IMAGE);
|
||||||
auto embedder_options_proto =
|
auto embedder_options_proto =
|
||||||
std::make_unique<tasks::components::proto::EmbedderOptions>(
|
std::make_unique<components::processors::proto::EmbedderOptions>(
|
||||||
components::ConvertEmbedderOptionsToProto(
|
components::processors::ConvertEmbedderOptionsToProto(
|
||||||
&(options->embedder_options)));
|
&(options->embedder_options)));
|
||||||
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
options_proto->mutable_embedder_options()->Swap(embedder_options_proto.get());
|
||||||
return options_proto;
|
return options_proto;
|
||||||
|
@ -121,9 +122,10 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Packet embedding_result_packet =
|
Packet embedding_result_packet =
|
||||||
status_or_packets.value()[kEmbeddingResultStreamName];
|
status_or_packets.value()[kEmbeddingsStreamName];
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(embedding_result_packet.Get<EmbeddingResult>(),
|
result_callback(ConvertToEmbeddingResult(
|
||||||
|
embedding_result_packet.Get<EmbeddingResult>()),
|
||||||
image_packet.Get<Image>(),
|
image_packet.Get<Image>(),
|
||||||
embedding_result_packet.Timestamp().Value() /
|
embedding_result_packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond);
|
kMicroSecondsPerMilliSecond);
|
||||||
|
@ -138,7 +140,7 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::Embed(
|
||||||
Image image,
|
Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -155,10 +157,11 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
return ConvertToEmbeddingResult(
|
||||||
|
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
absl::StatusOr<ImageEmbedderResult> ImageEmbedder::EmbedForVideo(
|
||||||
Image image, int64 timestamp_ms,
|
Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -178,7 +181,8 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
return ConvertToEmbeddingResult(
|
||||||
|
output_packets[kEmbeddingsStreamName].Get<EmbeddingResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageEmbedder::EmbedAsync(
|
absl::Status ImageEmbedder::EmbedAsync(
|
||||||
|
@ -202,7 +206,8 @@ absl::Status ImageEmbedder::EmbedAsync(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
|
absl::StatusOr<double> ImageEmbedder::CosineSimilarity(
|
||||||
const EmbeddingEntry& u, const EmbeddingEntry& v) {
|
const components::containers::Embedding& u,
|
||||||
|
const components::containers::Embedding& v) {
|
||||||
return components::utils::CosineSimilarity(u, v);
|
return components::utils::CosineSimilarity(u, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/processors/embedder_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
@ -33,6 +33,10 @@ namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace image_embedder {
|
namespace image_embedder {
|
||||||
|
|
||||||
|
// Alias the shared EmbeddingResult struct as result typo.
|
||||||
|
using ImageEmbedderResult =
|
||||||
|
::mediapipe::tasks::components::containers::EmbeddingResult;
|
||||||
|
|
||||||
// The options for configuring a MediaPipe image embedder task.
|
// The options for configuring a MediaPipe image embedder task.
|
||||||
struct ImageEmbedderOptions {
|
struct ImageEmbedderOptions {
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
@ -50,14 +54,12 @@ struct ImageEmbedderOptions {
|
||||||
|
|
||||||
// Options for configuring the embedder behavior, such as L2-normalization or
|
// Options for configuring the embedder behavior, such as L2-normalization or
|
||||||
// scalar-quantization.
|
// scalar-quantization.
|
||||||
components::EmbedderOptions embedder_options;
|
components::processors::EmbedderOptions embedder_options;
|
||||||
|
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(
|
std::function<void(absl::StatusOr<ImageEmbedderResult>, const Image&, int64)>
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult>,
|
|
||||||
const Image&, int64)>
|
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -104,7 +106,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// running mode.
|
// running mode.
|
||||||
//
|
//
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
absl::StatusOr<ImageEmbedderResult> Embed(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -127,7 +129,7 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA. It's required to
|
// The image can be of any size with format RGB or RGBA. It's required to
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
absl::StatusOr<ImageEmbedderResult> EmbedForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -168,15 +170,15 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// Shuts down the ImageEmbedder when all works are done.
|
// Shuts down the ImageEmbedder when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
||||||
// Utility function to compute cosine similarity [1] between two embedding
|
// Utility function to compute cosine similarity [1] between two embeddings.
|
||||||
// entries. May return an InvalidArgumentError if e.g. the feature vectors are
|
// May return an InvalidArgumentError if e.g. the embeddings are of different
|
||||||
// of different types (quantized vs. float), have different sizes, or have a
|
// types (quantized vs. float), have different sizes, or have a an L2-norm of
|
||||||
// an L2-norm of 0.
|
// 0.
|
||||||
//
|
//
|
||||||
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
// [1]: https://en.wikipedia.org/wiki/Cosine_similarity
|
||||||
static absl::StatusOr<double> CosineSimilarity(
|
static absl::StatusOr<double> CosineSimilarity(
|
||||||
const components::containers::proto::EmbeddingEntry& u,
|
const components::containers::Embedding& u,
|
||||||
const components::containers::proto::EmbeddingEntry& v);
|
const components::containers::Embedding& v);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace image_embedder
|
} // namespace image_embedder
|
||||||
|
|
|
@ -20,10 +20,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedding_postprocessing_graph.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
||||||
|
|
||||||
|
@ -40,10 +40,8 @@ using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::components::proto::
|
|
||||||
EmbeddingPostprocessingGraphOptions;
|
|
||||||
|
|
||||||
constexpr char kEmbeddingResultTag[] = "EMBEDDING_RESULT";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
@ -67,7 +65,7 @@ struct ImageEmbedderOutputStreams {
|
||||||
// Describes region of image to perform embedding extraction on.
|
// Describes region of image to perform embedding extraction on.
|
||||||
// @Optional: rect covering the whole image is used if not specified.
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDING_RESULT - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult
|
||||||
// The embedding result.
|
// The embedding result.
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
// The image that embedding extraction runs on.
|
// The image that embedding extraction runs on.
|
||||||
|
@ -76,7 +74,7 @@ struct ImageEmbedderOutputStreams {
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
|
// calculator: "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"
|
||||||
// input_stream: "IMAGE:image_in"
|
// input_stream: "IMAGE:image_in"
|
||||||
// output_stream: "EMBEDDING_RESULT:embedding_result_out"
|
// output_stream: "EMBEDDINGS:embedding_result_out"
|
||||||
// output_stream: "IMAGE:image_out"
|
// output_stream: "IMAGE:image_out"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
|
// [mediapipe.tasks.vision.image_embedder.proto.ImageEmbedderOptions.ext]
|
||||||
|
@ -107,7 +105,7 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
output_streams.embedding_result >>
|
output_streams.embedding_result >>
|
||||||
graph[Output<EmbeddingResult>(kEmbeddingResultTag)];
|
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -152,16 +150,17 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
||||||
// Adds postprocessing calculators and connects its input stream to the
|
// Adds postprocessing calculators and connects its input stream to the
|
||||||
// inference results.
|
// inference results.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.EmbeddingPostprocessingGraph");
|
"mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph");
|
||||||
MP_RETURN_IF_ERROR(components::ConfigureEmbeddingPostprocessing(
|
MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing(
|
||||||
model_resources, task_options.embedder_options(),
|
model_resources, task_options.embedder_options(),
|
||||||
&postprocessing.GetOptions<EmbeddingPostprocessingGraphOptions>()));
|
&postprocessing.GetOptions<components::processors::proto::
|
||||||
|
EmbeddingPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Outputs the embedding results.
|
// Outputs the embedding results.
|
||||||
return ImageEmbedderOutputStreams{
|
return ImageEmbedderOutputStreams{
|
||||||
/*embedding_result=*/postprocessing[Output<EmbeddingResult>(
|
/*embedding_result=*/postprocessing[Output<EmbeddingResult>(
|
||||||
kEmbeddingResultTag)],
|
kEmbeddingsTag)],
|
||||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -42,7 +42,6 @@ namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::Rect;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
@ -54,18 +53,14 @@ constexpr double kSimilarityTolerancy = 1e-6;
|
||||||
|
|
||||||
// Utility function to check the sizes, head_index and head_names of a result
|
// Utility function to check the sizes, head_index and head_names of a result
|
||||||
// procuded by kMobileNetV3Embedder.
|
// procuded by kMobileNetV3Embedder.
|
||||||
void CheckMobileNetV3Result(const EmbeddingResult& result, bool quantized) {
|
void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) {
|
||||||
EXPECT_EQ(result.embeddings().size(), 1);
|
EXPECT_EQ(result.embeddings.size(), 1);
|
||||||
EXPECT_EQ(result.embeddings(0).head_index(), 0);
|
EXPECT_EQ(result.embeddings[0].head_index, 0);
|
||||||
EXPECT_EQ(result.embeddings(0).head_name(), "feature");
|
EXPECT_EQ(result.embeddings[0].head_name, "feature");
|
||||||
EXPECT_EQ(result.embeddings(0).entries().size(), 1);
|
|
||||||
if (quantized) {
|
if (quantized) {
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(result.embeddings[0].quantized_embedding.size(), 1024);
|
||||||
result.embeddings(0).entries(0).quantized_embedding().values().size(),
|
|
||||||
1024);
|
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ(result.embeddings(0).entries(0).float_embedding().values().size(),
|
EXPECT_EQ(result.embeddings[0].float_embedding.size(), 1024);
|
||||||
1024);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +149,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = running_mode;
|
options->running_mode = running_mode;
|
||||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
|
|
||||||
auto image_embedder = ImageEmbedder::Create(std::move(options));
|
auto image_embedder = ImageEmbedder::Create(std::move(options));
|
||||||
|
@ -231,19 +226,18 @@ TEST_F(ImageModeTest, SucceedsWithoutL2Normalization) {
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.925519;
|
double expected_similarity = 0.925519;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -264,19 +258,18 @@ TEST_F(ImageModeTest, SucceedsWithL2Normalization) {
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.925519;
|
double expected_similarity = 0.925519;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -297,19 +290,18 @@ TEST_F(ImageModeTest, SucceedsWithQuantization) {
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, true);
|
CheckMobileNetV3Result(image_result, true);
|
||||||
CheckMobileNetV3Result(crop_result, true);
|
CheckMobileNetV3Result(crop_result, true);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.926791;
|
double expected_similarity = 0.926791;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -333,19 +325,18 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
const EmbeddingResult& image_result,
|
const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image, image_processing_options));
|
image_embedder->Embed(image, image_processing_options));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
crop_result.embeddings[0]));
|
||||||
crop_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.999931;
|
double expected_similarity = 0.999931;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -367,20 +358,19 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
image_processing_options.rotation_degrees = -90;
|
image_processing_options.rotation_degrees = -90;
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& image_result,
|
||||||
image_embedder->Embed(image));
|
image_embedder->Embed(image));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
const EmbeddingResult& rotated_result,
|
const ImageEmbedderResult& rotated_result,
|
||||||
image_embedder->Embed(rotated, image_processing_options));
|
image_embedder->Embed(rotated, image_processing_options));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(image_result, false);
|
CheckMobileNetV3Result(image_result, false);
|
||||||
CheckMobileNetV3Result(rotated_result, false);
|
CheckMobileNetV3Result(rotated_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
image_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
rotated_result.embeddings[0]));
|
||||||
rotated_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.572265;
|
double expected_similarity = 0.572265;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -403,20 +393,19 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
/*rotation_degrees=*/-90};
|
/*rotation_degrees=*/-90};
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const ImageEmbedderResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
const EmbeddingResult& rotated_result,
|
const ImageEmbedderResult& rotated_result,
|
||||||
image_embedder->Embed(rotated, image_processing_options));
|
image_embedder->Embed(rotated, image_processing_options));
|
||||||
|
|
||||||
// Check results.
|
// Check results.
|
||||||
CheckMobileNetV3Result(crop_result, false);
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
CheckMobileNetV3Result(rotated_result, false);
|
CheckMobileNetV3Result(rotated_result, false);
|
||||||
// CheckCosineSimilarity.
|
// CheckCosineSimilarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(double similarity, ImageEmbedder::CosineSimilarity(
|
||||||
double similarity,
|
crop_result.embeddings[0],
|
||||||
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
|
rotated_result.embeddings[0]));
|
||||||
rotated_result.embeddings(0).entries(0)));
|
|
||||||
double expected_similarity = 0.62838;
|
double expected_similarity = 0.62838;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -487,16 +476,16 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
ImageEmbedder::Create(std::move(options)));
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
EmbeddingResult previous_results;
|
ImageEmbedderResult previous_results;
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
image_embedder->EmbedForVideo(image, i));
|
image_embedder->EmbedForVideo(image, i));
|
||||||
CheckMobileNetV3Result(results, false);
|
CheckMobileNetV3Result(results, false);
|
||||||
if (i > 0) {
|
if (i > 0) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(double similarity,
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
ImageEmbedder::CosineSimilarity(
|
double similarity,
|
||||||
results.embeddings(0).entries(0),
|
ImageEmbedder::CosineSimilarity(results.embeddings[0],
|
||||||
previous_results.embeddings(0).entries(0)));
|
previous_results.embeddings[0]));
|
||||||
double expected_similarity = 1.000000;
|
double expected_similarity = 1.000000;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
@ -515,7 +504,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
ImageEmbedder::Create(std::move(options)));
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
@ -546,7 +535,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback = [](absl::StatusOr<EmbeddingResult>,
|
options->result_callback = [](absl::StatusOr<ImageEmbedderResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
ImageEmbedder::Create(std::move(options)));
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
@ -564,7 +553,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LiveStreamModeResults {
|
struct LiveStreamModeResults {
|
||||||
EmbeddingResult embedding_result;
|
ImageEmbedderResult embedding_result;
|
||||||
std::pair<int, int> image_size;
|
std::pair<int, int> image_size;
|
||||||
int64 timestamp_ms;
|
int64 timestamp_ms;
|
||||||
};
|
};
|
||||||
|
@ -580,7 +569,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&results](absl::StatusOr<EmbeddingResult> embedding_result,
|
[&results](absl::StatusOr<ImageEmbedderResult> embedding_result,
|
||||||
const Image& image, int64 timestamp_ms) {
|
const Image& image, int64 timestamp_ms) {
|
||||||
MP_ASSERT_OK(embedding_result.status());
|
MP_ASSERT_OK(embedding_result.status());
|
||||||
results.push_back(
|
results.push_back(
|
||||||
|
@ -612,8 +601,8 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
double similarity,
|
double similarity,
|
||||||
ImageEmbedder::CosineSimilarity(
|
ImageEmbedder::CosineSimilarity(
|
||||||
result.embedding_result.embeddings(0).entries(0),
|
result.embedding_result.embeddings[0],
|
||||||
results[i - 1].embedding_result.embeddings(0).entries(0)));
|
results[i - 1].embedding_result.embeddings[0]));
|
||||||
double expected_similarity = 1.000000;
|
double expected_similarity = 1.000000;
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:embedder_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.image_embedder.proto;
|
package mediapipe.tasks.vision.image_embedder.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message ImageEmbedderGraphOptions {
|
message ImageEmbedderGraphOptions {
|
||||||
|
@ -31,5 +31,5 @@ message ImageEmbedderGraphOptions {
|
||||||
|
|
||||||
// Options for configuring the embedder behavior, such as normalization or
|
// Options for configuring the embedder behavior, such as normalization or
|
||||||
// quantization.
|
// quantization.
|
||||||
optional components.proto.EmbedderOptions embedder_options = 2;
|
optional components.processors.proto.EmbedderOptions embedder_options = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -287,10 +287,9 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
tensor_to_images[Output<Image>::Multiple(kSegmentationTag)][i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {{
|
return ImageSegmenterOutputs{
|
||||||
.segmented_masks = segmented_masks,
|
/*segmented_masks=*/segmented_masks,
|
||||||
.image = preprocessing[Output<Image>(kImageTag)],
|
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||||
}};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||||
|
|
|
@ -18,6 +18,11 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "audio_data",
|
||||||
|
srcs = ["audio_data.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "bounding_box",
|
name = "bounding_box",
|
||||||
srcs = ["bounding_box.py"],
|
srcs = ["bounding_box.py"],
|
||||||
|
@ -36,6 +41,29 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "landmark",
|
||||||
|
srcs = ["landmark.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "landmark_detection_result",
|
||||||
|
srcs = ["landmark_detection_result.py"],
|
||||||
|
deps = [
|
||||||
|
":landmark",
|
||||||
|
":rect",
|
||||||
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["category.py"],
|
srcs = ["category.py"],
|
||||||
|
|
109
mediapipe/tasks/python/components/containers/audio_data.py
Normal file
109
mediapipe/tasks/python/components/containers/audio_data.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe audio data."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AudioFormat:
|
||||||
|
"""Audio format metadata.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_channels: the number of channels of the audio data.
|
||||||
|
sample_rate: the audio sample rate.
|
||||||
|
"""
|
||||||
|
num_channels: int = 1
|
||||||
|
sample_rate: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AudioData(object):
|
||||||
|
"""MediaPipe Tasks' audio container."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, buffer_length: int,
|
||||||
|
audio_format: AudioFormat = AudioFormat()) -> None:
|
||||||
|
"""Initializes the `AudioData` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_length: the length of the audio buffer.
|
||||||
|
audio_format: the audio format metadata.
|
||||||
|
"""
|
||||||
|
self._audio_format = audio_format
|
||||||
|
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clears the internal buffer and fill it with zeros."""
|
||||||
|
self._buffer.fill(0)
|
||||||
|
|
||||||
|
def load_from_array(self,
|
||||||
|
src: np.ndarray,
|
||||||
|
offset: int = 0,
|
||||||
|
size: int = -1) -> None:
|
||||||
|
"""Loads the audio data from a NumPy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: A NumPy source array contains the input audio.
|
||||||
|
offset: An optional offset for loading a slice of the `src` array to the
|
||||||
|
buffer.
|
||||||
|
size: An optional size parameter denoting the number of samples to load
|
||||||
|
from the `src` array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input array has an incorrect shape or if
|
||||||
|
`offset` + `size` exceeds the length of the `src` array.
|
||||||
|
"""
|
||||||
|
if src.shape[1] != self._audio_format.num_channels:
|
||||||
|
raise ValueError(f"Input audio contains an invalid number of channels. "
|
||||||
|
f"Expect {self._audio_format.num_channels}.")
|
||||||
|
|
||||||
|
if size < 0:
|
||||||
|
size = len(src)
|
||||||
|
|
||||||
|
if offset + size > len(src):
|
||||||
|
raise ValueError(
|
||||||
|
f"Index out of range. offset {offset} + size {size} should be <= "
|
||||||
|
f"src's length: {len(src)}")
|
||||||
|
|
||||||
|
if len(src) >= len(self._buffer):
|
||||||
|
# If the internal buffer is shorter than the load target (src), copy
|
||||||
|
# values from the end of the src array to the internal buffer.
|
||||||
|
new_offset = offset + size - len(self._buffer)
|
||||||
|
new_size = len(self._buffer)
|
||||||
|
self._buffer = src[new_offset:new_offset + new_size].copy()
|
||||||
|
else:
|
||||||
|
# Shift the internal buffer backward and add the incoming data to the end
|
||||||
|
# of the buffer.
|
||||||
|
shift = size
|
||||||
|
self._buffer = np.roll(self._buffer, -shift, axis=0)
|
||||||
|
self._buffer[-shift:, :] = src[offset:offset + size].copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_format(self) -> AudioFormat:
|
||||||
|
"""Gets the audio format of the audio."""
|
||||||
|
return self._audio_format
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer_length(self) -> int:
|
||||||
|
"""Gets the sample count of the audio."""
|
||||||
|
return self._buffer.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer(self) -> np.ndarray:
|
||||||
|
"""Gets the internal buffer."""
|
||||||
|
return self._buffer
|
|
@ -14,7 +14,7 @@
|
||||||
"""Category data class."""
|
"""Category data class."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
@ -39,10 +39,10 @@ class Category:
|
||||||
category_name: The label of this category object.
|
category_name: The label of this category object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index: int
|
index: Optional[int] = None
|
||||||
score: float
|
score: Optional[float] = None
|
||||||
display_name: str
|
display_name: Optional[str] = None
|
||||||
category_name: str
|
category_name: Optional[str] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _CategoryProto:
|
def to_pb2(self) -> _CategoryProto:
|
||||||
|
|
122
mediapipe/tasks/python/components/containers/landmark.py
Normal file
122
mediapipe/tasks/python/components/containers/landmark.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Landmark data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_LandmarkProto = landmark_pb2.Landmark
|
||||||
|
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Landmark:
|
||||||
|
"""A landmark that can have 1 to 3 dimensions.
|
||||||
|
|
||||||
|
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
x: The x coordinate.
|
||||||
|
y: The y coordinate.
|
||||||
|
z: The z coordinate.
|
||||||
|
visibility: Landmark visibility. Should stay unset if not supported. Float
|
||||||
|
score of whether landmark is visible or occluded by other objects.
|
||||||
|
Landmark considered as invisible also if it is not present on the screen
|
||||||
|
(out of scene bounds). Depending on the model, visibility value is either
|
||||||
|
a sigmoid or an argument of sigmoid.
|
||||||
|
presence: Landmark presence. Should stay unset if not supported. Float score
|
||||||
|
of whether landmark is present on the scene (located within scene bounds).
|
||||||
|
Depending on the model, presence value is either a result of sigmoid or an
|
||||||
|
argument of sigmoid function to get landmark presence probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x: Optional[float] = None
|
||||||
|
y: Optional[float] = None
|
||||||
|
z: Optional[float] = None
|
||||||
|
visibility: Optional[float] = None
|
||||||
|
presence: Optional[float] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _LandmarkProto:
|
||||||
|
"""Generates a Landmark protobuf object."""
|
||||||
|
return _LandmarkProto(
|
||||||
|
x=self.x,
|
||||||
|
y=self.y,
|
||||||
|
z=self.z,
|
||||||
|
visibility=self.visibility,
|
||||||
|
presence=self.presence)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(cls, pb2_obj: _LandmarkProto) -> 'Landmark':
|
||||||
|
"""Creates a `Landmark` object from the given protobuf object."""
|
||||||
|
return Landmark(
|
||||||
|
x=pb2_obj.x,
|
||||||
|
y=pb2_obj.y,
|
||||||
|
z=pb2_obj.z,
|
||||||
|
visibility=pb2_obj.visibility,
|
||||||
|
presence=pb2_obj.presence)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class NormalizedLandmark:
|
||||||
|
"""A normalized version of above Landmark proto.
|
||||||
|
|
||||||
|
All coordinates should be within [0, 1].
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
x: The normalized x coordinate.
|
||||||
|
y: The normalized y coordinate.
|
||||||
|
z: The normalized z coordinate.
|
||||||
|
visibility: Landmark visibility. Should stay unset if not supported. Float
|
||||||
|
score of whether landmark is visible or occluded by other objects.
|
||||||
|
Landmark considered as invisible also if it is not present on the screen
|
||||||
|
(out of scene bounds). Depending on the model, visibility value is either
|
||||||
|
a sigmoid or an argument of sigmoid.
|
||||||
|
presence: Landmark presence. Should stay unset if not supported. Float score
|
||||||
|
of whether landmark is present on the scene (located within scene bounds).
|
||||||
|
Depending on the model, presence value is either a result of sigmoid or an
|
||||||
|
argument of sigmoid function to get landmark presence probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
x: Optional[float] = None
|
||||||
|
y: Optional[float] = None
|
||||||
|
z: Optional[float] = None
|
||||||
|
visibility: Optional[float] = None
|
||||||
|
presence: Optional[float] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _NormalizedLandmarkProto:
|
||||||
|
"""Generates a NormalizedLandmark protobuf object."""
|
||||||
|
return _NormalizedLandmarkProto(
|
||||||
|
x=self.x,
|
||||||
|
y=self.y,
|
||||||
|
z=self.z,
|
||||||
|
visibility=self.visibility,
|
||||||
|
presence=self.presence)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _NormalizedLandmarkProto) -> 'NormalizedLandmark':
|
||||||
|
"""Creates a `NormalizedLandmark` object from the given protobuf object."""
|
||||||
|
return NormalizedLandmark(
|
||||||
|
x=pb2_obj.x,
|
||||||
|
y=pb2_obj.y,
|
||||||
|
z=pb2_obj.z,
|
||||||
|
visibility=pb2_obj.visibility,
|
||||||
|
presence=pb2_obj.presence)
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Landmarks Detection Result data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from mediapipe.framework.formats import classification_pb2
|
||||||
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||||
|
_ClassificationProto = classification_pb2.Classification
|
||||||
|
_ClassificationListProto = classification_pb2.ClassificationList
|
||||||
|
_LandmarkListProto = landmark_pb2.LandmarkList
|
||||||
|
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
|
||||||
|
_NormalizedRect = rect_module.NormalizedRect
|
||||||
|
_Category = category_module.Category
|
||||||
|
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||||
|
_Landmark = landmark_module.Landmark
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LandmarksDetectionResult:
|
||||||
|
"""Represents the landmarks detection result.
|
||||||
|
|
||||||
|
Attributes: landmarks : A list of `NormalizedLandmark` objects. categories : A
|
||||||
|
list of `Category` objects. world_landmarks : A list of `Landmark` objects.
|
||||||
|
rect : A `NormalizedRect` object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
landmarks: Optional[List[_NormalizedLandmark]]
|
||||||
|
categories: Optional[List[_Category]]
|
||||||
|
world_landmarks: Optional[List[_Landmark]]
|
||||||
|
rect: _NormalizedRect
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
||||||
|
"""Generates a LandmarksDetectionResult protobuf object."""
|
||||||
|
|
||||||
|
classifications = _ClassificationListProto()
|
||||||
|
for category in self.categories:
|
||||||
|
classifications.classification.append(
|
||||||
|
_ClassificationProto(
|
||||||
|
index=category.index,
|
||||||
|
score=category.score,
|
||||||
|
label=category.category_name,
|
||||||
|
display_name=category.display_name))
|
||||||
|
|
||||||
|
return _LandmarksDetectionResultProto(
|
||||||
|
landmarks=_NormalizedLandmarkListProto(self.landmarks),
|
||||||
|
classifications=classifications,
|
||||||
|
world_landmarks=_LandmarkListProto(self.world_landmarks),
|
||||||
|
rect=self.rect.to_pb2())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls,
|
||||||
|
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
|
||||||
|
"""Creates a `LandmarksDetectionResult` object from the given protobuf object.
|
||||||
|
"""
|
||||||
|
categories = []
|
||||||
|
for classification in pb2_obj.classifications.classification:
|
||||||
|
categories.append(
|
||||||
|
category_module.Category(
|
||||||
|
score=classification.score,
|
||||||
|
index=classification.index,
|
||||||
|
category_name=classification.label,
|
||||||
|
display_name=classification.display_name))
|
||||||
|
return LandmarksDetectionResult(
|
||||||
|
landmarks=[
|
||||||
|
_NormalizedLandmark.create_from_pb2(landmark)
|
||||||
|
for landmark in pb2_obj.landmarks.landmark
|
||||||
|
],
|
||||||
|
categories=categories,
|
||||||
|
world_landmarks=[
|
||||||
|
_Landmark.create_from_pb2(landmark)
|
||||||
|
for landmark in pb2_obj.world_landmarks.landmark
|
||||||
|
],
|
||||||
|
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
|
@ -19,75 +19,44 @@ from typing import Any, Optional
|
||||||
from mediapipe.framework.formats import rect_pb2
|
from mediapipe.framework.formats import rect_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_RectProto = rect_pb2.Rect
|
|
||||||
_NormalizedRectProto = rect_pb2.NormalizedRect
|
_NormalizedRectProto = rect_pb2.NormalizedRect
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Rect:
|
class Rect:
|
||||||
"""A rectangle with rotation in image coordinates.
|
"""A rectangle, used as part of detection results or as input region-of-interest.
|
||||||
|
|
||||||
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
|
The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||||
y_center : The Y coordinate of the top-left corner, in pixels.
|
[0,1] but they may exceed these bounds if describing a region overlapping the
|
||||||
width: The width of the rectangle, in pixels.
|
image. The origin is on the top-left corner of the image.
|
||||||
height: The height of the rectangle, in pixels.
|
|
||||||
rotation: Rotation angle is clockwise in radians.
|
Attributes:
|
||||||
rect_id: Optional unique id to help associate different rectangles to each
|
left: The X coordinate of the left side of the rectangle.
|
||||||
other.
|
top: The Y coordinate of the top of the rectangle.
|
||||||
|
right: The X coordinate of the right side of the rectangle.
|
||||||
|
bottom: The Y coordinate of the bottom of the rectangle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x_center: int
|
left: float
|
||||||
y_center: int
|
top: float
|
||||||
width: int
|
right: float
|
||||||
height: int
|
bottom: float
|
||||||
rotation: Optional[float] = 0.0
|
|
||||||
rect_id: Optional[int] = None
|
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def to_pb2(self) -> _RectProto:
|
|
||||||
"""Generates a Rect protobuf object."""
|
|
||||||
return _RectProto(
|
|
||||||
x_center=self.x_center,
|
|
||||||
y_center=self.y_center,
|
|
||||||
width=self.width,
|
|
||||||
height=self.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@doc_controls.do_not_generate_docs
|
|
||||||
def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect':
|
|
||||||
"""Creates a `Rect` object from the given protobuf object."""
|
|
||||||
return Rect(
|
|
||||||
x_center=pb2_obj.x_center,
|
|
||||||
y_center=pb2_obj.y_center,
|
|
||||||
width=pb2_obj.width,
|
|
||||||
height=pb2_obj.height)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
|
||||||
"""Checks if this object is equal to the given object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
other: The object to be compared with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the objects are equal.
|
|
||||||
"""
|
|
||||||
if not isinstance(other, Rect):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class NormalizedRect:
|
class NormalizedRect:
|
||||||
"""A rectangle with rotation in normalized coordinates.
|
"""A rectangle with rotation in normalized coordinates.
|
||||||
|
|
||||||
The values of box
|
Location of the center of the rectangle in image coordinates. The (0.0, 0.0)
|
||||||
|
point is at the (top, left) corner.
|
||||||
|
|
||||||
center location and size are within [0, 1].
|
The values of box center location and size are within [0, 1].
|
||||||
|
|
||||||
Attributes: x_center : The X normalized coordinate of the top-left corner.
|
Attributes:
|
||||||
y_center : The Y normalized coordinate of the top-left corner.
|
x_center: The normalized X coordinate of the rectangle, in image
|
||||||
|
coordinates.
|
||||||
|
y_center: The normalized Y coordinate of the rectangle, in image
|
||||||
|
coordinates.
|
||||||
width: The width of the rectangle.
|
width: The width of the rectangle.
|
||||||
height: The height of the rectangle.
|
height: The height of the rectangle.
|
||||||
rotation: Rotation angle is clockwise in radians.
|
rotation: Rotation angle is clockwise in radians.
|
||||||
|
|
13
mediapipe/tasks/python/components/processors/__init__.py
Normal file
13
mediapipe/tasks/python/components/processors/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -24,7 +24,7 @@ py_library(
|
||||||
srcs = ["test_utils.py"],
|
srcs = ["test_utils.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
visibility = [
|
visibility = [
|
||||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
"//mediapipe/model_maker/python:__subpackages__",
|
||||||
"//mediapipe/tasks:internal",
|
"//mediapipe/tasks:internal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
|
|
@ -53,6 +53,7 @@ py_test(
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
"//mediapipe/tasks/python/vision:image_classifier",
|
"//mediapipe/tasks/python/vision:image_classifier",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -30,9 +30,10 @@ from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.vision import image_classifier
|
from mediapipe.tasks.python.vision import image_classifier
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
_NormalizedRect = rect.NormalizedRect
|
_Rect = rect.Rect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_Category = category.Category
|
_Category = category.Category
|
||||||
|
@ -43,6 +44,7 @@ _Image = image.Image
|
||||||
_ImageClassifier = image_classifier.ImageClassifier
|
_ImageClassifier = image_classifier.ImageClassifier
|
||||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||||
_IMAGE_FILE = 'burger.jpg'
|
_IMAGE_FILE = 'burger.jpg'
|
||||||
|
@ -227,11 +229,11 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(
|
test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||||
# NormalizedRect around the soccer ball.
|
# Region-of-interest around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
# Performs image classification on the input.
|
# Performs image classification on the input.
|
||||||
image_result = classifier.classify(test_image, roi)
|
image_result = classifier.classify(test_image, image_processing_options)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
test_utils.assert_proto_equals(self, image_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(0).to_pb2())
|
_generate_soccer_ball_results(0).to_pb2())
|
||||||
|
@ -417,12 +419,12 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(
|
test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||||
# NormalizedRect around the soccer ball.
|
# Region-of-interest around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classification_result = classifier.classify_for_video(
|
classification_result = classifier.classify_for_video(
|
||||||
test_image, timestamp, roi)
|
test_image, timestamp, image_processing_options)
|
||||||
test_utils.assert_proto_equals(
|
test_utils.assert_proto_equals(
|
||||||
self, classification_result.to_pb2(),
|
self, classification_result.to_pb2(),
|
||||||
_generate_soccer_ball_results(timestamp).to_pb2())
|
_generate_soccer_ball_results(timestamp).to_pb2())
|
||||||
|
@ -491,9 +493,9 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(
|
test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||||
# NormalizedRect around the soccer ball.
|
# Region-of-interest around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
image_processing_options = _ImageProcessingOptions(roi)
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
|
@ -514,7 +516,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
result_callback=check_result)
|
result_callback=check_result)
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
classifier.classify_async(test_image, timestamp, roi)
|
classifier.classify_async(test_image, timestamp,
|
||||||
|
image_processing_options)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -33,8 +33,8 @@ from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_ImageFormat = image_frame.ImageFormat
|
_ImageFormat = image_frame.ImageFormat
|
||||||
_OutputType = image_segmenter.OutputType
|
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
|
||||||
_Activation = image_segmenter.Activation
|
_Activation = image_segmenter.ImageSegmenterOptions.Activation
|
||||||
_ImageSegmenter = image_segmenter.ImageSegmenter
|
_ImageSegmenter = image_segmenter.ImageSegmenter
|
||||||
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
||||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
|
|
@ -55,6 +55,7 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -77,3 +78,27 @@ py_library(
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "gesture_recognizer",
|
||||||
|
srcs = [
|
||||||
|
"gesture_recognizer.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -15,17 +15,25 @@
|
||||||
"""MediaPipe Tasks Vision API."""
|
"""MediaPipe Tasks Vision API."""
|
||||||
|
|
||||||
import mediapipe.tasks.python.vision.core
|
import mediapipe.tasks.python.vision.core
|
||||||
|
import mediapipe.tasks.python.vision.gesture_recognizer
|
||||||
import mediapipe.tasks.python.vision.image_classifier
|
import mediapipe.tasks.python.vision.image_classifier
|
||||||
|
import mediapipe.tasks.python.vision.image_segmenter
|
||||||
import mediapipe.tasks.python.vision.object_detector
|
import mediapipe.tasks.python.vision.object_detector
|
||||||
|
|
||||||
|
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||||
|
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
|
||||||
ImageClassifier = image_classifier.ImageClassifier
|
ImageClassifier = image_classifier.ImageClassifier
|
||||||
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
|
ImageSegmenter = image_segmenter.ImageSegmenter
|
||||||
|
ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
||||||
ObjectDetector = object_detector.ObjectDetector
|
ObjectDetector = object_detector.ObjectDetector
|
||||||
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
||||||
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
|
||||||
# Remove unnecessary modules to avoid duplication in API docs.
|
# Remove unnecessary modules to avoid duplication in API docs.
|
||||||
del core
|
del core
|
||||||
|
del gesture_recognizer
|
||||||
del image_classifier
|
del image_classifier
|
||||||
|
del image_segmenter
|
||||||
del object_detector
|
del object_detector
|
||||||
del mediapipe
|
del mediapipe
|
||||||
|
|
|
@ -23,15 +23,25 @@ py_library(
|
||||||
srcs = ["vision_task_running_mode.py"],
|
srcs = ["vision_task_running_mode.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_processing_options",
|
||||||
|
srcs = ["image_processing_options.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "base_vision_task_api",
|
name = "base_vision_task_api",
|
||||||
srcs = [
|
srcs = [
|
||||||
"base_vision_task_api.py",
|
"base_vision_task_api.py",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":image_processing_options",
|
||||||
":vision_task_running_mode",
|
":vision_task_running_mode",
|
||||||
"//mediapipe/framework:calculator_py_pb2",
|
"//mediapipe/framework:calculator_py_pb2",
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,17 +13,22 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""MediaPipe vision task base api."""
|
"""MediaPipe vision task base api."""
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Callable, Mapping, Optional
|
from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
from mediapipe.framework import calculator_pb2
|
from mediapipe.framework import calculator_pb2
|
||||||
from mediapipe.python._framework_bindings import packet as packet_module
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
from mediapipe.python._framework_bindings import task_runner as task_runner_module
|
||||||
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
_TaskRunner = task_runner_module.TaskRunner
|
_TaskRunner = task_runner_module.TaskRunner
|
||||||
_Packet = packet_module.Packet
|
_Packet = packet_module.Packet
|
||||||
|
_NormalizedRect = rect_module.NormalizedRect
|
||||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
|
|
||||||
class BaseVisionTaskApi(object):
|
class BaseVisionTaskApi(object):
|
||||||
|
@ -122,6 +127,49 @@ class BaseVisionTaskApi(object):
|
||||||
+ self._running_mode.name)
|
+ self._running_mode.name)
|
||||||
self._runner.send(inputs)
|
self._runner.send(inputs)
|
||||||
|
|
||||||
|
def convert_to_normalized_rect(self,
|
||||||
|
options: _ImageProcessingOptions,
|
||||||
|
roi_allowed: bool = True) -> _NormalizedRect:
|
||||||
|
"""Converts from ImageProcessingOptions to NormalizedRect, performing sanity checks on-the-fly.
|
||||||
|
|
||||||
|
If the input ImageProcessingOptions is not present, returns a default
|
||||||
|
NormalizedRect covering the whole image with rotation set to 0. If
|
||||||
|
'roi_allowed' is false, an error will be returned if the input
|
||||||
|
ImageProcessingOptions has its 'region_of_interest' field set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for image processing.
|
||||||
|
roi_allowed: Indicates if the `region_of_interest` field is allowed to be
|
||||||
|
set. By default, it's set to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A normalized rect proto that repesents the image processing options.
|
||||||
|
"""
|
||||||
|
normalized_rect = _NormalizedRect(
|
||||||
|
rotation=0, x_center=0.5, y_center=0.5, width=1, height=1)
|
||||||
|
if options is None:
|
||||||
|
return normalized_rect
|
||||||
|
|
||||||
|
if options.rotation_degrees % 90 != 0:
|
||||||
|
raise ValueError('Expected rotation to be a multiple of 90°.')
|
||||||
|
|
||||||
|
# Convert to radians counter-clockwise.
|
||||||
|
normalized_rect.rotation = -options.rotation_degrees * math.pi / 180.0
|
||||||
|
|
||||||
|
if options.region_of_interest:
|
||||||
|
if not roi_allowed:
|
||||||
|
raise ValueError("This task doesn't support region-of-interest.")
|
||||||
|
roi = options.region_of_interest
|
||||||
|
if roi.left >= roi.right or roi.top >= roi.bottom:
|
||||||
|
raise ValueError('Expected Rect with left < right and top < bottom.')
|
||||||
|
if roi.left < 0 or roi.top < 0 or roi.right > 1 or roi.bottom > 1:
|
||||||
|
raise ValueError('Expected Rect values to be in [0,1].')
|
||||||
|
normalized_rect.x_center = (roi.left + roi.right) / 2.0
|
||||||
|
normalized_rect.y_center = (roi.top + roi.bottom) / 2.0
|
||||||
|
normalized_rect.width = roi.right - roi.left
|
||||||
|
normalized_rect.height = roi.bottom - roi.top
|
||||||
|
return normalized_rect
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Shuts down the mediapipe vision task instance.
|
"""Shuts down the mediapipe vision task instance.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe vision options for image processing."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ImageProcessingOptions:
|
||||||
|
"""Options for image processing.
|
||||||
|
|
||||||
|
If both region-of-interest and rotation are specified, the crop around the
|
||||||
|
region-of-interest is extracted first, then the specified rotation is applied
|
||||||
|
to the crop.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
region_of_interest: The optional region-of-interest to crop from the image.
|
||||||
|
If not specified, the full image is used. Coordinates must be in [0,1]
|
||||||
|
with 'left' < 'right' and 'top' < 'bottom'.
|
||||||
|
rotation_degrees: The rotation to apply to the image (or cropped
|
||||||
|
region-of-interest), in degrees clockwise. The rotation must be a multiple
|
||||||
|
(positive or negative) of 90°.
|
||||||
|
"""
|
||||||
|
region_of_interest: Optional[rect_module.Rect] = None
|
||||||
|
rotation_degrees: int = 0
|
426
mediapipe/tasks/python/vision/gesture_recognizer.py
Normal file
426
mediapipe/tasks/python/vision/gesture_recognizer.py
Normal file
|
@ -0,0 +1,426 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe gesture recognizer task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Callable, Mapping, Optional, List
|
||||||
|
|
||||||
|
from mediapipe.framework.formats import classification_pb2
|
||||||
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
|
from mediapipe.tasks.cc.vision.gesture_recognizer.proto import gesture_recognizer_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_GestureRecognizerGraphOptionsProto = gesture_recognizer_graph_options_pb2.GestureRecognizerGraphOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||||
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
|
_HAND_GESTURE_STREAM_NAME = 'hand_gestures'
|
||||||
|
_HAND_GESTURE_TAG = 'HAND_GESTURES'
|
||||||
|
_HANDEDNESS_STREAM_NAME = 'handedness'
|
||||||
|
_HANDEDNESS_TAG = 'HANDEDNESS'
|
||||||
|
_HAND_LANDMARKS_STREAM_NAME = 'landmarks'
|
||||||
|
_HAND_LANDMARKS_TAG = 'LANDMARKS'
|
||||||
|
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
|
||||||
|
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'
|
||||||
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
_GESTURE_DEFAULT_INDEX = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GestureRecognitionResult:
|
||||||
|
"""The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gestures: Recognized hand gestures of detected hands. Note that the index of
|
||||||
|
the gesture is always -1, because the raw indices from multiple gesture
|
||||||
|
classifiers cannot consolidate to a meaningful index.
|
||||||
|
handedness: Classification of handedness.
|
||||||
|
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
||||||
|
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
gestures: List[List[category_module.Category]]
|
||||||
|
handedness: List[List[category_module.Category]]
|
||||||
|
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||||
|
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_recognition_result(
|
||||||
|
output_packets: Mapping[str,
|
||||||
|
packet_module.Packet]) -> GestureRecognitionResult:
|
||||||
|
"""Consturcts a `GestureRecognitionResult` from output packets."""
|
||||||
|
gestures_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||||
|
handedness_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||||
|
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||||
|
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||||
|
|
||||||
|
gesture_results = []
|
||||||
|
for proto in gestures_proto_list:
|
||||||
|
gesture_categories = []
|
||||||
|
gesture_classifications = classification_pb2.ClassificationList()
|
||||||
|
gesture_classifications.MergeFrom(proto)
|
||||||
|
for gesture in gesture_classifications.classification:
|
||||||
|
gesture_categories.append(
|
||||||
|
category_module.Category(
|
||||||
|
index=_GESTURE_DEFAULT_INDEX,
|
||||||
|
score=gesture.score,
|
||||||
|
display_name=gesture.display_name,
|
||||||
|
category_name=gesture.label))
|
||||||
|
gesture_results.append(gesture_categories)
|
||||||
|
|
||||||
|
handedness_results = []
|
||||||
|
for proto in handedness_proto_list:
|
||||||
|
handedness_categories = []
|
||||||
|
handedness_classifications = classification_pb2.ClassificationList()
|
||||||
|
handedness_classifications.MergeFrom(proto)
|
||||||
|
for handedness in handedness_classifications.classification:
|
||||||
|
handedness_categories.append(
|
||||||
|
category_module.Category(
|
||||||
|
index=handedness.index,
|
||||||
|
score=handedness.score,
|
||||||
|
display_name=handedness.display_name,
|
||||||
|
category_name=handedness.label))
|
||||||
|
handedness_results.append(handedness_categories)
|
||||||
|
|
||||||
|
hand_landmarks_results = []
|
||||||
|
for proto in hand_landmarks_proto_list:
|
||||||
|
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||||
|
hand_landmarks.MergeFrom(proto)
|
||||||
|
hand_landmarks_results.append([
|
||||||
|
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||||
|
for hand_landmark in hand_landmarks.landmark
|
||||||
|
])
|
||||||
|
|
||||||
|
hand_world_landmarks_results = []
|
||||||
|
for proto in hand_world_landmarks_proto_list:
|
||||||
|
hand_world_landmarks = landmark_pb2.LandmarkList()
|
||||||
|
hand_world_landmarks.MergeFrom(proto)
|
||||||
|
hand_world_landmarks_results.append([
|
||||||
|
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||||
|
for hand_world_landmark in hand_world_landmarks.landmark
|
||||||
|
])
|
||||||
|
|
||||||
|
return GestureRecognitionResult(gesture_results, handedness_results,
|
||||||
|
hand_landmarks_results,
|
||||||
|
hand_world_landmarks_results)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GestureRecognizerOptions:
|
||||||
|
"""Options for the gesture recognizer task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the hand gesture recognizer task.
|
||||||
|
running_mode: The running mode of the task. Default to the image mode.
|
||||||
|
Gesture recognizer task has three running modes: 1) The image mode for
|
||||||
|
recognizing hand gestures on single image inputs. 2) The video mode for
|
||||||
|
recognizing hand gestures on the decoded frames of a video. 3) The live
|
||||||
|
stream mode for recognizing hand gestures on a live stream of input data,
|
||||||
|
such as from camera.
|
||||||
|
num_hands: The maximum number of hands can be detected by the recognizer.
|
||||||
|
min_hand_detection_confidence: The minimum confidence score for the hand
|
||||||
|
detection to be considered successful.
|
||||||
|
min_hand_presence_confidence: The minimum confidence score of hand presence
|
||||||
|
score in the hand landmark detection.
|
||||||
|
min_tracking_confidence: The minimum confidence score for the hand tracking
|
||||||
|
to be considered successful.
|
||||||
|
canned_gesture_classifier_options: Options for configuring the canned
|
||||||
|
gestures classifier, such as score threshold, allow list and deny list of
|
||||||
|
gestures. The categories for canned gesture classifiers are: ["None",
|
||||||
|
"Closed_Fist", "Open_Palm", "Pointing_Up", "Thumb_Down", "Thumb_Up",
|
||||||
|
"Victory", "ILoveYou"]. Note this option is subject to change.
|
||||||
|
custom_gesture_classifier_options: Options for configuring the custom
|
||||||
|
gestures classifier, such as score threshold, allow list and deny list of
|
||||||
|
gestures. Note this option is subject to change.
|
||||||
|
result_callback: The user-defined result callback for processing live stream
|
||||||
|
data. The result callback should only be specified when the running mode
|
||||||
|
is set to the live stream mode.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
|
num_hands: Optional[int] = 1
|
||||||
|
min_hand_detection_confidence: Optional[float] = 0.5
|
||||||
|
min_hand_presence_confidence: Optional[float] = 0.5
|
||||||
|
min_tracking_confidence: Optional[float] = 0.5
|
||||||
|
canned_gesture_classifier_options: Optional[
|
||||||
|
_ClassifierOptions] = _ClassifierOptions()
|
||||||
|
custom_gesture_classifier_options: Optional[
|
||||||
|
_ClassifierOptions] = _ClassifierOptions()
|
||||||
|
result_callback: Optional[Callable[
|
||||||
|
[GestureRecognitionResult, image_module.Image, int], None]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
|
||||||
|
"""Generates an GestureRecognizerOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
|
|
||||||
|
# Initialize gesture recognizer options from base options.
|
||||||
|
gesture_recognizer_options_proto = _GestureRecognizerGraphOptionsProto(
|
||||||
|
base_options=base_options_proto)
|
||||||
|
# Configure hand detector and hand landmarker options.
|
||||||
|
hand_landmarker_options_proto = gesture_recognizer_options_proto.hand_landmarker_graph_options
|
||||||
|
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
|
||||||
|
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
|
||||||
|
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
|
||||||
|
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
|
||||||
|
|
||||||
|
# Configure hand gesture recognizer options.
|
||||||
|
hand_gesture_recognizer_options_proto = gesture_recognizer_options_proto.hand_gesture_recognizer_graph_options
|
||||||
|
hand_gesture_recognizer_options_proto.canned_gesture_classifier_graph_options.classifier_options.CopyFrom(
|
||||||
|
self.canned_gesture_classifier_options.to_pb2())
|
||||||
|
hand_gesture_recognizer_options_proto.custom_gesture_classifier_graph_options.classifier_options.CopyFrom(
|
||||||
|
self.custom_gesture_classifier_options.to_pb2())
|
||||||
|
|
||||||
|
return gesture_recognizer_options_proto
|
||||||
|
|
||||||
|
|
||||||
|
class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
"""Class that performs gesture recognition on images."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'GestureRecognizer':
|
||||||
|
"""Creates an `GestureRecognizer` object from a TensorFlow Lite model and the default `GestureRecognizerOptions`.
|
||||||
|
|
||||||
|
Note that the created `GestureRecognizer` instance is in image mode, for
|
||||||
|
recognizing hand gestures on single image inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`GestureRecognizer` object that's created from the model file and the
|
||||||
|
default `GestureRecognizerOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `GestureRecognizer` object from the
|
||||||
|
provided file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = GestureRecognizerOptions(
|
||||||
|
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(
|
||||||
|
cls, options: GestureRecognizerOptions) -> 'GestureRecognizer':
|
||||||
|
"""Creates the `GestureRecognizer` object from gesture recognizer options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the gesture recognizer task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`GestureRecognizer` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `GestureRecognizer` object from
|
||||||
|
`GestureRecognizerOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||||
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
|
return
|
||||||
|
|
||||||
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
|
|
||||||
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
|
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
|
||||||
|
options.result_callback(
|
||||||
|
GestureRecognitionResult([], [], [], []), image,
|
||||||
|
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
return
|
||||||
|
|
||||||
|
gesture_recognition_result = _build_recognition_result(output_packets)
|
||||||
|
timestamp = output_packets[_HAND_GESTURE_STREAM_NAME].timestamp
|
||||||
|
options.result_callback(gesture_recognition_result, image,
|
||||||
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
|
],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([_HAND_GESTURE_TAG, _HAND_GESTURE_STREAM_NAME]),
|
||||||
|
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
|
||||||
|
':'.join([_HAND_LANDMARKS_TAG,
|
||||||
|
_HAND_LANDMARKS_STREAM_NAME]), ':'.join([
|
||||||
|
_HAND_WORLD_LANDMARKS_TAG,
|
||||||
|
_HAND_WORLD_LANDMARKS_STREAM_NAME
|
||||||
|
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(
|
||||||
|
task_info.generate_graph_config(
|
||||||
|
enable_flow_limiting=options.running_mode ==
|
||||||
|
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||||
|
packets_callback if options.result_callback else None)
|
||||||
|
|
||||||
|
def recognize(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> GestureRecognitionResult:
|
||||||
|
"""Performs hand gesture recognition on the given image.
|
||||||
|
|
||||||
|
Only use this method when the GestureRecognizer is created with the image
|
||||||
|
running mode.
|
||||||
|
|
||||||
|
The image can be of any size with format RGB or RGBA.
|
||||||
|
TODO: Describes how the input image will be preprocessed after the yuv
|
||||||
|
support is implemented.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The hand gesture recognition results.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If gesture recognition failed to run.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, roi_allowed=False)
|
||||||
|
output_packets = self._process_image_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||||
|
})
|
||||||
|
|
||||||
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
|
return GestureRecognitionResult([], [], [], [])
|
||||||
|
|
||||||
|
return _build_recognition_result(output_packets)
|
||||||
|
|
||||||
|
def recognize_for_video(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
timestamp_ms: int,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> GestureRecognitionResult:
|
||||||
|
"""Performs gesture recognition on the provided video frame.
|
||||||
|
|
||||||
|
Only use this method when the GestureRecognizer is created with the video
|
||||||
|
running mode.
|
||||||
|
|
||||||
|
Only use this method when the GestureRecognizer is created with the video
|
||||||
|
running mode. It's required to provide the video frame's timestamp (in
|
||||||
|
milliseconds) along with the video frame. The input timestamps should be
|
||||||
|
monotonically increasing for adjacent calls of this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The hand gesture recognition results.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If gesture recognition failed to run.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, roi_allowed=False)
|
||||||
|
output_packets = self._process_video_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
||||||
|
|
||||||
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
|
return GestureRecognitionResult([], [], [], [])
|
||||||
|
|
||||||
|
return _build_recognition_result(output_packets)
|
||||||
|
|
||||||
|
def recognize_async(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
timestamp_ms: int,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> None:
|
||||||
|
"""Sends live image data to perform gesture recognition.
|
||||||
|
|
||||||
|
The results will be available via the "result_callback" provided in the
|
||||||
|
GestureRecognizerOptions. Only use this method when the GestureRecognizer
|
||||||
|
is created with the live stream running mode.
|
||||||
|
|
||||||
|
Only use this method when the GestureRecognizer is created with the live
|
||||||
|
stream running mode. The input timestamps should be monotonically increasing
|
||||||
|
for adjacent calls of this method. This method will return immediately after
|
||||||
|
the input image is accepted. The results will be available via the
|
||||||
|
`result_callback` provided in the `GestureRecognizerOptions`. The
|
||||||
|
`recognize_async` method is designed to process live stream data such as
|
||||||
|
camera input. To lower the overall latency, gesture recognizer may drop the
|
||||||
|
input images if needed. In other words, it's not guaranteed to have output
|
||||||
|
per input image.
|
||||||
|
|
||||||
|
The `result_callback` provides:
|
||||||
|
- The hand gesture recognition results.
|
||||||
|
- The input image that the gesture recognizer runs on.
|
||||||
|
- The input timestamp in milliseconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the current input timestamp is smaller than what the
|
||||||
|
gesture recognizer has already processed.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, roi_allowed=False)
|
||||||
|
self._send_live_stream_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
|
@ -30,6 +30,7 @@ from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
_NormalizedRect = rect.NormalizedRect
|
_NormalizedRect = rect.NormalizedRect
|
||||||
|
@ -37,6 +38,7 @@ _BaseOptions = base_options_module.BaseOptions
|
||||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||||
|
@ -44,17 +46,12 @@ _CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
_NORM_RECT_NAME = 'norm_rect_in'
|
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||||
_NORM_RECT_TAG = 'NORM_RECT'
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
def _build_full_image_norm_rect() -> _NormalizedRect:
|
|
||||||
# Builds a NormalizedRect covering the entire image.
|
|
||||||
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ImageClassifierOptions:
|
class ImageClassifierOptions:
|
||||||
"""Options for the image classifier task.
|
"""Options for the image classifier task.
|
||||||
|
@ -156,7 +153,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[
|
input_streams=[
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=[
|
||||||
':'.join([
|
':'.join([
|
||||||
|
@ -171,17 +168,16 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||||
packets_callback if options.result_callback else None)
|
packets_callback if options.result_callback else None)
|
||||||
|
|
||||||
# TODO: Replace _NormalizedRect with ImageProcessingOption
|
|
||||||
def classify(
|
def classify(
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> classifications.ClassificationResult:
|
) -> classifications.ClassificationResult:
|
||||||
"""Performs image classification on the provided MediaPipe Image.
|
"""Performs image classification on the provided MediaPipe Image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A classification result object that contains a list of classifications.
|
A classification result object that contains a list of classifications.
|
||||||
|
@ -190,10 +186,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
RuntimeError: If image classification failed to run.
|
RuntimeError: If image classification failed to run.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
output_packets = self._process_image_data({
|
output_packets = self._process_image_data({
|
||||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
_IMAGE_IN_STREAM_NAME:
|
||||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
|
packet_creator.create_image(image),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||||
})
|
})
|
||||||
|
|
||||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
@ -210,7 +208,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
roi: Optional[_NormalizedRect] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> classifications.ClassificationResult:
|
) -> classifications.ClassificationResult:
|
||||||
"""Performs image classification on the provided video frames.
|
"""Performs image classification on the provided video frames.
|
||||||
|
|
||||||
|
@ -222,7 +220,7 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A classification result object that contains a list of classifications.
|
A classification result object that contains a list of classifications.
|
||||||
|
@ -231,13 +229,13 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
ValueError: If any of the input arguments is invalid.
|
ValueError: If any of the input arguments is invalid.
|
||||||
RuntimeError: If image classification failed to run.
|
RuntimeError: If image classification failed to run.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
output_packets = self._process_video_data({
|
output_packets = self._process_video_data({
|
||||||
_IMAGE_IN_STREAM_NAME:
|
_IMAGE_IN_STREAM_NAME:
|
||||||
packet_creator.create_image(image).at(
|
packet_creator.create_image(image).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
_NORM_RECT_NAME:
|
_NORM_RECT_STREAM_NAME:
|
||||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -251,10 +249,12 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
for classification in classification_result_proto.classifications
|
for classification in classification_result_proto.classifications
|
||||||
])
|
])
|
||||||
|
|
||||||
def classify_async(self,
|
def classify_async(
|
||||||
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
roi: Optional[_NormalizedRect] = None) -> None:
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> None:
|
||||||
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
||||||
|
|
||||||
Only use this method when the ImageClassifier is created with the live
|
Only use this method when the ImageClassifier is created with the live
|
||||||
|
@ -275,18 +275,18 @@ class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||||
roi: The region of interest.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the current input timestamp is smaller than what the image
|
ValueError: If the current input timestamp is smaller than what the image
|
||||||
classifier has already processed.
|
classifier has already processed.
|
||||||
"""
|
"""
|
||||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
normalized_rect = self.convert_to_normalized_rect(image_processing_options)
|
||||||
self._send_live_stream_data({
|
self._send_live_stream_data({
|
||||||
_IMAGE_IN_STREAM_NAME:
|
_IMAGE_IN_STREAM_NAME:
|
||||||
packet_creator.create_image(image).at(
|
packet_creator.create_image(image).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
_NORM_RECT_NAME:
|
_NORM_RECT_STREAM_NAME:
|
||||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
})
|
})
|
||||||
|
|
|
@ -44,18 +44,6 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.ImageSegmenterGraph'
|
||||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
class OutputType(enum.Enum):
|
|
||||||
UNSPECIFIED = 0
|
|
||||||
CATEGORY_MASK = 1
|
|
||||||
CONFIDENCE_MASK = 2
|
|
||||||
|
|
||||||
|
|
||||||
class Activation(enum.Enum):
|
|
||||||
NONE = 0
|
|
||||||
SIGMOID = 1
|
|
||||||
SOFTMAX = 2
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ImageSegmenterOptions:
|
class ImageSegmenterOptions:
|
||||||
"""Options for the image segmenter task.
|
"""Options for the image segmenter task.
|
||||||
|
@ -74,6 +62,17 @@ class ImageSegmenterOptions:
|
||||||
data. The result callback should only be specified when the running mode
|
data. The result callback should only be specified when the running mode
|
||||||
is set to the live stream mode.
|
is set to the live stream mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class OutputType(enum.Enum):
|
||||||
|
UNSPECIFIED = 0
|
||||||
|
CATEGORY_MASK = 1
|
||||||
|
CONFIDENCE_MASK = 2
|
||||||
|
|
||||||
|
class Activation(enum.Enum):
|
||||||
|
NONE = 0
|
||||||
|
SIGMOID = 1
|
||||||
|
SOFTMAX = 2
|
||||||
|
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
||||||
|
|
21
mediapipe/tasks/web/components/containers/BUILD
Normal file
21
mediapipe/tasks/web/components/containers/BUILD
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# This package contains options shared by all MediaPipe Tasks for Web.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "category",
|
||||||
|
srcs = ["category.d.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "classifications",
|
||||||
|
srcs = ["classifications.d.ts"],
|
||||||
|
deps = [":category"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "landmark",
|
||||||
|
srcs = ["landmark.d.ts"],
|
||||||
|
)
|
38
mediapipe/tasks/web/components/containers/category.d.ts
vendored
Normal file
38
mediapipe/tasks/web/components/containers/category.d.ts
vendored
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** A classification category. */
|
||||||
|
export interface Category {
|
||||||
|
/** The probability score of this label category. */
|
||||||
|
score: number;
|
||||||
|
|
||||||
|
/** The index of the category in the corresponding label file. */
|
||||||
|
index: number;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The label of this category object. Defaults to an empty string if there is
|
||||||
|
* no category.
|
||||||
|
*/
|
||||||
|
categoryName: string;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The display name of the label, which may be translated for different
|
||||||
|
* locales. For example, a label, "apple", may be translated into Spanish for
|
||||||
|
* display purpose, so that the `display_name` is "manzana". Defaults to an
|
||||||
|
* empty string if there is no display name.
|
||||||
|
*/
|
||||||
|
displayName: string;
|
||||||
|
}
|
51
mediapipe/tasks/web/components/containers/classifications.d.ts
vendored
Normal file
51
mediapipe/tasks/web/components/containers/classifications.d.ts
vendored
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
|
||||||
|
/** List of predicted categories with an optional timestamp. */
|
||||||
|
export interface ClassificationEntry {
|
||||||
|
/**
|
||||||
|
* The array of predicted categories, usually sorted by descending scores,
|
||||||
|
* e.g., from high to low probability.
|
||||||
|
*/
|
||||||
|
categories: Category[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The optional timestamp (in milliseconds) associated to the classification
|
||||||
|
* entry. This is useful for time series use cases, e.g., audio
|
||||||
|
* classification.
|
||||||
|
*/
|
||||||
|
timestampMs?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Classifications for a given classifier head. */
|
||||||
|
export interface Classifications {
|
||||||
|
/** A list of classification entries. */
|
||||||
|
entries: ClassificationEntry[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The index of the classifier head these categories refer to. This is
|
||||||
|
* useful for multi-head models.
|
||||||
|
*/
|
||||||
|
headIndex: number;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The name of the classifier head, which is the corresponding tensor
|
||||||
|
* metadata name.
|
||||||
|
*/
|
||||||
|
headName: string;
|
||||||
|
}
|
35
mediapipe/tasks/web/components/containers/landmark.d.ts
vendored
Normal file
35
mediapipe/tasks/web/components/containers/landmark.d.ts
vendored
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Landmark represents a point in 3D space with x, y, z coordinates. If
|
||||||
|
* normalized is true, the landmark coordinates is normalized respect to the
|
||||||
|
* dimension of image, and the coordinates values are in the range of [0,1].
|
||||||
|
* Otherwise, it represenet a point in world coordinates.
|
||||||
|
*/
|
||||||
|
export class Landmark {
|
||||||
|
/** The x coordinates of the landmark. */
|
||||||
|
x: number;
|
||||||
|
|
||||||
|
/** The y coordinates of the landmark. */
|
||||||
|
y: number;
|
||||||
|
|
||||||
|
/** The z coordinates of the landmark. */
|
||||||
|
z: number;
|
||||||
|
|
||||||
|
/** Whether this landmark is normalized with respect to the image size. */
|
||||||
|
normalized: boolean;
|
||||||
|
}
|
33
mediapipe/tasks/web/components/processors/BUILD
Normal file
33
mediapipe/tasks/web/components/processors/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
# This package contains options shared by all MediaPipe Tasks for Web.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "classifier_options",
|
||||||
|
srcs = ["classifier_options.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "classifier_result",
|
||||||
|
srcs = ["classifier_result.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:classifications",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "base_options",
|
||||||
|
srcs = ["base_options.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
],
|
||||||
|
)
|
50
mediapipe/tasks/web/components/processors/base_options.ts
Normal file
50
mediapipe/tasks/web/components/processors/base_options.ts
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
|
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
|
||||||
|
import {BaseOptions} from '../../../../tasks/web/core/base_options';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a BaseOptions API object to its Protobuf representation.
|
||||||
|
* @throws If neither a model assset path or buffer is provided
|
||||||
|
*/
|
||||||
|
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
|
||||||
|
Promise<BaseOptionsProto> {
|
||||||
|
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
|
||||||
|
throw new Error(
|
||||||
|
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
|
||||||
|
}
|
||||||
|
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
|
||||||
|
throw new Error(
|
||||||
|
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
|
||||||
|
}
|
||||||
|
|
||||||
|
let modelAssetBuffer = baseOptions.modelAssetBuffer;
|
||||||
|
if (!modelAssetBuffer) {
|
||||||
|
const response = await fetch(baseOptions.modelAssetPath!.toString());
|
||||||
|
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
|
||||||
|
}
|
||||||
|
|
||||||
|
const proto = new BaseOptionsProto();
|
||||||
|
const externalFile = new ExternalFile();
|
||||||
|
externalFile.setFileContent(modelAssetBuffer);
|
||||||
|
proto.setModelAsset(externalFile);
|
||||||
|
return proto;
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {ClassifierOptions as ClassifierOptionsProto} from '../../../../tasks/cc/components/processors/proto/classifier_options_pb';
|
||||||
|
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a ClassifierOptions object to its Proto representation, optionally
|
||||||
|
* based on existing definition.
|
||||||
|
* @param options The options object to convert to a Proto. Only options that
|
||||||
|
* are expliclty provided are set.
|
||||||
|
* @param baseOptions A base object that options can be merged into.
|
||||||
|
*/
|
||||||
|
export function convertClassifierOptionsToProto(
|
||||||
|
options: ClassifierOptions,
|
||||||
|
baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto {
|
||||||
|
const classifierOptions =
|
||||||
|
baseOptions ? baseOptions.clone() : new ClassifierOptionsProto();
|
||||||
|
if (options.displayNamesLocale) {
|
||||||
|
classifierOptions.setDisplayNamesLocale(options.displayNamesLocale);
|
||||||
|
} else if (options.displayNamesLocale === undefined) {
|
||||||
|
classifierOptions.clearDisplayNamesLocale();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.maxResults) {
|
||||||
|
classifierOptions.setMaxResults(options.maxResults);
|
||||||
|
} else if ('maxResults' in options) { // Check for undefined
|
||||||
|
classifierOptions.clearMaxResults();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.scoreThreshold) {
|
||||||
|
classifierOptions.setScoreThreshold(options.scoreThreshold);
|
||||||
|
} else if ('scoreThreshold' in options) { // Check for undefined
|
||||||
|
classifierOptions.clearScoreThreshold();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.categoryAllowlist) {
|
||||||
|
classifierOptions.setCategoryAllowlistList(options.categoryAllowlist);
|
||||||
|
} else if ('categoryAllowlist' in options) { // Check for undefined
|
||||||
|
classifierOptions.clearCategoryAllowlistList();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.categoryDenylist) {
|
||||||
|
classifierOptions.setCategoryDenylistList(options.categoryDenylist);
|
||||||
|
} else if ('categoryDenylist' in options) { // Check for undefined
|
||||||
|
classifierOptions.clearCategoryDenylistList();
|
||||||
|
}
|
||||||
|
return classifierOptions;
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||||
|
|
||||||
|
const DEFAULT_INDEX = -1;
|
||||||
|
const DEFAULT_SCORE = 0.0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a ClassificationEntry proto to the ClassificationEntry result
|
||||||
|
* type.
|
||||||
|
*/
|
||||||
|
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
|
||||||
|
ClassificationEntry {
|
||||||
|
const categories = source.getCategoriesList().map(category => {
|
||||||
|
return {
|
||||||
|
index: category.getIndex() ?? DEFAULT_INDEX,
|
||||||
|
score: category.getScore() ?? DEFAULT_SCORE,
|
||||||
|
displayName: category.getDisplayName() ?? '',
|
||||||
|
categoryName: category.getCategoryName() ?? '',
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
categories,
|
||||||
|
timestampMs: source.getTimestampMs(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a ClassificationResult proto to a list of classifications.
|
||||||
|
*/
|
||||||
|
export function convertFromClassificationResultProto(
|
||||||
|
classificationResult: ClassificationResult) : Classifications[] {
|
||||||
|
const result: Classifications[] = [];
|
||||||
|
for (const classificationsProto of
|
||||||
|
classificationResult.getClassificationsList()) {
|
||||||
|
const classifications: Classifications = {
|
||||||
|
entries: classificationsProto.getEntriesList().map(
|
||||||
|
entry => convertFromClassificationEntryProto(entry)),
|
||||||
|
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
|
||||||
|
headName: classificationsProto.getHeadName() ?? '',
|
||||||
|
};
|
||||||
|
result.push(classifications);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
21
mediapipe/tasks/web/core/BUILD
Normal file
21
mediapipe/tasks/web/core/BUILD
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# This package contains options shared by all MediaPipe Tasks for Web.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "core",
|
||||||
|
srcs = [
|
||||||
|
"base_options.d.ts",
|
||||||
|
"wasm_loader_options.d.ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "classifier_options",
|
||||||
|
srcs = [
|
||||||
|
"classifier_options.d.ts",
|
||||||
|
],
|
||||||
|
deps = [":core"],
|
||||||
|
)
|
31
mediapipe/tasks/web/core/base_options.d.ts
vendored
Normal file
31
mediapipe/tasks/web/core/base_options.d.ts
vendored
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
/** Options to configure MediaPipe Tasks in general. */
|
||||||
|
export interface BaseOptions {
|
||||||
|
/**
|
||||||
|
* The model path to the model asset file. Only one of `modelAssetPath` or
|
||||||
|
* `modelAssetBuffer` can be set.
|
||||||
|
*/
|
||||||
|
modelAssetPath?: string;
|
||||||
|
/**
|
||||||
|
* A buffer containing the model aaset. Only one of `modelAssetPath` or
|
||||||
|
* `modelAssetBuffer` can be set.
|
||||||
|
*/
|
||||||
|
modelAssetBuffer?: Uint8Array;
|
||||||
|
}
|
52
mediapipe/tasks/web/core/classifier_options.d.ts
vendored
Normal file
52
mediapipe/tasks/web/core/classifier_options.d.ts
vendored
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {BaseOptions} from '../../../tasks/web/core/base_options';
|
||||||
|
|
||||||
|
/** Options to configure the Mediapipe Classifier Task. */
|
||||||
|
export interface ClassifierOptions {
|
||||||
|
/** Options to configure the loading of the model assets. */
|
||||||
|
baseOptions?: BaseOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The locale to use for display names specified through the TFLite Model
|
||||||
|
* Metadata, if any. Defaults to English.
|
||||||
|
*/
|
||||||
|
displayNamesLocale?: string|undefined;
|
||||||
|
|
||||||
|
/** The maximum number of top-scored detection results to return. */
|
||||||
|
maxResults?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Overrides the value provided in the model metadata. Results below this
|
||||||
|
* value are rejected.
|
||||||
|
*/
|
||||||
|
scoreThreshold?: number|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allowlist of category names. If non-empty, detection results whose category
|
||||||
|
* name is not in this set will be filtered out. Duplicate or unknown category
|
||||||
|
* names are ignored. Mutually exclusive with `categoryDenylist`.
|
||||||
|
*/
|
||||||
|
categoryAllowlist?: string[]|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Denylist of category names. If non-empty, detection results whose category
|
||||||
|
* name is in this set will be filtered out. Duplicate or unknown category
|
||||||
|
* names are ignored. Mutually exclusive with `categoryAllowlist`.
|
||||||
|
*/
|
||||||
|
categoryDenylist?: string[]|undefined;
|
||||||
|
}
|
25
mediapipe/tasks/web/core/wasm_loader_options.d.ts
vendored
Normal file
25
mediapipe/tasks/web/core/wasm_loader_options.d.ts
vendored
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
/** An object containing the locations of all Wasm assets */
|
||||||
|
export interface WasmLoaderOptions {
|
||||||
|
/** The path to the Wasm loader script. */
|
||||||
|
wasmLoaderPath: string;
|
||||||
|
/** The path to the Wasm binary. */
|
||||||
|
wasmBinaryPath: string;
|
||||||
|
}
|
15
package.json
Normal file
15
package.json
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
"name": "medipipe-dev",
|
||||||
|
"version": "0.0.0-alphga",
|
||||||
|
"description": "MediaPipe GitHub repo",
|
||||||
|
"devDependencies": {
|
||||||
|
"@bazel/typescript": "^5.7.1",
|
||||||
|
"@types/google-protobuf": "^3.15.6",
|
||||||
|
"@types/offscreencanvas": "^2019.7.0",
|
||||||
|
"google-protobuf": "^3.21.2",
|
||||||
|
"protobufjs": "^7.1.2",
|
||||||
|
"protobufjs-cli": "^1.0.2",
|
||||||
|
"ts-protoc-gen": "^0.15.0",
|
||||||
|
"typescript": "^4.8.4"
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
absl-py
|
absl-py
|
||||||
attrs>=19.1.0
|
attrs>=19.1.0
|
||||||
|
flatbuffers>=2.0
|
||||||
matplotlib
|
matplotlib
|
||||||
numpy
|
numpy
|
||||||
opencv-contrib-python
|
opencv-contrib-python
|
||||||
|
|
51
setup.py
51
setup.py
|
@ -129,6 +129,15 @@ def _add_mp_init_files():
|
||||||
mp_dir_init_file.close()
|
mp_dir_init_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_to_build_lib_dir(build_lib, file):
|
||||||
|
"""Copy a file from bazel-bin to the build lib dir."""
|
||||||
|
dst = os.path.join(build_lib + '/', file)
|
||||||
|
dst_dir = os.path.dirname(dst)
|
||||||
|
if not os.path.exists(dst_dir):
|
||||||
|
os.makedirs(dst_dir)
|
||||||
|
shutil.copyfile(os.path.join('bazel-bin/', file), dst)
|
||||||
|
|
||||||
|
|
||||||
class GeneratePyProtos(build_ext.build_ext):
|
class GeneratePyProtos(build_ext.build_ext):
|
||||||
"""Generate MediaPipe Python protobuf files by Protocol Compiler."""
|
"""Generate MediaPipe Python protobuf files by Protocol Compiler."""
|
||||||
|
|
||||||
|
@ -259,7 +268,7 @@ class BuildModules(build_ext.build_ext):
|
||||||
]
|
]
|
||||||
if subprocess.call(fetch_model_command) != 0:
|
if subprocess.call(fetch_model_command) != 0:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
self._copy_to_build_lib_dir(external_file)
|
_copy_to_build_lib_dir(self.build_lib, external_file)
|
||||||
|
|
||||||
def _generate_binary_graph(self, binary_graph_target):
|
def _generate_binary_graph(self, binary_graph_target):
|
||||||
"""Generate binary graph for a particular MediaPipe binary graph target."""
|
"""Generate binary graph for a particular MediaPipe binary graph target."""
|
||||||
|
@ -277,15 +286,27 @@ class BuildModules(build_ext.build_ext):
|
||||||
bazel_command.append('--define=OPENCV=source')
|
bazel_command.append('--define=OPENCV=source')
|
||||||
if subprocess.call(bazel_command) != 0:
|
if subprocess.call(bazel_command) != 0:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
self._copy_to_build_lib_dir(binary_graph_target + '.binarypb')
|
_copy_to_build_lib_dir(self.build_lib, binary_graph_target + '.binarypb')
|
||||||
|
|
||||||
def _copy_to_build_lib_dir(self, file):
|
|
||||||
"""Copy a file from bazel-bin to the build lib dir."""
|
class GenerateMetadataSchema(build_ext.build_ext):
|
||||||
dst = os.path.join(self.build_lib + '/', file)
|
"""Generate metadata python schema files."""
|
||||||
dst_dir = os.path.dirname(dst)
|
|
||||||
if not os.path.exists(dst_dir):
|
def run(self):
|
||||||
os.makedirs(dst_dir)
|
for target in ['metadata_schema_py', 'schema_py']:
|
||||||
shutil.copyfile(os.path.join('bazel-bin/', file), dst)
|
bazel_command = [
|
||||||
|
'bazel',
|
||||||
|
'build',
|
||||||
|
'--compilation_mode=opt',
|
||||||
|
'--define=MEDIAPIPE_DISABLE_GPU=1',
|
||||||
|
'--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable),
|
||||||
|
'//mediapipe/tasks/metadata:' + target,
|
||||||
|
]
|
||||||
|
if subprocess.call(bazel_command) != 0:
|
||||||
|
sys.exit(-1)
|
||||||
|
_copy_to_build_lib_dir(
|
||||||
|
self.build_lib,
|
||||||
|
'mediapipe/tasks/metadata/' + target + '_generated.py')
|
||||||
|
|
||||||
|
|
||||||
class BazelExtension(setuptools.Extension):
|
class BazelExtension(setuptools.Extension):
|
||||||
|
@ -375,6 +396,7 @@ class BuildPy(build_py.build_py):
|
||||||
build_ext_obj = self.distribution.get_command_obj('build_ext')
|
build_ext_obj = self.distribution.get_command_obj('build_ext')
|
||||||
build_ext_obj.link_opencv = self.link_opencv
|
build_ext_obj.link_opencv = self.link_opencv
|
||||||
self.run_command('gen_protos')
|
self.run_command('gen_protos')
|
||||||
|
self.run_command('generate_metadata_schema')
|
||||||
self.run_command('build_modules')
|
self.run_command('build_modules')
|
||||||
self.run_command('build_ext')
|
self.run_command('build_ext')
|
||||||
build_py.build_py.run(self)
|
build_py.build_py.run(self)
|
||||||
|
@ -434,18 +456,25 @@ setuptools.setup(
|
||||||
author_email='mediapipe@google.com',
|
author_email='mediapipe@google.com',
|
||||||
long_description=_get_long_description(),
|
long_description=_get_long_description(),
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type='text/markdown',
|
||||||
packages=setuptools.find_packages(exclude=['mediapipe.examples.desktop.*']),
|
packages=setuptools.find_packages(
|
||||||
|
exclude=['mediapipe.examples.desktop.*', 'mediapipe.model_maker.*']),
|
||||||
install_requires=_parse_requirements('requirements.txt'),
|
install_requires=_parse_requirements('requirements.txt'),
|
||||||
cmdclass={
|
cmdclass={
|
||||||
'build_py': BuildPy,
|
'build_py': BuildPy,
|
||||||
'gen_protos': GeneratePyProtos,
|
|
||||||
'build_modules': BuildModules,
|
'build_modules': BuildModules,
|
||||||
'build_ext': BuildExtension,
|
'build_ext': BuildExtension,
|
||||||
|
'generate_metadata_schema': GenerateMetadataSchema,
|
||||||
|
'gen_protos': GeneratePyProtos,
|
||||||
'install': Install,
|
'install': Install,
|
||||||
'restore': Restore,
|
'restore': Restore,
|
||||||
},
|
},
|
||||||
ext_modules=[
|
ext_modules=[
|
||||||
BazelExtension('//mediapipe/python:_framework_bindings'),
|
BazelExtension('//mediapipe/python:_framework_bindings'),
|
||||||
|
BazelExtension(
|
||||||
|
'//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version'),
|
||||||
|
BazelExtension(
|
||||||
|
'//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers'
|
||||||
|
),
|
||||||
],
|
],
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
|
|
47
tsconfig.json
Normal file
47
tsconfig.json
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "es2017",
|
||||||
|
"module": "commonjs",
|
||||||
|
"lib": ["ES2017", "dom"],
|
||||||
|
"declaration": true,
|
||||||
|
"moduleResolution": "node",
|
||||||
|
"esModuleInterop": true,
|
||||||
|
"noImplicitAny": true,
|
||||||
|
"inlineSourceMap": true,
|
||||||
|
"inlineSources": true,
|
||||||
|
"strict": true,
|
||||||
|
"types": ["@types/offscreencanvas"],
|
||||||
|
"rootDirs": [
|
||||||
|
".",
|
||||||
|
"./bazel-out/host/bin",
|
||||||
|
"./bazel-out/darwin-dbg/bin",
|
||||||
|
"./bazel-out/darwin-fastbuild/bin",
|
||||||
|
"./bazel-out/darwin-opt/bin",
|
||||||
|
"./bazel-out/darwin_arm64-dbg/bin",
|
||||||
|
"./bazel-out/darwin_arm64-fastbuild/bin",
|
||||||
|
"./bazel-out/darwin_arm64-opt/bin",
|
||||||
|
"./bazel-out/k8-dbg/bin",
|
||||||
|
"./bazel-out/k8-fastbuild/bin",
|
||||||
|
"./bazel-out/k8-opt/bin",
|
||||||
|
"./bazel-out/x64_windows-dbg/bin",
|
||||||
|
"./bazel-out/x64_windows-fastbuild/bin",
|
||||||
|
"./bazel-out/x64_windows-opt/bin",
|
||||||
|
"./bazel-out/darwin-dbg/bin",
|
||||||
|
"./bazel-out/darwin-fastbuild/bin",
|
||||||
|
"./bazel-out/darwin-opt/bin",
|
||||||
|
"./bazel-out/k8-dbg/bin",
|
||||||
|
"./bazel-out/k8-fastbuild/bin",
|
||||||
|
"./bazel-out/k8-opt/bin",
|
||||||
|
"./bazel-out/x64_windows-dbg/bin",
|
||||||
|
"./bazel-out/x64_windows-fastbuild/bin",
|
||||||
|
"./bazel-out/x64_windows-opt/bin",
|
||||||
|
"./bazel-out/k8-fastbuild-ST-4a519fd6d3e4/bin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"exclude": [
|
||||||
|
"./_bazel_bin",
|
||||||
|
"./_bazel_buildbot",
|
||||||
|
"./_bazel_out",
|
||||||
|
"./_bazel_testlogs"
|
||||||
|
]
|
||||||
|
}
|
626
yarn.lock
Normal file
626
yarn.lock
Normal file
|
@ -0,0 +1,626 @@
|
||||||
|
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
|
||||||
|
# yarn lockfile v1
|
||||||
|
|
||||||
|
|
||||||
|
"@babel/parser@^7.9.4":
|
||||||
|
version "7.20.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.20.1.tgz#3e045a92f7b4623cafc2425eddcb8cf2e54f9cc5"
|
||||||
|
integrity sha512-hp0AYxaZJhxULfM1zyp7Wgr+pSUKBcP3M+PHnSzWGdXOzg/kHWIgiUWARvubhUKGOEw3xqY4x+lyZ9ytBVcELw==
|
||||||
|
|
||||||
|
"@bazel/typescript@^5.7.1":
|
||||||
|
version "5.7.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/@bazel/typescript/-/typescript-5.7.1.tgz#e585bcdc54a4ccb23d99c3e1206abf4853cf0682"
|
||||||
|
integrity sha512-MAnAtFxA2znadm81+rbYXcyWX1DEF/urzZ1F4LBq+w27EQ4PGyqIqCM5om7JcoSZJwjjMoBJc3SflRsMrZZ6+g==
|
||||||
|
dependencies:
|
||||||
|
"@bazel/worker" "5.7.1"
|
||||||
|
semver "5.6.0"
|
||||||
|
source-map-support "0.5.9"
|
||||||
|
tsutils "3.21.0"
|
||||||
|
|
||||||
|
"@bazel/worker@5.7.1":
|
||||||
|
version "5.7.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/@bazel/worker/-/worker-5.7.1.tgz#2c4a9bd0e0ef75e496aec9599ff64a87307e7dad"
|
||||||
|
integrity sha512-UndmQVRqK0t0NMNl8I1P5XmxzdPvMA0X6jufszpfwy5gyzjOxeiOIzmC0ALCOx78CuJqOB/8WOI1pwTRmhd0tg==
|
||||||
|
dependencies:
|
||||||
|
google-protobuf "^3.6.1"
|
||||||
|
|
||||||
|
"@protobufjs/aspromise@^1.1.1", "@protobufjs/aspromise@^1.1.2":
|
||||||
|
version "1.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz#9b8b0cc663d669a7d8f6f5d0893a14d348f30fbf"
|
||||||
|
integrity sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==
|
||||||
|
|
||||||
|
"@protobufjs/base64@^1.1.2":
|
||||||
|
version "1.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/base64/-/base64-1.1.2.tgz#4c85730e59b9a1f1f349047dbf24296034bb2735"
|
||||||
|
integrity sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==
|
||||||
|
|
||||||
|
"@protobufjs/codegen@^2.0.4":
|
||||||
|
version "2.0.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/codegen/-/codegen-2.0.4.tgz#7ef37f0d010fb028ad1ad59722e506d9262815cb"
|
||||||
|
integrity sha512-YyFaikqM5sH0ziFZCN3xDC7zeGaB/d0IUb9CATugHWbd1FRFwWwt4ld4OYMPWu5a3Xe01mGAULCdqhMlPl29Jg==
|
||||||
|
|
||||||
|
"@protobufjs/eventemitter@^1.1.0":
|
||||||
|
version "1.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz#355cbc98bafad5978f9ed095f397621f1d066b70"
|
||||||
|
integrity sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==
|
||||||
|
|
||||||
|
"@protobufjs/fetch@^1.1.0":
|
||||||
|
version "1.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/fetch/-/fetch-1.1.0.tgz#ba99fb598614af65700c1619ff06d454b0d84c45"
|
||||||
|
integrity sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==
|
||||||
|
dependencies:
|
||||||
|
"@protobufjs/aspromise" "^1.1.1"
|
||||||
|
"@protobufjs/inquire" "^1.1.0"
|
||||||
|
|
||||||
|
"@protobufjs/float@^1.0.2":
|
||||||
|
version "1.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/float/-/float-1.0.2.tgz#5e9e1abdcb73fc0a7cb8b291df78c8cbd97b87d1"
|
||||||
|
integrity sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==
|
||||||
|
|
||||||
|
"@protobufjs/inquire@^1.1.0":
|
||||||
|
version "1.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/inquire/-/inquire-1.1.0.tgz#ff200e3e7cf2429e2dcafc1140828e8cc638f089"
|
||||||
|
integrity sha512-kdSefcPdruJiFMVSbn801t4vFK7KB/5gd2fYvrxhuJYg8ILrmn9SKSX2tZdV6V+ksulWqS7aXjBcRXl3wHoD9Q==
|
||||||
|
|
||||||
|
"@protobufjs/path@^1.1.2":
|
||||||
|
version "1.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/path/-/path-1.1.2.tgz#6cc2b20c5c9ad6ad0dccfd21ca7673d8d7fbf68d"
|
||||||
|
integrity sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==
|
||||||
|
|
||||||
|
"@protobufjs/pool@^1.1.0":
|
||||||
|
version "1.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/pool/-/pool-1.1.0.tgz#09fd15f2d6d3abfa9b65bc366506d6ad7846ff54"
|
||||||
|
integrity sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==
|
||||||
|
|
||||||
|
"@protobufjs/utf8@^1.1.0":
|
||||||
|
version "1.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@protobufjs/utf8/-/utf8-1.1.0.tgz#a777360b5b39a1a2e5106f8e858f2fd2d060c570"
|
||||||
|
integrity sha512-Vvn3zZrhQZkkBE8LSuW3em98c0FwgO4nxzv6OdSxPKJIEKY2bGbHn+mhGIPerzI4twdxaP8/0+06HBpwf345Lw==
|
||||||
|
|
||||||
|
"@types/google-protobuf@^3.15.6":
|
||||||
|
version "3.15.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/@types/google-protobuf/-/google-protobuf-3.15.6.tgz#674a69493ef2c849b95eafe69167ea59079eb504"
|
||||||
|
integrity sha512-pYVNNJ+winC4aek+lZp93sIKxnXt5qMkuKmaqS3WGuTq0Bw1ZDYNBgzG5kkdtwcv+GmYJGo3yEg6z2cKKAiEdw==
|
||||||
|
|
||||||
|
"@types/linkify-it@*":
|
||||||
|
version "3.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/@types/linkify-it/-/linkify-it-3.0.2.tgz#fd2cd2edbaa7eaac7e7f3c1748b52a19143846c9"
|
||||||
|
integrity sha512-HZQYqbiFVWufzCwexrvh694SOim8z2d+xJl5UNamcvQFejLY/2YUtzXHYi3cHdI7PMlS8ejH2slRAOJQ32aNbA==
|
||||||
|
|
||||||
|
"@types/markdown-it@^12.2.3":
|
||||||
|
version "12.2.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/@types/markdown-it/-/markdown-it-12.2.3.tgz#0d6f6e5e413f8daaa26522904597be3d6cd93b51"
|
||||||
|
integrity sha512-GKMHFfv3458yYy+v/N8gjufHO6MSZKCOXpZc5GXIWWy8uldwfmPn98vp81gZ5f9SVw8YYBctgfJ22a2d7AOMeQ==
|
||||||
|
dependencies:
|
||||||
|
"@types/linkify-it" "*"
|
||||||
|
"@types/mdurl" "*"
|
||||||
|
|
||||||
|
"@types/mdurl@*":
|
||||||
|
version "1.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/@types/mdurl/-/mdurl-1.0.2.tgz#e2ce9d83a613bacf284c7be7d491945e39e1f8e9"
|
||||||
|
integrity sha512-eC4U9MlIcu2q0KQmXszyn5Akca/0jrQmwDRgpAMJai7qBWq4amIQhZyNau4VYGtCeALvW1/NtjzJJ567aZxfKA==
|
||||||
|
|
||||||
|
"@types/node@>=13.7.0":
|
||||||
|
version "18.11.9"
|
||||||
|
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.9.tgz#02d013de7058cea16d36168ef2fc653464cfbad4"
|
||||||
|
integrity sha512-CRpX21/kGdzjOpFsZSkcrXMGIBWMGNIHXXBVFSH+ggkftxg+XYP20TESbh+zFvFj3EQOl5byk0HTRn1IL6hbqg==
|
||||||
|
|
||||||
|
"@types/offscreencanvas@^2019.7.0":
|
||||||
|
version "2019.7.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d"
|
||||||
|
integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg==
|
||||||
|
|
||||||
|
acorn-jsx@^5.3.2:
|
||||||
|
version "5.3.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.3.2.tgz#7ed5bb55908b3b2f1bc55c6af1653bada7f07937"
|
||||||
|
integrity sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==
|
||||||
|
|
||||||
|
acorn@^8.8.0:
|
||||||
|
version "8.8.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/acorn/-/acorn-8.8.1.tgz#0a3f9cbecc4ec3bea6f0a80b66ae8dd2da250b73"
|
||||||
|
integrity sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==
|
||||||
|
|
||||||
|
ansi-styles@^4.1.0:
|
||||||
|
version "4.3.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/ansi-styles/-/ansi-styles-4.3.0.tgz#edd803628ae71c04c85ae7a0906edad34b648937"
|
||||||
|
integrity sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==
|
||||||
|
dependencies:
|
||||||
|
color-convert "^2.0.1"
|
||||||
|
|
||||||
|
argparse@^2.0.1:
|
||||||
|
version "2.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/argparse/-/argparse-2.0.1.tgz#246f50f3ca78a3240f6c997e8a9bd1eac49e4b38"
|
||||||
|
integrity sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==
|
||||||
|
|
||||||
|
balanced-match@^1.0.0:
|
||||||
|
version "1.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee"
|
||||||
|
integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==
|
||||||
|
|
||||||
|
bluebird@^3.7.2:
|
||||||
|
version "3.7.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/bluebird/-/bluebird-3.7.2.tgz#9f229c15be272454ffa973ace0dbee79a1b0c36f"
|
||||||
|
integrity sha512-XpNj6GDQzdfW+r2Wnn7xiSAd7TM3jzkxGXBGTtWKuSXv1xUV+azxAm8jdWZN06QTQk+2N2XB9jRDkvbmQmcRtg==
|
||||||
|
|
||||||
|
brace-expansion@^1.1.7:
|
||||||
|
version "1.1.11"
|
||||||
|
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-1.1.11.tgz#3c7fcbf529d87226f3d2f52b966ff5271eb441dd"
|
||||||
|
integrity sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==
|
||||||
|
dependencies:
|
||||||
|
balanced-match "^1.0.0"
|
||||||
|
concat-map "0.0.1"
|
||||||
|
|
||||||
|
brace-expansion@^2.0.1:
|
||||||
|
version "2.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/brace-expansion/-/brace-expansion-2.0.1.tgz#1edc459e0f0c548486ecf9fc99f2221364b9a0ae"
|
||||||
|
integrity sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==
|
||||||
|
dependencies:
|
||||||
|
balanced-match "^1.0.0"
|
||||||
|
|
||||||
|
buffer-from@^1.0.0:
|
||||||
|
version "1.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/buffer-from/-/buffer-from-1.1.2.tgz#2b146a6fd72e80b4f55d255f35ed59a3a9a41bd5"
|
||||||
|
integrity sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==
|
||||||
|
|
||||||
|
catharsis@^0.9.0:
|
||||||
|
version "0.9.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/catharsis/-/catharsis-0.9.0.tgz#40382a168be0e6da308c277d3a2b3eb40c7d2121"
|
||||||
|
integrity sha512-prMTQVpcns/tzFgFVkVp6ak6RykZyWb3gu8ckUpd6YkTlacOd3DXGJjIpD4Q6zJirizvaiAjSSHlOsA+6sNh2A==
|
||||||
|
dependencies:
|
||||||
|
lodash "^4.17.15"
|
||||||
|
|
||||||
|
chalk@^4.0.0:
|
||||||
|
version "4.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/chalk/-/chalk-4.1.2.tgz#aac4e2b7734a740867aeb16bf02aad556a1e7a01"
|
||||||
|
integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==
|
||||||
|
dependencies:
|
||||||
|
ansi-styles "^4.1.0"
|
||||||
|
supports-color "^7.1.0"
|
||||||
|
|
||||||
|
color-convert@^2.0.1:
|
||||||
|
version "2.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/color-convert/-/color-convert-2.0.1.tgz#72d3a68d598c9bdb3af2ad1e84f21d896abd4de3"
|
||||||
|
integrity sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==
|
||||||
|
dependencies:
|
||||||
|
color-name "~1.1.4"
|
||||||
|
|
||||||
|
color-name@~1.1.4:
|
||||||
|
version "1.1.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/color-name/-/color-name-1.1.4.tgz#c2a09a87acbde69543de6f63fa3995c826c536a2"
|
||||||
|
integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==
|
||||||
|
|
||||||
|
concat-map@0.0.1:
|
||||||
|
version "0.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/concat-map/-/concat-map-0.0.1.tgz#d8a96bd77fd68df7793a73036a3ba0d5405d477b"
|
||||||
|
integrity sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==
|
||||||
|
|
||||||
|
deep-is@~0.1.3:
|
||||||
|
version "0.1.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831"
|
||||||
|
integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==
|
||||||
|
|
||||||
|
entities@~2.1.0:
|
||||||
|
version "2.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/entities/-/entities-2.1.0.tgz#992d3129cf7df6870b96c57858c249a120f8b8b5"
|
||||||
|
integrity sha512-hCx1oky9PFrJ611mf0ifBLBRW8lUUVRlFolb5gWRfIELabBlbp9xZvrqZLZAs+NxFnbfQoeGd8wDkygjg7U85w==
|
||||||
|
|
||||||
|
escape-string-regexp@^2.0.0:
|
||||||
|
version "2.0.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/escape-string-regexp/-/escape-string-regexp-2.0.0.tgz#a30304e99daa32e23b2fd20f51babd07cffca344"
|
||||||
|
integrity sha512-UpzcLCXolUWcNu5HtVMHYdXJjArjsF9C0aNnquZYY4uW/Vu0miy5YoWvbV345HauVvcAUnpRuhMMcqTcGOY2+w==
|
||||||
|
|
||||||
|
escodegen@^1.13.0:
|
||||||
|
version "1.14.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/escodegen/-/escodegen-1.14.3.tgz#4e7b81fba61581dc97582ed78cab7f0e8d63f503"
|
||||||
|
integrity sha512-qFcX0XJkdg+PB3xjZZG/wKSuT1PnQWx57+TVSjIMmILd2yC/6ByYElPwJnslDsuWuSAp4AwJGumarAAmJch5Kw==
|
||||||
|
dependencies:
|
||||||
|
esprima "^4.0.1"
|
||||||
|
estraverse "^4.2.0"
|
||||||
|
esutils "^2.0.2"
|
||||||
|
optionator "^0.8.1"
|
||||||
|
optionalDependencies:
|
||||||
|
source-map "~0.6.1"
|
||||||
|
|
||||||
|
eslint-visitor-keys@^3.3.0:
|
||||||
|
version "3.3.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/eslint-visitor-keys/-/eslint-visitor-keys-3.3.0.tgz#f6480fa6b1f30efe2d1968aa8ac745b862469826"
|
||||||
|
integrity sha512-mQ+suqKJVyeuwGYHAdjMFqjCyfl8+Ldnxuyp3ldiMBFKkvytrXUZWaiPCEav8qDHKty44bD+qV1IP4T+w+xXRA==
|
||||||
|
|
||||||
|
espree@^9.0.0:
|
||||||
|
version "9.4.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/espree/-/espree-9.4.0.tgz#cd4bc3d6e9336c433265fc0aa016fc1aaf182f8a"
|
||||||
|
integrity sha512-DQmnRpLj7f6TgN/NYb0MTzJXL+vJF9h3pHy4JhCIs3zwcgez8xmGg3sXHcEO97BrmO2OSvCwMdfdlyl+E9KjOw==
|
||||||
|
dependencies:
|
||||||
|
acorn "^8.8.0"
|
||||||
|
acorn-jsx "^5.3.2"
|
||||||
|
eslint-visitor-keys "^3.3.0"
|
||||||
|
|
||||||
|
esprima@^4.0.1:
|
||||||
|
version "4.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/esprima/-/esprima-4.0.1.tgz#13b04cdb3e6c5d19df91ab6987a8695619b0aa71"
|
||||||
|
integrity sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==
|
||||||
|
|
||||||
|
estraverse@^4.2.0:
|
||||||
|
version "4.3.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-4.3.0.tgz#398ad3f3c5a24948be7725e83d11a7de28cdbd1d"
|
||||||
|
integrity sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==
|
||||||
|
|
||||||
|
estraverse@^5.1.0:
|
||||||
|
version "5.3.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/estraverse/-/estraverse-5.3.0.tgz#2eea5290702f26ab8fe5370370ff86c965d21123"
|
||||||
|
integrity sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==
|
||||||
|
|
||||||
|
esutils@^2.0.2:
|
||||||
|
version "2.0.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/esutils/-/esutils-2.0.3.tgz#74d2eb4de0b8da1293711910d50775b9b710ef64"
|
||||||
|
integrity sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==
|
||||||
|
|
||||||
|
fast-levenshtein@~2.0.6:
|
||||||
|
version "2.0.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917"
|
||||||
|
integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==
|
||||||
|
|
||||||
|
fs.realpath@^1.0.0:
|
||||||
|
version "1.0.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/fs.realpath/-/fs.realpath-1.0.0.tgz#1504ad2523158caa40db4a2787cb01411994ea4f"
|
||||||
|
integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==
|
||||||
|
|
||||||
|
glob@^7.1.3:
|
||||||
|
version "7.2.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/glob/-/glob-7.2.3.tgz#b8df0fb802bbfa8e89bd1d938b4e16578ed44f2b"
|
||||||
|
integrity sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==
|
||||||
|
dependencies:
|
||||||
|
fs.realpath "^1.0.0"
|
||||||
|
inflight "^1.0.4"
|
||||||
|
inherits "2"
|
||||||
|
minimatch "^3.1.1"
|
||||||
|
once "^1.3.0"
|
||||||
|
path-is-absolute "^1.0.0"
|
||||||
|
|
||||||
|
glob@^8.0.0:
|
||||||
|
version "8.0.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/glob/-/glob-8.0.3.tgz#415c6eb2deed9e502c68fa44a272e6da6eeca42e"
|
||||||
|
integrity sha512-ull455NHSHI/Y1FqGaaYFaLGkNMMJbavMrEGFXG/PGrg6y7sutWHUHrz6gy6WEBH6akM1M414dWKCNs+IhKdiQ==
|
||||||
|
dependencies:
|
||||||
|
fs.realpath "^1.0.0"
|
||||||
|
inflight "^1.0.4"
|
||||||
|
inherits "2"
|
||||||
|
minimatch "^5.0.1"
|
||||||
|
once "^1.3.0"
|
||||||
|
|
||||||
|
google-protobuf@^3.15.5, google-protobuf@^3.21.2, google-protobuf@^3.6.1:
|
||||||
|
version "3.21.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/google-protobuf/-/google-protobuf-3.21.2.tgz#4580a2bea8bbb291ee579d1fefb14d6fa3070ea4"
|
||||||
|
integrity sha512-3MSOYFO5U9mPGikIYCzK0SaThypfGgS6bHqrUGXG3DPHCrb+txNqeEcns1W0lkGfk0rCyNXm7xB9rMxnCiZOoA==
|
||||||
|
|
||||||
|
graceful-fs@^4.1.9:
|
||||||
|
version "4.2.10"
|
||||||
|
resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.10.tgz#147d3a006da4ca3ce14728c7aefc287c367d7a6c"
|
||||||
|
integrity sha512-9ByhssR2fPVsNZj478qUUbKfmL0+t5BDVyjShtyZZLiK7ZDAArFFfopyOTj0M05wE2tJPisA4iTnnXl2YoPvOA==
|
||||||
|
|
||||||
|
has-flag@^4.0.0:
|
||||||
|
version "4.0.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/has-flag/-/has-flag-4.0.0.tgz#944771fd9c81c81265c4d6941860da06bb59479b"
|
||||||
|
integrity sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==
|
||||||
|
|
||||||
|
inflight@^1.0.4:
|
||||||
|
version "1.0.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/inflight/-/inflight-1.0.6.tgz#49bd6331d7d02d0c09bc910a1075ba8165b56df9"
|
||||||
|
integrity sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==
|
||||||
|
dependencies:
|
||||||
|
once "^1.3.0"
|
||||||
|
wrappy "1"
|
||||||
|
|
||||||
|
inherits@2:
|
||||||
|
version "2.0.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.4.tgz#0fa2c64f932917c3433a0ded55363aae37416b7c"
|
||||||
|
integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==
|
||||||
|
|
||||||
|
js2xmlparser@^4.0.2:
|
||||||
|
version "4.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/js2xmlparser/-/js2xmlparser-4.0.2.tgz#2a1fdf01e90585ef2ae872a01bc169c6a8d5e60a"
|
||||||
|
integrity sha512-6n4D8gLlLf1n5mNLQPRfViYzu9RATblzPEtm1SthMX1Pjao0r9YI9nw7ZIfRxQMERS87mcswrg+r/OYrPRX6jA==
|
||||||
|
dependencies:
|
||||||
|
xmlcreate "^2.0.4"
|
||||||
|
|
||||||
|
jsdoc@^3.6.3:
|
||||||
|
version "3.6.11"
|
||||||
|
resolved "https://registry.yarnpkg.com/jsdoc/-/jsdoc-3.6.11.tgz#8bbb5747e6f579f141a5238cbad4e95e004458ce"
|
||||||
|
integrity sha512-8UCU0TYeIYD9KeLzEcAu2q8N/mx9O3phAGl32nmHlE0LpaJL71mMkP4d+QE5zWfNt50qheHtOZ0qoxVrsX5TUg==
|
||||||
|
dependencies:
|
||||||
|
"@babel/parser" "^7.9.4"
|
||||||
|
"@types/markdown-it" "^12.2.3"
|
||||||
|
bluebird "^3.7.2"
|
||||||
|
catharsis "^0.9.0"
|
||||||
|
escape-string-regexp "^2.0.0"
|
||||||
|
js2xmlparser "^4.0.2"
|
||||||
|
klaw "^3.0.0"
|
||||||
|
markdown-it "^12.3.2"
|
||||||
|
markdown-it-anchor "^8.4.1"
|
||||||
|
marked "^4.0.10"
|
||||||
|
mkdirp "^1.0.4"
|
||||||
|
requizzle "^0.2.3"
|
||||||
|
strip-json-comments "^3.1.0"
|
||||||
|
taffydb "2.6.2"
|
||||||
|
underscore "~1.13.2"
|
||||||
|
|
||||||
|
klaw@^3.0.0:
|
||||||
|
version "3.0.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/klaw/-/klaw-3.0.0.tgz#b11bec9cf2492f06756d6e809ab73a2910259146"
|
||||||
|
integrity sha512-0Fo5oir+O9jnXu5EefYbVK+mHMBeEVEy2cmctR1O1NECcCkPRreJKrS6Qt/j3KC2C148Dfo9i3pCmCMsdqGr0g==
|
||||||
|
dependencies:
|
||||||
|
graceful-fs "^4.1.9"
|
||||||
|
|
||||||
|
levn@~0.3.0:
|
||||||
|
version "0.3.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/levn/-/levn-0.3.0.tgz#3b09924edf9f083c0490fdd4c0bc4421e04764ee"
|
||||||
|
integrity sha512-0OO4y2iOHix2W6ujICbKIaEQXvFQHue65vUG3pb5EUomzPI90z9hsA1VsO/dbIIpC53J8gxM9Q4Oho0jrCM/yA==
|
||||||
|
dependencies:
|
||||||
|
prelude-ls "~1.1.2"
|
||||||
|
type-check "~0.3.2"
|
||||||
|
|
||||||
|
linkify-it@^3.0.1:
|
||||||
|
version "3.0.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/linkify-it/-/linkify-it-3.0.3.tgz#a98baf44ce45a550efb4d49c769d07524cc2fa2e"
|
||||||
|
integrity sha512-ynTsyrFSdE5oZ/O9GEf00kPngmOfVwazR5GKDq6EYfhlpFug3J2zybX56a2PRRpc9P+FuSoGNAwjlbDs9jJBPQ==
|
||||||
|
dependencies:
|
||||||
|
uc.micro "^1.0.1"
|
||||||
|
|
||||||
|
lodash@^4.17.14, lodash@^4.17.15:
|
||||||
|
version "4.17.21"
|
||||||
|
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
|
||||||
|
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
|
||||||
|
|
||||||
|
long@^5.0.0:
|
||||||
|
version "5.2.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/long/-/long-5.2.0.tgz#2696dadf4b4da2ce3f6f6b89186085d94d52fd61"
|
||||||
|
integrity sha512-9RTUNjK60eJbx3uz+TEGF7fUr29ZDxR5QzXcyDpeSfeH28S9ycINflOgOlppit5U+4kNTe83KQnMEerw7GmE8w==
|
||||||
|
|
||||||
|
lru-cache@^6.0.0:
|
||||||
|
version "6.0.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/lru-cache/-/lru-cache-6.0.0.tgz#6d6fe6570ebd96aaf90fcad1dafa3b2566db3a94"
|
||||||
|
integrity sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==
|
||||||
|
dependencies:
|
||||||
|
yallist "^4.0.0"
|
||||||
|
|
||||||
|
markdown-it-anchor@^8.4.1:
|
||||||
|
version "8.6.5"
|
||||||
|
resolved "https://registry.yarnpkg.com/markdown-it-anchor/-/markdown-it-anchor-8.6.5.tgz#30c4bc5bbff327f15ce3c429010ec7ba75e7b5f8"
|
||||||
|
integrity sha512-PI1qEHHkTNWT+X6Ip9w+paonfIQ+QZP9sCeMYi47oqhH+EsW8CrJ8J7CzV19QVOj6il8ATGbK2nTECj22ZHGvQ==
|
||||||
|
|
||||||
|
markdown-it@^12.3.2:
|
||||||
|
version "12.3.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/markdown-it/-/markdown-it-12.3.2.tgz#bf92ac92283fe983fe4de8ff8abfb5ad72cd0c90"
|
||||||
|
integrity sha512-TchMembfxfNVpHkbtriWltGWc+m3xszaRD0CZup7GFFhzIgQqxIfn3eGj1yZpfuflzPvfkt611B2Q/Bsk1YnGg==
|
||||||
|
dependencies:
|
||||||
|
argparse "^2.0.1"
|
||||||
|
entities "~2.1.0"
|
||||||
|
linkify-it "^3.0.1"
|
||||||
|
mdurl "^1.0.1"
|
||||||
|
uc.micro "^1.0.5"
|
||||||
|
|
||||||
|
marked@^4.0.10:
|
||||||
|
version "4.2.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/marked/-/marked-4.2.1.tgz#eaa32594e45b4e58c02e4d118531fd04345de3b4"
|
||||||
|
integrity sha512-VK1/jNtwqDLvPktNpL0Fdg3qoeUZhmRsuiIjPEy/lHwXW4ouLoZfO4XoWd4ClDt+hupV1VLpkZhEovjU0W/kqA==
|
||||||
|
|
||||||
|
mdurl@^1.0.1:
|
||||||
|
version "1.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/mdurl/-/mdurl-1.0.1.tgz#fe85b2ec75a59037f2adfec100fd6c601761152e"
|
||||||
|
integrity sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g==
|
||||||
|
|
||||||
|
minimatch@^3.1.1:
|
||||||
|
version "3.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-3.1.2.tgz#19cd194bfd3e428f049a70817c038d89ab4be35b"
|
||||||
|
integrity sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==
|
||||||
|
dependencies:
|
||||||
|
brace-expansion "^1.1.7"
|
||||||
|
|
||||||
|
minimatch@^5.0.1:
|
||||||
|
version "5.1.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/minimatch/-/minimatch-5.1.0.tgz#1717b464f4971b144f6aabe8f2d0b8e4511e09c7"
|
||||||
|
integrity sha512-9TPBGGak4nHfGZsPBohm9AWg6NoT7QTCehS3BIJABslyZbzxfV78QM2Y6+i741OPZIafFAaiiEMh5OyIrJPgtg==
|
||||||
|
dependencies:
|
||||||
|
brace-expansion "^2.0.1"
|
||||||
|
|
||||||
|
minimist@^1.2.0:
|
||||||
|
version "1.2.7"
|
||||||
|
resolved "https://registry.yarnpkg.com/minimist/-/minimist-1.2.7.tgz#daa1c4d91f507390437c6a8bc01078e7000c4d18"
|
||||||
|
integrity sha512-bzfL1YUZsP41gmu/qjrEk0Q6i2ix/cVeAhbCbqH9u3zYutS1cLg00qhrD0M2MVdCcx4Sc0UpP2eBWo9rotpq6g==
|
||||||
|
|
||||||
|
mkdirp@^1.0.4:
|
||||||
|
version "1.0.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-1.0.4.tgz#3eb5ed62622756d79a5f0e2a221dfebad75c2f7e"
|
||||||
|
integrity sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==
|
||||||
|
|
||||||
|
once@^1.3.0:
|
||||||
|
version "1.4.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/once/-/once-1.4.0.tgz#583b1aa775961d4b113ac17d9c50baef9dd76bd1"
|
||||||
|
integrity sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==
|
||||||
|
dependencies:
|
||||||
|
wrappy "1"
|
||||||
|
|
||||||
|
optionator@^0.8.1:
|
||||||
|
version "0.8.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/optionator/-/optionator-0.8.3.tgz#84fa1d036fe9d3c7e21d99884b601167ec8fb495"
|
||||||
|
integrity sha512-+IW9pACdk3XWmmTXG8m3upGUJst5XRGzxMRjXzAuJ1XnIFNvfhjjIuYkDvysnPQ7qzqVzLt78BCruntqRhWQbA==
|
||||||
|
dependencies:
|
||||||
|
deep-is "~0.1.3"
|
||||||
|
fast-levenshtein "~2.0.6"
|
||||||
|
levn "~0.3.0"
|
||||||
|
prelude-ls "~1.1.2"
|
||||||
|
type-check "~0.3.2"
|
||||||
|
word-wrap "~1.2.3"
|
||||||
|
|
||||||
|
path-is-absolute@^1.0.0:
|
||||||
|
version "1.0.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/path-is-absolute/-/path-is-absolute-1.0.1.tgz#174b9268735534ffbc7ace6bf53a5a9e1b5c5f5f"
|
||||||
|
integrity sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==
|
||||||
|
|
||||||
|
prelude-ls@~1.1.2:
|
||||||
|
version "1.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.1.2.tgz#21932a549f5e52ffd9a827f570e04be62a97da54"
|
||||||
|
integrity sha512-ESF23V4SKG6lVSGZgYNpbsiaAkdab6ZgOxe52p7+Kid3W3u3bxR4Vfd/o21dmN7jSt0IwgZ4v5MUd26FEtXE9w==
|
||||||
|
|
||||||
|
protobufjs-cli@^1.0.2:
|
||||||
|
version "1.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/protobufjs-cli/-/protobufjs-cli-1.0.2.tgz#905fc49007cf4aaf3c45d5f250eb294eedeea062"
|
||||||
|
integrity sha512-cz9Pq9p/Zs7okc6avH20W7QuyjTclwJPgqXG11jNaulfS3nbVisID8rC+prfgq0gbZE0w9LBFd1OKFF03kgFzg==
|
||||||
|
dependencies:
|
||||||
|
chalk "^4.0.0"
|
||||||
|
escodegen "^1.13.0"
|
||||||
|
espree "^9.0.0"
|
||||||
|
estraverse "^5.1.0"
|
||||||
|
glob "^8.0.0"
|
||||||
|
jsdoc "^3.6.3"
|
||||||
|
minimist "^1.2.0"
|
||||||
|
semver "^7.1.2"
|
||||||
|
tmp "^0.2.1"
|
||||||
|
uglify-js "^3.7.7"
|
||||||
|
|
||||||
|
protobufjs@^7.1.2:
|
||||||
|
version "7.1.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/protobufjs/-/protobufjs-7.1.2.tgz#a0cf6aeaf82f5625bffcf5a38b7cd2a7de05890c"
|
||||||
|
integrity sha512-4ZPTPkXCdel3+L81yw3dG6+Kq3umdWKh7Dc7GW/CpNk4SX3hK58iPCWeCyhVTDrbkNeKrYNZ7EojM5WDaEWTLQ==
|
||||||
|
dependencies:
|
||||||
|
"@protobufjs/aspromise" "^1.1.2"
|
||||||
|
"@protobufjs/base64" "^1.1.2"
|
||||||
|
"@protobufjs/codegen" "^2.0.4"
|
||||||
|
"@protobufjs/eventemitter" "^1.1.0"
|
||||||
|
"@protobufjs/fetch" "^1.1.0"
|
||||||
|
"@protobufjs/float" "^1.0.2"
|
||||||
|
"@protobufjs/inquire" "^1.1.0"
|
||||||
|
"@protobufjs/path" "^1.1.2"
|
||||||
|
"@protobufjs/pool" "^1.1.0"
|
||||||
|
"@protobufjs/utf8" "^1.1.0"
|
||||||
|
"@types/node" ">=13.7.0"
|
||||||
|
long "^5.0.0"
|
||||||
|
|
||||||
|
requizzle@^0.2.3:
|
||||||
|
version "0.2.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/requizzle/-/requizzle-0.2.3.tgz#4675c90aacafb2c036bd39ba2daa4a1cb777fded"
|
||||||
|
integrity sha512-YanoyJjykPxGHii0fZP0uUPEXpvqfBDxWV7s6GKAiiOsiqhX6vHNyW3Qzdmqp/iq/ExbhaGbVrjB4ruEVSM4GQ==
|
||||||
|
dependencies:
|
||||||
|
lodash "^4.17.14"
|
||||||
|
|
||||||
|
rimraf@^3.0.0:
|
||||||
|
version "3.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-3.0.2.tgz#f1a5402ba6220ad52cc1282bac1ae3aa49fd061a"
|
||||||
|
integrity sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==
|
||||||
|
dependencies:
|
||||||
|
glob "^7.1.3"
|
||||||
|
|
||||||
|
semver@5.6.0:
|
||||||
|
version "5.6.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/semver/-/semver-5.6.0.tgz#7e74256fbaa49c75aa7c7a205cc22799cac80004"
|
||||||
|
integrity sha512-RS9R6R35NYgQn++fkDWaOmqGoj4Ek9gGs+DPxNUZKuwE183xjJroKvyo1IzVFeXvUrvmALy6FWD5xrdJT25gMg==
|
||||||
|
|
||||||
|
semver@^7.1.2:
|
||||||
|
version "7.3.8"
|
||||||
|
resolved "https://registry.yarnpkg.com/semver/-/semver-7.3.8.tgz#07a78feafb3f7b32347d725e33de7e2a2df67798"
|
||||||
|
integrity sha512-NB1ctGL5rlHrPJtFDVIVzTyQylMLu9N9VICA6HSFJo8MCGVTMW6gfpicwKmmK/dAjTOrqu5l63JJOpDSrAis3A==
|
||||||
|
dependencies:
|
||||||
|
lru-cache "^6.0.0"
|
||||||
|
|
||||||
|
source-map-support@0.5.9:
|
||||||
|
version "0.5.9"
|
||||||
|
resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.9.tgz#41bc953b2534267ea2d605bccfa7bfa3111ced5f"
|
||||||
|
integrity sha512-gR6Rw4MvUlYy83vP0vxoVNzM6t8MUXqNuRsuBmBHQDu1Fh6X015FrLdgoDKcNdkwGubozq0P4N0Q37UyFVr1EA==
|
||||||
|
dependencies:
|
||||||
|
buffer-from "^1.0.0"
|
||||||
|
source-map "^0.6.0"
|
||||||
|
|
||||||
|
source-map@^0.6.0, source-map@~0.6.1:
|
||||||
|
version "0.6.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/source-map/-/source-map-0.6.1.tgz#74722af32e9614e9c287a8d0bbde48b5e2f1a263"
|
||||||
|
integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==
|
||||||
|
|
||||||
|
strip-json-comments@^3.1.0:
|
||||||
|
version "3.1.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.1.1.tgz#31f1281b3832630434831c310c01cccda8cbe006"
|
||||||
|
integrity sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==
|
||||||
|
|
||||||
|
supports-color@^7.1.0:
|
||||||
|
version "7.2.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/supports-color/-/supports-color-7.2.0.tgz#1b7dcdcb32b8138801b3e478ba6a51caa89648da"
|
||||||
|
integrity sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==
|
||||||
|
dependencies:
|
||||||
|
has-flag "^4.0.0"
|
||||||
|
|
||||||
|
taffydb@2.6.2:
|
||||||
|
version "2.6.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/taffydb/-/taffydb-2.6.2.tgz#7cbcb64b5a141b6a2efc2c5d2c67b4e150b2a268"
|
||||||
|
integrity sha512-y3JaeRSplks6NYQuCOj3ZFMO3j60rTwbuKCvZxsAraGYH2epusatvZ0baZYA01WsGqJBq/Dl6vOrMUJqyMj8kA==
|
||||||
|
|
||||||
|
tmp@^0.2.1:
|
||||||
|
version "0.2.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/tmp/-/tmp-0.2.1.tgz#8457fc3037dcf4719c251367a1af6500ee1ccf14"
|
||||||
|
integrity sha512-76SUhtfqR2Ijn+xllcI5P1oyannHNHByD80W1q447gU3mp9G9PSpGdWmjUOHRDPiHYacIk66W7ubDTuPF3BEtQ==
|
||||||
|
dependencies:
|
||||||
|
rimraf "^3.0.0"
|
||||||
|
|
||||||
|
ts-protoc-gen@^0.15.0:
|
||||||
|
version "0.15.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/ts-protoc-gen/-/ts-protoc-gen-0.15.0.tgz#2fec5930b46def7dcc9fa73c060d770b7b076b7b"
|
||||||
|
integrity sha512-TycnzEyrdVDlATJ3bWFTtra3SCiEP0W0vySXReAuEygXCUr1j2uaVyL0DhzjwuUdQoW5oXPwk6oZWeA0955V+g==
|
||||||
|
dependencies:
|
||||||
|
google-protobuf "^3.15.5"
|
||||||
|
|
||||||
|
tslib@^1.8.1:
|
||||||
|
version "1.14.1"
|
||||||
|
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.14.1.tgz#cf2d38bdc34a134bcaf1091c41f6619e2f672d00"
|
||||||
|
integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==
|
||||||
|
|
||||||
|
tsutils@3.21.0:
|
||||||
|
version "3.21.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.21.0.tgz#b48717d394cea6c1e096983eed58e9d61715b623"
|
||||||
|
integrity sha512-mHKK3iUXL+3UF6xL5k0PEhKRUBKPBCv/+RkEOpjRWxxx27KKRBmmA60A9pgOUvMi8GKhRMPEmjBRPzs2W7O1OA==
|
||||||
|
dependencies:
|
||||||
|
tslib "^1.8.1"
|
||||||
|
|
||||||
|
type-check@~0.3.2:
|
||||||
|
version "0.3.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/type-check/-/type-check-0.3.2.tgz#5884cab512cf1d355e3fb784f30804b2b520db72"
|
||||||
|
integrity sha512-ZCmOJdvOWDBYJlzAoFkC+Q0+bUyEOS1ltgp1MGU03fqHG+dbi9tBFU2Rd9QKiDZFAYrhPh2JUf7rZRIuHRKtOg==
|
||||||
|
dependencies:
|
||||||
|
prelude-ls "~1.1.2"
|
||||||
|
|
||||||
|
typescript@^4.8.4:
|
||||||
|
version "4.8.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.8.4.tgz#c464abca159669597be5f96b8943500b238e60e6"
|
||||||
|
integrity sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ==
|
||||||
|
|
||||||
|
uc.micro@^1.0.1, uc.micro@^1.0.5:
|
||||||
|
version "1.0.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/uc.micro/-/uc.micro-1.0.6.tgz#9c411a802a409a91fc6cf74081baba34b24499ac"
|
||||||
|
integrity sha512-8Y75pvTYkLJW2hWQHXxoqRgV7qb9B+9vFEtidML+7koHUFapnVJAZ6cKs+Qjz5Aw3aZWHMC6u0wJE3At+nSGwA==
|
||||||
|
|
||||||
|
uglify-js@^3.7.7:
|
||||||
|
version "3.17.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/uglify-js/-/uglify-js-3.17.4.tgz#61678cf5fa3f5b7eb789bb345df29afb8257c22c"
|
||||||
|
integrity sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==
|
||||||
|
|
||||||
|
underscore@~1.13.2:
|
||||||
|
version "1.13.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/underscore/-/underscore-1.13.6.tgz#04786a1f589dc6c09f761fc5f45b89e935136441"
|
||||||
|
integrity sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==
|
||||||
|
|
||||||
|
word-wrap@~1.2.3:
|
||||||
|
version "1.2.3"
|
||||||
|
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
|
||||||
|
integrity sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==
|
||||||
|
|
||||||
|
wrappy@1:
|
||||||
|
version "1.0.2"
|
||||||
|
resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f"
|
||||||
|
integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==
|
||||||
|
|
||||||
|
xmlcreate@^2.0.4:
|
||||||
|
version "2.0.4"
|
||||||
|
resolved "https://registry.yarnpkg.com/xmlcreate/-/xmlcreate-2.0.4.tgz#0c5ab0f99cdd02a81065fa9cd8f8ae87624889be"
|
||||||
|
integrity sha512-nquOebG4sngPmGPICTS5EnxqhKbCmz5Ox5hsszI2T6U5qdrJizBc+0ilYSEjTSzU0yZcmvppztXe/5Al5fUwdg==
|
||||||
|
|
||||||
|
yallist@^4.0.0:
|
||||||
|
version "4.0.0"
|
||||||
|
resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72"
|
||||||
|
integrity sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==
|
Loading…
Reference in New Issue
Block a user