Merge branch 'google:master' into face-stylizer-python-add-tests
This commit is contained in:
commit
a5716c9225
10
WORKSPACE
10
WORKSPACE
|
@ -239,6 +239,16 @@ http_archive(
|
||||||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
|
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "darts_clone",
|
||||||
|
build_file = "@//third_party:darts_clone.BUILD",
|
||||||
|
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
|
||||||
|
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
|
||||||
|
urls = [
|
||||||
|
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "org_tensorflow_text",
|
name = "org_tensorflow_text",
|
||||||
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",
|
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",
|
||||||
|
|
|
@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
} else if (packet_options.has_string_value()) {
|
} else if (packet_options.has_string_value()) {
|
||||||
packet.Set<std::string>();
|
packet.Set<std::string>();
|
||||||
} else if (packet_options.has_uint64_value()) {
|
} else if (packet_options.has_uint64_value()) {
|
||||||
packet.Set<uint64>();
|
packet.Set<uint64_t>();
|
||||||
} else if (packet_options.has_classification_list_value()) {
|
} else if (packet_options.has_classification_list_value()) {
|
||||||
packet.Set<ClassificationList>();
|
packet.Set<ClassificationList>();
|
||||||
} else if (packet_options.has_landmark_list_value()) {
|
} else if (packet_options.has_landmark_list_value()) {
|
||||||
|
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
} else if (packet_options.has_string_value()) {
|
} else if (packet_options.has_string_value()) {
|
||||||
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
||||||
} else if (packet_options.has_uint64_value()) {
|
} else if (packet_options.has_uint64_value()) {
|
||||||
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
|
packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
|
||||||
} else if (packet_options.has_classification_list_value()) {
|
} else if (packet_options.has_classification_list_value()) {
|
||||||
packet.Set(MakePacket<ClassificationList>(
|
packet.Set(MakePacket<ClassificationList>(
|
||||||
packet_options.classification_list_value()));
|
packet_options.classification_list_value()));
|
||||||
|
|
|
@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
||||||
void RunTimeStep(int64 timestamp, bool stream_payload) {
|
void RunTimeStep(int64_t timestamp, bool stream_payload) {
|
||||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||||
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
||||||
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
||||||
void RunTimeStep(int64 timestamp, const std::string& control_tag,
|
void RunTimeStep(int64_t timestamp, const std::string& control_tag,
|
||||||
bool control) {
|
bool control) {
|
||||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||||
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
||||||
|
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
|
||||||
}
|
}
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
|
||||||
}
|
}
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
|
|
@ -35,7 +35,7 @@ class MatrixToVectorCalculatorTest
|
||||||
void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; }
|
void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; }
|
||||||
|
|
||||||
void AppendInput(const std::vector<float>& column_major_data,
|
void AppendInput(const std::vector<float>& column_major_data,
|
||||||
int64 timestamp) {
|
int64_t timestamp) {
|
||||||
ASSERT_EQ(num_input_samples_ * num_input_channels_,
|
ASSERT_EQ(num_input_samples_ * num_input_channels_,
|
||||||
column_major_data.size());
|
column_major_data.size());
|
||||||
Eigen::Map<const Matrix> data_map(&column_major_data[0],
|
Eigen::Map<const Matrix> data_map(&column_major_data[0],
|
||||||
|
|
|
@ -51,9 +51,9 @@ class SimpleRunner : public CalculatorRunner {
|
||||||
|
|
||||||
virtual ~SimpleRunner() {}
|
virtual ~SimpleRunner() {}
|
||||||
|
|
||||||
void SetInput(const std::vector<int64>& timestamp_list) {
|
void SetInput(const std::vector<int64_t>& timestamp_list) {
|
||||||
MutableInputs()->Index(0).packets.clear();
|
MutableInputs()->Index(0).packets.clear();
|
||||||
for (const int64 ts : timestamp_list) {
|
for (const int64_t ts : timestamp_list) {
|
||||||
MutableInputs()->Index(0).packets.push_back(
|
MutableInputs()->Index(0).packets.push_back(
|
||||||
Adopt(new std::string(absl::StrCat("Frame #", ts)))
|
Adopt(new std::string(absl::StrCat("Frame #", ts)))
|
||||||
.At(Timestamp(ts)));
|
.At(Timestamp(ts)));
|
||||||
|
@ -72,8 +72,8 @@ class SimpleRunner : public CalculatorRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckOutputTimestamps(
|
void CheckOutputTimestamps(
|
||||||
const std::vector<int64>& expected_frames,
|
const std::vector<int64_t>& expected_frames,
|
||||||
const std::vector<int64>& expected_timestamps) const {
|
const std::vector<int64_t>& expected_timestamps) const {
|
||||||
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
|
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
|
||||||
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
|
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
@ -112,7 +112,7 @@ MATCHER_P2(PacketAtTimestamp, payload, timestamp,
|
||||||
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
|
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int64 actual_payload = arg.template Get<int64>();
|
int64_t actual_payload = arg.template Get<int64_t>();
|
||||||
if (actual_payload != payload) {
|
if (actual_payload != payload) {
|
||||||
*result_listener << "with incorrect payload = " << actual_payload;
|
*result_listener << "with incorrect payload = " << actual_payload;
|
||||||
return false;
|
return false;
|
||||||
|
@ -137,18 +137,18 @@ class ReproducibleJitterWithReflectionStrategyForTesting
|
||||||
//
|
//
|
||||||
// An EXPECT will fail if sequence is less than the number requested during
|
// An EXPECT will fail if sequence is less than the number requested during
|
||||||
// processing.
|
// processing.
|
||||||
static std::vector<uint64> random_sequence;
|
static std::vector<uint64_t> random_sequence;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual uint64 GetNextRandom(uint64 n) {
|
virtual uint64_t GetNextRandom(uint64_t n) {
|
||||||
EXPECT_LT(sequence_index_, random_sequence.size());
|
EXPECT_LT(sequence_index_, random_sequence.size());
|
||||||
return random_sequence[sequence_index_++] % n;
|
return random_sequence[sequence_index_++] % n;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32 sequence_index_ = 0;
|
int32_t sequence_index_ = 0;
|
||||||
};
|
};
|
||||||
std::vector<uint64>
|
std::vector<uint64_t>
|
||||||
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
|
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
|
||||||
|
|
||||||
// PacketResamplerCalculator child class which injects a specified stream
|
// PacketResamplerCalculator child class which injects a specified stream
|
||||||
|
@ -469,7 +469,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
||||||
}
|
}
|
||||||
)pb"));
|
)pb"));
|
||||||
|
|
||||||
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
|
for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) {
|
||||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
|
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
|
||||||
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,8 +43,8 @@ constexpr char kDisallowTag[] = "DISALLOW";
|
||||||
|
|
||||||
// Returns the timestamp values for a vector of Packets.
|
// Returns the timestamp values for a vector of Packets.
|
||||||
// TODO: puth this kind of test util in a common place.
|
// TODO: puth this kind of test util in a common place.
|
||||||
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
std::vector<int64_t> TimestampValues(const std::vector<Packet>& packets) {
|
||||||
std::vector<int64> result;
|
std::vector<int64_t> result;
|
||||||
for (const Packet& packet : packets) {
|
for (const Packet& packet : packets) {
|
||||||
result.push_back(packet.Timestamp().Value());
|
result.push_back(packet.Timestamp().Value());
|
||||||
}
|
}
|
||||||
|
@ -371,7 +371,7 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
|
||||||
for (int main_ts = 0; main_ts < 50; ++main_ts) {
|
for (int main_ts = 0; main_ts < 50; ++main_ts) {
|
||||||
send_packet("in", main_ts);
|
send_packet("in", main_ts);
|
||||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||||
std::vector<int64> ts_values = TimestampValues(outputs);
|
std::vector<int64_t> ts_values = TimestampValues(outputs);
|
||||||
EXPECT_EQ(ts_values.size(), main_ts + 1);
|
EXPECT_EQ(ts_values.size(), main_ts + 1);
|
||||||
for (int j = 0; j < main_ts + 1; ++j) {
|
for (int j = 0; j < main_ts + 1; ++j) {
|
||||||
EXPECT_EQ(ts_values[j], j);
|
EXPECT_EQ(ts_values[j], j);
|
||||||
|
|
|
@ -121,7 +121,7 @@ absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->Outputs().HasTag(kTagAtTimestamp)) {
|
if (cc->Outputs().HasTag(kTagAtTimestamp)) {
|
||||||
RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries())
|
RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries())
|
||||||
<< "For AT_TIMESTAMP tag, 2 input side packets are required.";
|
<< "For AT_TIMESTAMP tag, 2 input side packets are required.";
|
||||||
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64>();
|
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64_t>();
|
||||||
} else {
|
} else {
|
||||||
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries())
|
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries())
|
||||||
<< "Same number of input side packets and output streams is required.";
|
<< "Same number of input side packets and output streams is required.";
|
||||||
|
@ -178,8 +178,8 @@ absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
|
||||||
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
|
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
|
||||||
}
|
}
|
||||||
} else if (cc->Outputs().HasTag(kTagAtTimestamp)) {
|
} else if (cc->Outputs().HasTag(kTagAtTimestamp)) {
|
||||||
int64 timestamp =
|
int64_t timestamp =
|
||||||
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64>();
|
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64_t>();
|
||||||
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
|
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
|
||||||
cc->Outputs()
|
cc->Outputs()
|
||||||
.Get(output_tag_, i)
|
.Get(output_tag_, i)
|
||||||
|
|
|
@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
|
||||||
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
||||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||||
|
|
||||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
|
||||||
REGISTER_CALCULATOR(StringToInt32Calculator);
|
REGISTER_CALCULATOR(StringToInt32Calculator);
|
||||||
|
|
||||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
|
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
|
||||||
REGISTER_CALCULATOR(StringToUint32Calculator);
|
REGISTER_CALCULATOR(StringToUint32Calculator);
|
||||||
|
|
||||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
|
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
|
||||||
REGISTER_CALCULATOR(StringToInt64Calculator);
|
REGISTER_CALCULATOR(StringToInt64Calculator);
|
||||||
|
|
||||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
|
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
|
||||||
REGISTER_CALCULATOR(StringToUint64Calculator);
|
REGISTER_CALCULATOR(StringToUint64Calculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
|
||||||
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
||||||
frame_ptr->Height(), frame_ptr->WidthStep(),
|
frame_ptr->Height(), frame_ptr->WidthStep(),
|
||||||
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
||||||
[](uint8* data){});
|
[](uint8_t* data){});
|
||||||
ASSIGN_OR_RETURN(auto result,
|
ASSIGN_OR_RETURN(auto result,
|
||||||
runner->Run(image_frame, matrix, size, border_mode));
|
runner->Run(image_frame, matrix, size, border_mode));
|
||||||
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
||||||
|
|
|
@ -401,8 +401,8 @@ cc_library_with_tflite(
|
||||||
hdrs = ["inference_calculator.h"],
|
hdrs = ["inference_calculator.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"//mediapipe/util/tflite:tflite_model_loader",
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_calculator_cc_proto",
|
":inference_calculator_cc_proto",
|
||||||
|
@ -506,7 +506,7 @@ cc_library_with_tflite(
|
||||||
name = "tflite_delegate_ptr",
|
name = "tflite_delegate_ptr",
|
||||||
hdrs = ["tflite_delegate_ptr.h"],
|
hdrs = ["tflite_delegate_ptr.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -517,8 +517,8 @@ cc_library_with_tflite(
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":tflite_delegate_ptr",
|
":tflite_delegate_ptr",
|
||||||
"//mediapipe/util/tflite:tflite_model_loader",
|
"//mediapipe/util/tflite:tflite_model_loader",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":inference_runner",
|
":inference_runner",
|
||||||
|
@ -546,8 +546,8 @@ cc_library(
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
|
|
|
@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
|
||||||
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
||||||
}
|
}
|
||||||
return PacketAdopting<tflite::OpResolver>(
|
return PacketAdopting<tflite::OpResolver>(
|
||||||
std::make_unique<tflite_shims::ops::builtin::
|
std::make_unique<
|
||||||
BuiltinOpResolverWithoutDefaultDelegates>());
|
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
|
||||||
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
||||||
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
||||||
// migration.
|
// migration.
|
||||||
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
|
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
||||||
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||||
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
||||||
"OP_RESOLVER"};
|
"OP_RESOLVER"};
|
||||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#if defined(MEDIAPIPE_ANDROID)
|
#if defined(MEDIAPIPE_ANDROID)
|
||||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||||
#endif // ANDROID
|
#endif // ANDROID
|
||||||
|
|
|
@ -22,9 +22,9 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
#include "tensorflow/lite/interpreter.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
|
#include "tensorflow/lite/interpreter_builder.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||||
|
@ -33,8 +33,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using Interpreter = ::tflite_shims::Interpreter;
|
using Interpreter = ::tflite::Interpreter;
|
||||||
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
|
using InterpreterBuilder = ::tflite::InterpreterBuilder;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||||
|
|
|
@ -23,8 +23,8 @@
|
||||||
#include "mediapipe/calculators/tensor/tflite_delegate_ptr.h"
|
#include "mediapipe/calculators/tensor/tflite_delegate_ptr.h"
|
||||||
#include "mediapipe/framework/api2/packet.h"
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
#include "tensorflow/lite/c/c_api_types.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
|
|
@ -61,12 +61,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
|
||||||
// overload GPU/TPU/...
|
// overload GPU/TPU/...
|
||||||
class SimpleSemaphore {
|
class SimpleSemaphore {
|
||||||
public:
|
public:
|
||||||
explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {}
|
explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {}
|
||||||
SimpleSemaphore(const SimpleSemaphore&) = delete;
|
SimpleSemaphore(const SimpleSemaphore&) = delete;
|
||||||
SimpleSemaphore(SimpleSemaphore&&) = delete;
|
SimpleSemaphore(SimpleSemaphore&&) = delete;
|
||||||
|
|
||||||
// Acquires the semaphore by certain amount.
|
// Acquires the semaphore by certain amount.
|
||||||
void Acquire(uint32 amount) {
|
void Acquire(uint32_t amount) {
|
||||||
mutex_.Lock();
|
mutex_.Lock();
|
||||||
while (count_ < amount) {
|
while (count_ < amount) {
|
||||||
cond_.Wait(&mutex_);
|
cond_.Wait(&mutex_);
|
||||||
|
@ -76,7 +76,7 @@ class SimpleSemaphore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Releases the semaphore by certain amount.
|
// Releases the semaphore by certain amount.
|
||||||
void Release(uint32 amount) {
|
void Release(uint32_t amount) {
|
||||||
mutex_.Lock();
|
mutex_.Lock();
|
||||||
count_ += amount;
|
count_ += amount;
|
||||||
cond_.SignalAll();
|
cond_.SignalAll();
|
||||||
|
@ -84,7 +84,7 @@ class SimpleSemaphore {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint32 count_;
|
uint32_t count_;
|
||||||
absl::Mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
absl::CondVar cond_;
|
absl::CondVar cond_;
|
||||||
};
|
};
|
||||||
|
@ -488,7 +488,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
// necessary.
|
// necessary.
|
||||||
absl::Status OutputBatch(CalculatorContext* cc,
|
absl::Status OutputBatch(CalculatorContext* cc,
|
||||||
std::unique_ptr<InferenceState> inference_state) {
|
std::unique_ptr<InferenceState> inference_state) {
|
||||||
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
||||||
|
|
||||||
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
||||||
|
@ -544,7 +544,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
get_session_run_throttle(options_.max_concurrent_session_runs());
|
get_session_run_throttle(options_.max_concurrent_session_runs());
|
||||||
session_run_throttle->Acquire(1);
|
session_run_throttle->Acquire(1);
|
||||||
}
|
}
|
||||||
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
tf::Status tf_status;
|
tf::Status tf_status;
|
||||||
{
|
{
|
||||||
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
||||||
|
@ -562,7 +562,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
// informative error message.
|
// informative error message.
|
||||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||||
|
|
||||||
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
||||||
->IncrementBy(run_end_time - run_start_time);
|
->IncrementBy(run_end_time - run_start_time);
|
||||||
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
||||||
|
@ -611,7 +611,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get end time and report.
|
// Get end time and report.
|
||||||
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
cc->GetCounter(kTotalUsecsCounterSuffix)
|
cc->GetCounter(kTotalUsecsCounterSuffix)
|
||||||
->IncrementBy(end_time - start_time);
|
->IncrementBy(end_time - start_time);
|
||||||
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
||||||
|
@ -650,7 +650,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
|
|
||||||
// The static singleton semaphore to throttle concurrent session runs.
|
// The static singleton semaphore to throttle concurrent session runs.
|
||||||
static SimpleSemaphore* get_session_run_throttle(
|
static SimpleSemaphore* get_session_run_throttle(
|
||||||
int32 max_concurrent_session_runs) {
|
int32_t max_concurrent_session_runs) {
|
||||||
static SimpleSemaphore* session_run_throttle =
|
static SimpleSemaphore* session_run_throttle =
|
||||||
new SimpleSemaphore(max_concurrent_session_runs);
|
new SimpleSemaphore(max_concurrent_session_runs);
|
||||||
return session_run_throttle;
|
return session_run_throttle;
|
||||||
|
|
|
@ -197,15 +197,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// timestamp and the associated feature. This information is used in process
|
// timestamp and the associated feature. This information is used in process
|
||||||
// to output batches of packets in order.
|
// to output batches of packets in order.
|
||||||
timestamps_.clear();
|
timestamps_.clear();
|
||||||
int64 last_timestamp_seen = Timestamp::PreStream().Value();
|
int64_t last_timestamp_seen = Timestamp::PreStream().Value();
|
||||||
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
|
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
|
||||||
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
|
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
|
||||||
if (absl::StrContains(map_kv.first, "/timestamp")) {
|
if (absl::StrContains(map_kv.first, "/timestamp")) {
|
||||||
LOG(INFO) << "Found feature timestamps: " << map_kv.first
|
LOG(INFO) << "Found feature timestamps: " << map_kv.first
|
||||||
<< " with size: " << map_kv.second.feature_size();
|
<< " with size: " << map_kv.second.feature_size();
|
||||||
int64 recent_timestamp = Timestamp::PreStream().Value();
|
int64_t recent_timestamp = Timestamp::PreStream().Value();
|
||||||
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
|
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
|
||||||
int64 next_timestamp =
|
int64_t next_timestamp =
|
||||||
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
|
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
|
||||||
RET_CHECK_GT(next_timestamp, recent_timestamp)
|
RET_CHECK_GT(next_timestamp, recent_timestamp)
|
||||||
<< "Timestamps must be sequential. If you're seeing this message "
|
<< "Timestamps must be sequential. If you're seeing this message "
|
||||||
|
@ -361,8 +361,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// any particular call to Process(). At the every end, we output the
|
// any particular call to Process(). At the every end, we output the
|
||||||
// poststream packets. If we only have poststream packets,
|
// poststream packets. If we only have poststream packets,
|
||||||
// last_timestamp_key_ will be empty.
|
// last_timestamp_key_ will be empty.
|
||||||
int64 start_timestamp = 0;
|
int64_t start_timestamp = 0;
|
||||||
int64 end_timestamp = 0;
|
int64_t end_timestamp = 0;
|
||||||
if (last_timestamp_key_.empty() || process_poststream_) {
|
if (last_timestamp_key_.empty() || process_poststream_) {
|
||||||
process_poststream_ = true;
|
process_poststream_ = true;
|
||||||
start_timestamp = Timestamp::PostStream().Value();
|
start_timestamp = Timestamp::PostStream().Value();
|
||||||
|
@ -481,14 +481,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// Store a map from the keys for each stream to the timestamps for each
|
// Store a map from the keys for each stream to the timestamps for each
|
||||||
// key. This allows us to identify which packets to output for each stream
|
// key. This allows us to identify which packets to output for each stream
|
||||||
// for timestamps within a given time window.
|
// for timestamps within a given time window.
|
||||||
std::map<std::string, std::vector<int64>> timestamps_;
|
std::map<std::string, std::vector<int64_t>> timestamps_;
|
||||||
// Store the stream with the latest timestamp in the SequenceExample.
|
// Store the stream with the latest timestamp in the SequenceExample.
|
||||||
std::string last_timestamp_key_;
|
std::string last_timestamp_key_;
|
||||||
// Store the index of the current timestamp. Will be less than
|
// Store the index of the current timestamp. Will be less than
|
||||||
// timestamps_[last_timestamp_key_].size().
|
// timestamps_[last_timestamp_key_].size().
|
||||||
int current_timestamp_index_;
|
int current_timestamp_index_;
|
||||||
// Store the very first timestamp, so we output everything on the first frame.
|
// Store the very first timestamp, so we output everything on the first frame.
|
||||||
int64 first_timestamp_seen_;
|
int64_t first_timestamp_seen_;
|
||||||
// List of keypoint names.
|
// List of keypoint names.
|
||||||
std::vector<std::string> keypoint_names_;
|
std::vector<std::string> keypoint_names_;
|
||||||
// Default keypoint location when missing.
|
// Default keypoint location when missing.
|
||||||
|
|
|
@ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 time = 1234;
|
const int64_t time = 1234;
|
||||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
Adopt(input.release()).At(Timestamp(time)));
|
Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) {
|
||||||
// 2^i can be represented exactly in floating point numbers if 'i' is small.
|
// 2^i can be represented exactly in floating point numbers if 'i' is small.
|
||||||
input->at(i) = static_cast<float>(1 << i);
|
input->at(i) = static_cast<float>(1 << i);
|
||||||
}
|
}
|
||||||
const int64 time = 1234;
|
const int64_t time = 1234;
|
||||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
Adopt(input.release()).At(Timestamp(time)));
|
Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
|
|
@ -28,11 +28,8 @@
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
|
||||||
using mediapipe::Adopt;
|
|
||||||
using mediapipe::CalculatorBase;
|
|
||||||
using mediapipe::ImageFrame;
|
using mediapipe::ImageFrame;
|
||||||
using mediapipe::PacketTypeSet;
|
using mediapipe::PacketTypeSet;
|
||||||
using mediapipe::autoflip::Border;
|
|
||||||
|
|
||||||
constexpr char kDetectedBorders[] = "DETECTED_BORDERS";
|
constexpr char kDetectedBorders[] = "DETECTED_BORDERS";
|
||||||
constexpr int kMinBorderDistance = 5;
|
constexpr int kMinBorderDistance = 5;
|
||||||
|
|
|
@ -28,16 +28,12 @@
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
using mediapipe::Adopt;
|
|
||||||
using mediapipe::CalculatorGraphConfig;
|
using mediapipe::CalculatorGraphConfig;
|
||||||
using mediapipe::CalculatorRunner;
|
using mediapipe::CalculatorRunner;
|
||||||
using mediapipe::ImageFormat;
|
using mediapipe::ImageFormat;
|
||||||
using mediapipe::ImageFrame;
|
using mediapipe::ImageFrame;
|
||||||
using mediapipe::Packet;
|
using mediapipe::Packet;
|
||||||
using mediapipe::PacketTypeSet;
|
using mediapipe::PacketTypeSet;
|
||||||
using mediapipe::ParseTextProtoOrDie;
|
|
||||||
using mediapipe::Timestamp;
|
|
||||||
using mediapipe::autoflip::Border;
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace autoflip {
|
namespace autoflip {
|
||||||
|
|
|
@ -31,14 +31,11 @@
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
using mediapipe::Adopt;
|
|
||||||
using mediapipe::CalculatorGraphConfig;
|
using mediapipe::CalculatorGraphConfig;
|
||||||
using mediapipe::CalculatorRunner;
|
using mediapipe::CalculatorRunner;
|
||||||
using mediapipe::ImageFormat;
|
using mediapipe::ImageFormat;
|
||||||
using mediapipe::ImageFrame;
|
using mediapipe::ImageFrame;
|
||||||
using mediapipe::PacketTypeSet;
|
using mediapipe::PacketTypeSet;
|
||||||
using mediapipe::ParseTextProtoOrDie;
|
|
||||||
using mediapipe::Timestamp;
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace autoflip {
|
namespace autoflip {
|
||||||
|
|
|
@ -28,8 +28,6 @@
|
||||||
using mediapipe::Packet;
|
using mediapipe::Packet;
|
||||||
using mediapipe::PacketTypeSet;
|
using mediapipe::PacketTypeSet;
|
||||||
using mediapipe::autoflip::DetectionSet;
|
using mediapipe::autoflip::DetectionSet;
|
||||||
using mediapipe::autoflip::SalientRegion;
|
|
||||||
using mediapipe::autoflip::SignalType;
|
|
||||||
|
|
||||||
constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY";
|
constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY";
|
||||||
constexpr char kSignalInputsTag[] = "SIGNAL";
|
constexpr char kSignalInputsTag[] = "SIGNAL";
|
||||||
|
|
|
@ -19,8 +19,6 @@ namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
namespace test {
|
namespace test {
|
||||||
|
|
||||||
using testing::ElementsAre;
|
|
||||||
|
|
||||||
// Returns the packet values for a vector of Packets.
|
// Returns the packet values for a vector of Packets.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) {
|
std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) {
|
||||||
|
|
|
@ -310,7 +310,7 @@ class Scheduler {
|
||||||
absl::Mutex state_mutex_;
|
absl::Mutex state_mutex_;
|
||||||
|
|
||||||
// Current state of the scheduler.
|
// Current state of the scheduler.
|
||||||
std::atomic<State> state_ = ATOMIC_VAR_INIT(STATE_NOT_STARTED);
|
std::atomic<State> state_ = STATE_NOT_STARTED;
|
||||||
|
|
||||||
// True if all graph input streams are closed.
|
// True if all graph input streams are closed.
|
||||||
bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false;
|
bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false;
|
||||||
|
|
|
@ -131,7 +131,7 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
|
||||||
// Record the most recent first kept timestamp on any stream.
|
// Record the most recent first kept timestamp on any stream.
|
||||||
for (const auto& stream : input_stream_managers_) {
|
for (const auto& stream : input_stream_managers_) {
|
||||||
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
||||||
? target_queue_size_
|
? target_queue_size_
|
||||||
: trigger_queue_size_ - 1;
|
: trigger_queue_size_ - 1;
|
||||||
if (stream->QueueSize() > queue_size) {
|
if (stream->QueueSize() > queue_size) {
|
||||||
|
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32 trigger_queue_size_;
|
int32_t trigger_queue_size_;
|
||||||
int32 target_queue_size_;
|
int32_t target_queue_size_;
|
||||||
bool fixed_min_size_;
|
bool fixed_min_size_;
|
||||||
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
|
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
|
||||||
// the corresponding call to FillInputSet has not yet completed.
|
// the corresponding call to FillInputSet has not yet completed.
|
||||||
|
|
|
@ -30,15 +30,15 @@ namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const int64 kMaxPacketId = 100;
|
const int64_t kMaxPacketId = 100;
|
||||||
const int64 kSlowCalculatorRate = 10;
|
const int64_t kSlowCalculatorRate = 10;
|
||||||
|
|
||||||
// Rate limiter for TestSlowCalculator.
|
// Rate limiter for TestSlowCalculator.
|
||||||
ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit);
|
ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit);
|
||||||
int64 g_source_counter ABSL_GUARDED_BY(g_source_mutex);
|
int64_t g_source_counter ABSL_GUARDED_BY(g_source_mutex);
|
||||||
|
|
||||||
// Rate limiter for TestSourceCalculator.
|
// Rate limiter for TestSourceCalculator.
|
||||||
int64 g_slow_counter ABSL_GUARDED_BY(g_source_mutex);
|
int64_t g_slow_counter ABSL_GUARDED_BY(g_source_mutex);
|
||||||
|
|
||||||
// Flag that indicates that the source is done.
|
// Flag that indicates that the source is done.
|
||||||
bool g_source_done ABSL_GUARDED_BY(g_source_mutex);
|
bool g_source_done ABSL_GUARDED_BY(g_source_mutex);
|
||||||
|
@ -47,7 +47,7 @@ class TestSourceCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
TestSourceCalculator() : current_packet_id_(0) {}
|
TestSourceCalculator() : current_packet_id_(0) {}
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Outputs().Index(0).Set<int64>();
|
cc->Outputs().Index(0).Set<int64_t>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
@ -62,7 +62,7 @@ class TestSourceCalculator : public CalculatorBase {
|
||||||
g_source_done = true;
|
g_source_done = true;
|
||||||
return tool::StatusStop();
|
return tool::StatusStop();
|
||||||
}
|
}
|
||||||
cc->Outputs().Index(0).Add(new int64(0), Timestamp(current_packet_id_));
|
cc->Outputs().Index(0).Add(new int64_t(0), Timestamp(current_packet_id_));
|
||||||
++current_packet_id_;
|
++current_packet_id_;
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&g_source_mutex);
|
absl::MutexLock lock(&g_source_mutex);
|
||||||
|
@ -78,7 +78,7 @@ class TestSourceCalculator : public CalculatorBase {
|
||||||
return g_source_counter <= kSlowCalculatorRate * g_slow_counter ||
|
return g_source_counter <= kSlowCalculatorRate * g_slow_counter ||
|
||||||
g_source_counter <= 1;
|
g_source_counter <= 1;
|
||||||
}
|
}
|
||||||
int64 current_packet_id_;
|
int64_t current_packet_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_CALCULATOR(TestSourceCalculator);
|
REGISTER_CALCULATOR(TestSourceCalculator);
|
||||||
|
@ -87,8 +87,8 @@ class TestSlowCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
TestSlowCalculator() = default;
|
TestSlowCalculator() = default;
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->Inputs().Index(0).Set<int64>();
|
cc->Inputs().Index(0).Set<int64_t>();
|
||||||
cc->Outputs().Index(0).Set<int64>();
|
cc->Outputs().Index(0).Set<int64_t>();
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
@ -97,7 +97,7 @@ class TestSlowCalculator : public CalculatorBase {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
cc->Outputs().Index(0).Add(new int64(0),
|
cc->Outputs().Index(0).Add(new int64_t(0),
|
||||||
cc->Inputs().Index(0).Value().Timestamp());
|
cc->Inputs().Index(0).Value().Timestamp());
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&g_source_mutex);
|
absl::MutexLock lock(&g_source_mutex);
|
||||||
|
@ -118,8 +118,9 @@ class TestSlowCalculator : public CalculatorBase {
|
||||||
REGISTER_CALCULATOR(TestSlowCalculator);
|
REGISTER_CALCULATOR(TestSlowCalculator);
|
||||||
|
|
||||||
// Return the values of the timestamps of a vector of Packets.
|
// Return the values of the timestamps of a vector of Packets.
|
||||||
static std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
static std::vector<int64_t> TimestampValues(
|
||||||
std::vector<int64> result;
|
const std::vector<Packet>& packets) {
|
||||||
|
std::vector<int64_t> result;
|
||||||
for (const Packet& p : packets) {
|
for (const Packet& p : packets) {
|
||||||
result.push_back(p.Timestamp().Value());
|
result.push_back(p.Timestamp().Value());
|
||||||
}
|
}
|
||||||
|
@ -174,7 +175,7 @@ TEST_P(FixedSizeInputStreamHandlerTest, DropsPackets) {
|
||||||
// consumed. In this way, the TestSlowCalculator consumes and outputs only
|
// consumed. In this way, the TestSlowCalculator consumes and outputs only
|
||||||
// every tenth packet.
|
// every tenth packet.
|
||||||
EXPECT_EQ(output_packets.size(), 11);
|
EXPECT_EQ(output_packets.size(), 11);
|
||||||
std::vector<int64> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99};
|
std::vector<int64_t> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99};
|
||||||
EXPECT_THAT(TimestampValues(output_packets),
|
EXPECT_THAT(TimestampValues(output_packets),
|
||||||
testing::ContainerEq(expected_ts));
|
testing::ContainerEq(expected_ts));
|
||||||
}
|
}
|
||||||
|
@ -344,18 +345,18 @@ TEST_P(FixedSizeInputStreamHandlerTest, LateArrivalDrop) {
|
||||||
|
|
||||||
if (GetParam()) {
|
if (GetParam()) {
|
||||||
EXPECT_THAT(TimestampValues(output_packets[0]),
|
EXPECT_THAT(TimestampValues(output_packets[0]),
|
||||||
testing::ContainerEq(std::vector<int64>{1, 2, 3, 4, 5, 6}));
|
testing::ContainerEq(std::vector<int64_t>{1, 2, 3, 4, 5, 6}));
|
||||||
EXPECT_THAT(TimestampValues(output_packets[1]),
|
EXPECT_THAT(TimestampValues(output_packets[1]),
|
||||||
testing::ContainerEq(std::vector<int64>{3, 4, 5, 6, 7}));
|
testing::ContainerEq(std::vector<int64_t>{3, 4, 5, 6, 7}));
|
||||||
EXPECT_THAT(TimestampValues(output_packets[2]),
|
EXPECT_THAT(TimestampValues(output_packets[2]),
|
||||||
testing::ContainerEq(std::vector<int64>{4, 5, 6, 7}));
|
testing::ContainerEq(std::vector<int64_t>{4, 5, 6, 7}));
|
||||||
} else {
|
} else {
|
||||||
EXPECT_THAT(TimestampValues(output_packets[0]),
|
EXPECT_THAT(TimestampValues(output_packets[0]),
|
||||||
testing::ContainerEq(std::vector<int64>{5, 6}));
|
testing::ContainerEq(std::vector<int64_t>{5, 6}));
|
||||||
EXPECT_THAT(TimestampValues(output_packets[1]),
|
EXPECT_THAT(TimestampValues(output_packets[1]),
|
||||||
testing::ContainerEq(std::vector<int64>{5, 6, 7}));
|
testing::ContainerEq(std::vector<int64_t>{5, 6, 7}));
|
||||||
EXPECT_THAT(TimestampValues(output_packets[2]),
|
EXPECT_THAT(TimestampValues(output_packets[2]),
|
||||||
testing::ContainerEq(std::vector<int64>{5, 6, 7}));
|
testing::ContainerEq(std::vector<int64_t>{5, 6, 7}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,6 @@ namespace options_field_util {
|
||||||
|
|
||||||
using ::mediapipe::proto_ns::internal::WireFormatLite;
|
using ::mediapipe::proto_ns::internal::WireFormatLite;
|
||||||
using FieldType = WireFormatLite::FieldType;
|
using FieldType = WireFormatLite::FieldType;
|
||||||
using ::mediapipe::proto_ns::io::ArrayInputStream;
|
|
||||||
using ::mediapipe::proto_ns::io::CodedInputStream;
|
|
||||||
using ::mediapipe::proto_ns::io::CodedOutputStream;
|
|
||||||
using ::mediapipe::proto_ns::io::StringOutputStream;
|
|
||||||
|
|
||||||
// Utility functions for OptionsFieldUtil.
|
// Utility functions for OptionsFieldUtil.
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -454,8 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
||||||
// Number of glFinish calls completed on the GL thread.
|
// Number of glFinish calls completed on the GL thread.
|
||||||
// Changes should be guarded by mutex_. However, we use simple atomic
|
// Changes should be guarded by mutex_. However, we use simple atomic
|
||||||
// loads for efficiency on the fast path.
|
// loads for efficiency on the fast path.
|
||||||
std::atomic<int64_t> gl_finish_count_ = ATOMIC_VAR_INIT(0);
|
std::atomic<int64_t> gl_finish_count_ = 0;
|
||||||
std::atomic<int64_t> gl_finish_count_target_ = ATOMIC_VAR_INIT(0);
|
std::atomic<int64_t> gl_finish_count_target_ = 0;
|
||||||
|
|
||||||
GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr;
|
GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr;
|
||||||
|
|
||||||
|
|
|
@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
|
||||||
// TODO: Investigate this option in more detail, esp. on Safari.
|
// TODO: Investigate this option in more detail, esp. on Safari.
|
||||||
attrs.preserveDrawingBuffer = 0;
|
attrs.preserveDrawingBuffer = 0;
|
||||||
|
|
||||||
// Since the Emscripten canvas target finding function is visible from here,
|
// Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
|
||||||
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas
|
// looks for our #canvas target in Module.canvas, where we expect it to be.
|
||||||
// behavior if the user desires, falling back to the new DOM element CSS
|
// -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
|
||||||
// selector behavior next if that is specified, and finally just allowing the
|
// event target behavior, but it was never supposed to be tapping into our
|
||||||
// lookup to proceed on a null target.
|
// canvas anyways. See b/278155946 for more background.
|
||||||
// TODO: Ensure this works with all options (in particular,
|
EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
|
||||||
// multithreading options, like the special-case combination of USE_PTHREADS
|
|
||||||
// and OFFSCREEN_FRAMEBUFFER)
|
|
||||||
// clang-format off
|
|
||||||
EM_ASM(
|
|
||||||
let init_once = true;
|
|
||||||
if (init_once) {
|
|
||||||
const cachedFindCanvasEventTarget = findCanvasEventTarget;
|
|
||||||
|
|
||||||
if (typeof cachedFindCanvasEventTarget !== 'function') {
|
|
||||||
if (typeof console !== 'undefined') {
|
|
||||||
console.error('Expected Emscripten global function '
|
|
||||||
+ '"findCanvasEventTarget" not found. WebGL context creation '
|
|
||||||
+ 'may fail.');
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
findCanvasEventTarget = function(target) {
|
|
||||||
if (target == 0) {
|
|
||||||
if (Module && Module.canvas) {
|
|
||||||
return Module.canvas;
|
|
||||||
} else if (Module && Module.canvasCssSelector) {
|
|
||||||
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
|
|
||||||
}
|
|
||||||
if (typeof console !== 'undefined') {
|
|
||||||
console.warn('Module properties canvas and canvasCssSelector not ' +
|
|
||||||
'found during WebGL context creation.');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We still go through with the find attempt, although for most use
|
|
||||||
// cases it will not succeed, just in case the user does want to fall-
|
|
||||||
// back.
|
|
||||||
return cachedFindCanvasEventTarget(target);
|
|
||||||
}; // NOLINT: Necessary semicolon.
|
|
||||||
init_once = false;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
|
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
|
||||||
emscripten_webgl_create_context(nullptr, &attrs);
|
emscripten_webgl_create_context("#canvas", &attrs);
|
||||||
|
|
||||||
// Check for failure
|
// Check for failure
|
||||||
if (context_handle <= 0) {
|
if (context_handle <= 0) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
|
||||||
int actual_ws = image_frame.WidthStep();
|
int actual_ws = image_frame.WidthStep();
|
||||||
int alignment = 0;
|
int alignment = 0;
|
||||||
std::unique_ptr<ImageFrame> temp;
|
std::unique_ptr<ImageFrame> temp;
|
||||||
const uint8* data = image_frame.PixelData();
|
const uint8_t* data = image_frame.PixelData();
|
||||||
|
|
||||||
// Let's see if the pixel data is tightly aligned to one of the alignments
|
// Let's see if the pixel data is tightly aligned to one of the alignments
|
||||||
// supported by OpenGL, preferring 4 if possible since it's the default.
|
// supported by OpenGL, preferring 4 if possible since it's the default.
|
||||||
|
|
|
@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
||||||
GpuBufferFormat format) {
|
GpuBufferFormat format) {
|
||||||
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
|
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
|
||||||
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
|
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
|
||||||
auto y_data = std::make_unique<uint8[]>(y_stride * height);
|
auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
|
||||||
switch (fourcc) {
|
switch (fourcc) {
|
||||||
case libyuv::FOURCC_NV12:
|
case libyuv::FOURCC_NV12:
|
||||||
case libyuv::FOURCC_NV21: {
|
case libyuv::FOURCC_NV21: {
|
||||||
|
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
||||||
int uv_width = 2 * std::ceil(0.5f * width);
|
int uv_width = 2 * std::ceil(0.5f * width);
|
||||||
int uv_height = std::ceil(0.5f * height);
|
int uv_height = std::ceil(0.5f * height);
|
||||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||||
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||||
yuv_image_ = std::make_shared<YUVImage>(
|
yuv_image_ = std::make_shared<YUVImage>(
|
||||||
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
|
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
|
||||||
nullptr, 0, width, height);
|
nullptr, 0, width, height);
|
||||||
|
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
||||||
int uv_width = std::ceil(0.5f * width);
|
int uv_width = std::ceil(0.5f * width);
|
||||||
int uv_height = std::ceil(0.5f * height);
|
int uv_height = std::ceil(0.5f * height);
|
||||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||||
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||||
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||||
yuv_image_ = std::make_shared<YUVImage>(
|
yuv_image_ = std::make_shared<YUVImage>(
|
||||||
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
|
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
|
||||||
std::move(v_data), uv_stride, width, height);
|
std::move(v_data), uv_stride, width, height);
|
||||||
|
|
|
@ -16,6 +16,7 @@ import csv
|
||||||
import filecmp
|
import filecmp
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip('b/275624089')
|
||||||
class TextClassifierTest(tf.test.TestCase):
|
class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
|
|
|
@ -175,11 +175,7 @@ py_test(
|
||||||
data = [":testdata"],
|
data = [":testdata"],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":object_detector_import",
|
||||||
":hyperparameters",
|
|
||||||
":model_spec",
|
|
||||||
":object_detector",
|
|
||||||
":object_detector_options",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
from mediapipe.model_maker.python.vision import object_detector
|
||||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
|
||||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
self.data = dataset.Dataset.from_coco_folder(
|
self.data = object_detector.Dataset.from_coco_folder(
|
||||||
dataset_folder, cache_dir=cache_dir
|
dataset_folder, cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.addCleanup(mock_gettempdir.stop)
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def test_object_detector(self):
|
def test_object_detector(self):
|
||||||
hparams = hyperparameters.HParams(
|
hparams = object_detector.HParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
export_dir=self.create_tempdir(),
|
export_dir=self.create_tempdir(),
|
||||||
)
|
)
|
||||||
options = object_detector_options.ObjectDetectorOptions(
|
options = object_detector.ObjectDetectorOptions(
|
||||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||||
|
hparams=hparams,
|
||||||
)
|
)
|
||||||
# Test `create``
|
# Test `create``
|
||||||
model = object_detector.ObjectDetector.create(
|
model = object_detector.ObjectDetector.create(
|
||||||
|
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
|
||||||
# Test `quantization_aware_training`
|
# Test `quantization_aware_training`
|
||||||
qat_hparams = hyperparameters.QATHParams(
|
qat_hparams = object_detector.QATHParams(
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
epochs=1,
|
epochs=1,
|
||||||
|
|
|
@ -298,6 +298,7 @@ cc_library(
|
||||||
":tensors_to_objects_calculator_cc_proto",
|
":tensors_to_objects_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
|
|
@ -91,8 +91,8 @@ absl::Status FrameAnnotationToTimedBoxListCalculator::Process(
|
||||||
TimedBoxProto* added_box = output_objects->add_box();
|
TimedBoxProto* added_box = output_objects->add_box();
|
||||||
ComputeBoundingRect(key_points, added_box);
|
ComputeBoundingRect(key_points, added_box);
|
||||||
added_box->set_id(annotation.object_id());
|
added_box->set_id(annotation.object_id());
|
||||||
const int64 time_msec =
|
const int64_t time_msec =
|
||||||
static_cast<int64>(std::round(frame_annotation.timestamp() / 1000));
|
static_cast<int64_t>(std::round(frame_annotation.timestamp() / 1000));
|
||||||
added_box->set_time_msec(time_msec);
|
added_box->set_time_msec(time_msec);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
void FrameAnnotationTracker::AddDetectionResult(
|
void FrameAnnotationTracker::AddDetectionResult(
|
||||||
const FrameAnnotation& frame_annotation) {
|
const FrameAnnotation& frame_annotation) {
|
||||||
const int64 time_us =
|
const int64_t time_us =
|
||||||
static_cast<int64>(std::round(frame_annotation.timestamp()));
|
static_cast<int64_t>(std::round(frame_annotation.timestamp()));
|
||||||
for (const auto& object_annotation : frame_annotation.annotations()) {
|
for (const auto& object_annotation : frame_annotation.annotations()) {
|
||||||
detected_objects_[time_us + object_annotation.object_id()] =
|
detected_objects_[time_us + object_annotation.object_id()] =
|
||||||
object_annotation;
|
object_annotation;
|
||||||
|
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
|
||||||
absl::flat_hash_set<int>* cancel_object_ids) {
|
absl::flat_hash_set<int>* cancel_object_ids) {
|
||||||
CHECK(cancel_object_ids != nullptr);
|
CHECK(cancel_object_ids != nullptr);
|
||||||
FrameAnnotation frame_annotation;
|
FrameAnnotation frame_annotation;
|
||||||
std::vector<int64> keys_to_be_deleted;
|
std::vector<int64_t> keys_to_be_deleted;
|
||||||
for (const auto& detected_obj : detected_objects_) {
|
for (const auto& detected_obj : detected_objects_) {
|
||||||
const int object_id = detected_obj.second.object_id();
|
const int object_id = detected_obj.second.object_id();
|
||||||
if (cancel_object_ids->contains(object_id)) {
|
if (cancel_object_ids->contains(object_id)) {
|
||||||
|
|
|
@ -76,7 +76,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase {
|
||||||
// In a single MediaPipe session, the IDs are unique.
|
// In a single MediaPipe session, the IDs are unique.
|
||||||
// Also assign timestamp for the FrameAnnotation to be the input packet
|
// Also assign timestamp for the FrameAnnotation to be the input packet
|
||||||
// timestamp.
|
// timestamp.
|
||||||
void AssignObjectIdAndTimestamp(int64 timestamp_us,
|
void AssignObjectIdAndTimestamp(int64_t timestamp_us,
|
||||||
FrameAnnotation* annotation);
|
FrameAnnotation* annotation);
|
||||||
|
|
||||||
int num_classes_ = 0;
|
int num_classes_ = 0;
|
||||||
|
@ -207,7 +207,7 @@ void TfLiteTensorsToObjectsCalculator::Project3DTo2D(
|
||||||
}
|
}
|
||||||
|
|
||||||
void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp(
|
void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp(
|
||||||
int64 timestamp_us, FrameAnnotation* annotation) {
|
int64_t timestamp_us, FrameAnnotation* annotation) {
|
||||||
for (auto& ann : *annotation->mutable_annotations()) {
|
for (auto& ann : *annotation->mutable_annotations()) {
|
||||||
ann.set_object_id(GetNextObjectId());
|
ann.set_object_id(GetNextObjectId());
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -157,7 +157,7 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
||||||
auto options = std::make_unique<AudioClassifierOptions>();
|
auto options = std::make_unique<AudioClassifierOptions>();
|
||||||
|
@ -270,7 +270,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassifyTest : public tflite_shims::testing::Test {};
|
class ClassifyTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ClassifyTest, Succeeds) {
|
TEST_F(ClassifyTest, Succeeds) {
|
||||||
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
||||||
|
@ -467,7 +467,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
|
class ClassifyAsyncTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ClassifyAsyncTest, Succeeds) {
|
TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||||
constexpr int kSampleRateHz = 48000;
|
constexpr int kSampleRateHz = 48000;
|
||||||
|
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -66,7 +66,7 @@ Matrix GetAudioData(absl::string_view filename) {
|
||||||
return matrix_mapping.matrix();
|
return matrix_mapping.matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||||
auto audio_embedder =
|
auto audio_embedder =
|
||||||
|
@ -124,7 +124,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class EmbedTest : public tflite_shims::testing::Test {};
|
class EmbedTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
||||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||||
|
@ -187,7 +187,7 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
|
||||||
MP_EXPECT_OK(audio_embedder->Close());
|
MP_EXPECT_OK(audio_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class EmbedAsyncTest : public tflite_shims::testing::Test {
|
class EmbedAsyncTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
||||||
int sample_rate_hz,
|
int sample_rate_hz,
|
||||||
|
|
|
@ -47,7 +47,7 @@ cc_test_with_tflite(
|
||||||
data = ["//mediapipe/tasks/testdata/audio:test_models"],
|
data = ["//mediapipe/tasks/testdata/audio:test_models"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":audio_tensor_specs",
|
":audio_tensor_specs",
|
||||||
|
|
|
@ -34,7 +34,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -52,7 +52,7 @@ constexpr char kModelWithMetadata[] =
|
||||||
"yamnet_audio_classifier_with_metadata.tflite";
|
"yamnet_audio_classifier_with_metadata.tflite";
|
||||||
constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite";
|
constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite";
|
||||||
|
|
||||||
class AudioTensorSpecsTest : public tflite_shims::testing::Test {};
|
class AudioTensorSpecsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(AudioTensorSpecsTest,
|
TEST_F(AudioTensorSpecsTest,
|
||||||
BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) {
|
BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) {
|
||||||
|
|
|
@ -63,7 +63,7 @@ cc_test(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,6 +232,6 @@ cc_test(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -66,8 +66,7 @@ ClassificationList MakeClassificationList(int class_index) {
|
||||||
class_index));
|
class_index));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClassificationAggregationCalculatorTest
|
class ClassificationAggregationCalculatorTest : public tflite::testing::Test {
|
||||||
: public tflite_shims::testing::Test {
|
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
bool connect_timestamps = false) {
|
bool connect_timestamps = false) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -52,7 +52,7 @@ constexpr char kTimestampsName[] = "timestamps_in";
|
||||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||||
|
|
||||||
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
class EmbeddingAggregationCalculatorTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
|
|
@ -66,7 +66,7 @@ using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
using ::tflite::ProcessUnit;
|
using ::tflite::ProcessUnit;
|
||||||
using ::tflite::TensorMetadata;
|
using ::tflite::TensorMetadata;
|
||||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
|
||||||
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
|
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
|
||||||
|
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
|
@ -49,7 +49,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -101,7 +101,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
std::move(external_file));
|
std::move(external_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
class ConfigureTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -417,7 +417,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
)pb")));
|
)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
class PostprocessingTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||||
|
@ -520,7 +520,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||||
auto poller,
|
auto poller,
|
||||||
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
||||||
// Build input tensors.
|
// Build input tensors.
|
||||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
|
||||||
tensor[1] = 18;
|
tensor[1] = 18;
|
||||||
tensor[2] = 16;
|
tensor[2] = 16;
|
||||||
|
|
||||||
|
@ -552,7 +552,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||||
// Build input tensors.
|
// Build input tensors.
|
||||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
|
||||||
tensor[1] = 12;
|
tensor[1] = 12;
|
||||||
tensor[2] = 14;
|
tensor[2] = 14;
|
||||||
tensor[3] = 16;
|
tensor[3] = 16;
|
||||||
|
@ -589,7 +589,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
auto poller,
|
auto poller,
|
||||||
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
||||||
// Build input tensors.
|
// Build input tensors.
|
||||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
|
||||||
tensor[1] = 12;
|
tensor[1] = 12;
|
||||||
tensor[2] = 14;
|
tensor[2] = 14;
|
||||||
tensor[3] = 16;
|
tensor[3] = 16;
|
||||||
|
@ -677,11 +677,11 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||||
/*connect_timestamps=*/true));
|
/*connect_timestamps=*/true));
|
||||||
// Build input tensors.
|
// Build input tensors.
|
||||||
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
|
std::vector<uint8_t> tensor_0(kMobileNetNumClasses, 0);
|
||||||
tensor_0[1] = 12;
|
tensor_0[1] = 12;
|
||||||
tensor_0[2] = 14;
|
tensor_0[2] = 14;
|
||||||
tensor_0[3] = 16;
|
tensor_0[3] = 16;
|
||||||
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
|
std::vector<uint8_t> tensor_1(kMobileNetNumClasses, 0);
|
||||||
tensor_1[5] = 12;
|
tensor_1[5] = 12;
|
||||||
tensor_1[6] = 14;
|
tensor_1[6] = 14;
|
||||||
tensor_1[7] = 16;
|
tensor_1[7] = 16;
|
||||||
|
|
|
@ -39,7 +39,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -86,7 +86,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
std::move(external_file));
|
std::move(external_file));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
class ConfigureTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -153,7 +153,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
has_quantized_outputs: false)pb")));
|
has_quantized_outputs: false)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
class PostprocessingTest : public tflite::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -125,7 +125,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
||||||
return TaskRunner::Create(graph.GetConfig());
|
return TaskRunner::Create(graph.GetConfig());
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
class ConfigureTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -78,6 +78,7 @@ cc_library(
|
||||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
|
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
|
||||||
|
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
|
||||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||||
|
@ -128,9 +129,9 @@ cc_library_with_tflite(
|
||||||
srcs = ["model_resources.cc"],
|
srcs = ["model_resources.cc"],
|
||||||
hdrs = ["model_resources.h"],
|
hdrs = ["model_resources.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:verifier",
|
"@org_tensorflow//tensorflow/lite/tools:verifier",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":external_file_handler",
|
":external_file_handler",
|
||||||
|
@ -159,9 +160,9 @@ cc_test_with_tflite(
|
||||||
],
|
],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":model_resources",
|
":model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":utils",
|
":utils",
|
||||||
|
@ -186,7 +187,7 @@ cc_library_with_tflite(
|
||||||
hdrs = ["model_resources_cache.h"],
|
hdrs = ["model_resources_cache.h"],
|
||||||
tflite_deps = [
|
tflite_deps = [
|
||||||
":model_resources",
|
":model_resources",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":model_asset_bundle_resources",
|
":model_asset_bundle_resources",
|
||||||
|
@ -233,7 +234,7 @@ cc_test_with_tflite(
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
":model_resources_calculator",
|
":model_resources_calculator",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
@ -284,7 +285,7 @@ cc_test_with_tflite(
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":model_resources_cache",
|
":model_resources_cache",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:pass_through_calculator",
|
"//mediapipe/calculators/core:pass_through_calculator",
|
||||||
|
@ -317,6 +318,9 @@ cc_library(
|
||||||
":model_resources",
|
":model_resources",
|
||||||
":task_runner",
|
":task_runner",
|
||||||
":utils",
|
":utils",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework/port:requires",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
|
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||||
|
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
||||||
AddCustom("KmeansEmbeddingLookup",
|
AddCustom("KmeansEmbeddingLookup",
|
||||||
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
||||||
// For the UniversalSentenceEncoder model.
|
// For the UniversalSentenceEncoder model.
|
||||||
|
AddCustom("TFSentencepieceTokenizeOp",
|
||||||
|
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
|
||||||
AddCustom("RaggedTensorToTensor",
|
AddCustom("RaggedTensorToTensor",
|
||||||
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
|
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,8 +37,8 @@ limitations under the License.
|
||||||
#include "mediapipe/util/tflite/error_reporter.h"
|
#include "mediapipe/util/tflite/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
#include "tensorflow/lite/model_builder.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
|
#include "tensorflow/lite/tools/verifier.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -52,7 +52,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
|
|
||||||
bool ModelResources::Verifier::Verify(const char* data, int length,
|
bool ModelResources::Verifier::Verify(const char* data, int length,
|
||||||
tflite::ErrorReporter* reporter) {
|
tflite::ErrorReporter* reporter) {
|
||||||
return tflite_shims::Verify(data, length, reporter);
|
return tflite::Verify(data, length, reporter);
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelResources::ModelResources(const std::string& tag,
|
ModelResources::ModelResources(const std::string& tag,
|
||||||
|
@ -124,7 +124,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
|
||||||
// and that it uses only operators that are supported by the OpResolver
|
// and that it uses only operators that are supported by the OpResolver
|
||||||
// that was passed to the ModelResources constructor, and then builds
|
// that was passed to the ModelResources constructor, and then builds
|
||||||
// the model from the buffer.
|
// the model from the buffer.
|
||||||
auto model = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
|
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||||
buffer_data, buffer_size, &verifier_, &error_reporter_);
|
buffer_data, buffer_size, &verifier_, &error_reporter_);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
static constexpr char kInvalidFlatbufferMessage[] =
|
static constexpr char kInvalidFlatbufferMessage[] =
|
||||||
|
@ -151,8 +151,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
|
||||||
}
|
}
|
||||||
|
|
||||||
model_packet_ = MakePacket<ModelPtr>(
|
model_packet_ = MakePacket<ModelPtr>(
|
||||||
model.release(),
|
model.release(), [](tflite::FlatBufferModel* model) { delete model; });
|
||||||
[](tflite_shims::FlatBufferModel* model) { delete model; });
|
|
||||||
ASSIGN_OR_RETURN(auto model_metadata_extractor,
|
ASSIGN_OR_RETURN(auto model_metadata_extractor,
|
||||||
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
|
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
|
||||||
buffer_data, buffer_size));
|
buffer_data, buffer_size));
|
||||||
|
|
|
@ -32,10 +32,10 @@ limitations under the License.
|
||||||
#include "mediapipe/util/tflite/error_reporter.h"
|
#include "mediapipe/util/tflite/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
#include "tensorflow/lite/model_builder.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
|
#include "tensorflow/lite/tools/verifier.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -51,8 +51,8 @@ class ModelResources {
|
||||||
public:
|
public:
|
||||||
// Represents a TfLite model as a FlatBuffer.
|
// Represents a TfLite model as a FlatBuffer.
|
||||||
using ModelPtr =
|
using ModelPtr =
|
||||||
std::unique_ptr<tflite_shims::FlatBufferModel,
|
std::unique_ptr<tflite::FlatBufferModel,
|
||||||
std::function<void(tflite_shims::FlatBufferModel*)>>;
|
std::function<void(tflite::FlatBufferModel*)>>;
|
||||||
|
|
||||||
// Takes the ownership of the provided ExternalFile proto and creates
|
// Takes the ownership of the provided ExternalFile proto and creates
|
||||||
// ModelResources from the proto and an op resolver object. A non-empty tag
|
// ModelResources from the proto and an op resolver object. A non-empty tag
|
||||||
|
@ -61,7 +61,7 @@ class ModelResources {
|
||||||
static absl::StatusOr<std::unique_ptr<ModelResources>> Create(
|
static absl::StatusOr<std::unique_ptr<ModelResources>> Create(
|
||||||
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
|
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver =
|
std::unique_ptr<tflite::OpResolver> op_resolver =
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
|
|
||||||
// Takes the ownership of the provided ExternalFile proto and creates
|
// Takes the ownership of the provided ExternalFile proto and creates
|
||||||
// ModelResources from the proto and an op resolver mediapipe packet. A
|
// ModelResources from the proto and an op resolver mediapipe packet. A
|
||||||
|
|
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -124,7 +124,7 @@ void RunGraphWithGraphService(std::unique_ptr<ModelResources> model_resources,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class ModelResourcesCalculatorTest : public tflite_shims::testing::Test {};
|
class ModelResourcesCalculatorTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) {
|
TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) {
|
||||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
|
|
@ -38,9 +38,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -116,7 +116,7 @@ void CheckModelResourcesPackets(const ModelResources* model_resources) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class ModelResourcesTest : public tflite_shims::testing::Test {};
|
class ModelResourcesTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ModelResourcesTest, CreateFromBinaryContent) {
|
TEST_F(ModelResourcesTest, CreateFromBinaryContent) {
|
||||||
auto model_file = std::make_unique<proto::ExternalFile>();
|
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||||
|
@ -211,7 +211,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsFromFile) {
|
||||||
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
||||||
tflite::MutableOpResolver resolver;
|
tflite::MutableOpResolver resolver;
|
||||||
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||||
::tflite_shims::ops::builtin::Register_ADD());
|
::tflite::ops::builtin::Register_ADD());
|
||||||
resolver.AddCustom(kCustomOpName,
|
resolver.AddCustom(kCustomOpName,
|
||||||
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
||||||
|
|
||||||
|
@ -275,7 +275,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsPacket) {
|
||||||
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
||||||
tflite::MutableOpResolver resolver;
|
tflite::MutableOpResolver resolver;
|
||||||
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||||
::tflite_shims::ops::builtin::Register_ADD());
|
::tflite::ops::builtin::Register_ADD());
|
||||||
resolver.AddCustom(kCustomOpName,
|
resolver.AddCustom(kCustomOpName,
|
||||||
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/port/requires.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -54,6 +58,8 @@ class TaskApiFactory {
|
||||||
std::unique_ptr<tflite::OpResolver> resolver,
|
std::unique_ptr<tflite::OpResolver> resolver,
|
||||||
PacketsCallback packets_callback = nullptr) {
|
PacketsCallback packets_callback = nullptr) {
|
||||||
bool found_task_subgraph = false;
|
bool found_task_subgraph = false;
|
||||||
|
// This for-loop ensures there's only one subgraph besides
|
||||||
|
// FlowLimiterCalculator.
|
||||||
for (const auto& node : graph_config.node()) {
|
for (const auto& node : graph_config.node()) {
|
||||||
if (node.calculator() == "FlowLimiterCalculator") {
|
if (node.calculator() == "FlowLimiterCalculator") {
|
||||||
continue;
|
continue;
|
||||||
|
@ -64,13 +70,7 @@ class TaskApiFactory {
|
||||||
"Task graph config should only contain one task subgraph node.",
|
"Task graph config should only contain one task subgraph node.",
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
} else {
|
} else {
|
||||||
if (!node.options().HasExtension(Options::ext)) {
|
MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node));
|
||||||
return CreateStatusWithPayload(
|
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat(node.calculator(),
|
|
||||||
" is missing the required task options field."),
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
found_task_subgraph = true;
|
found_task_subgraph = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,6 +80,35 @@ class TaskApiFactory {
|
||||||
std::move(packets_callback)));
|
std::move(packets_callback)));
|
||||||
return std::make_unique<T>(std::move(runner));
|
return std::make_unique<T>(std::move(runner));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename Options>
|
||||||
|
static absl::Status CheckHasValidOptions(
|
||||||
|
const CalculatorGraphConfig::Node& node) {
|
||||||
|
if constexpr (mediapipe::Requires<Options>(
|
||||||
|
[](auto&& o) -> decltype(o.ext) {})) {
|
||||||
|
if (node.options().HasExtension(Options::ext)) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#ifndef MEDIAPIPE_PROTO_LITE
|
||||||
|
for (const auto& option : node.node_options()) {
|
||||||
|
if (absl::StrContains(option.type_url(),
|
||||||
|
Options::descriptor()->full_name())) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else // MEDIAPIPE_PROTO_LITE
|
||||||
|
// Skip the check for proto lite, as Options::descriptor() is unavailable.
|
||||||
|
return absl::OkStatus();
|
||||||
|
#endif // MEDIAPIPE_PROTO_LITE
|
||||||
|
}
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrCat(node.calculator(),
|
||||||
|
" is missing the required task options field."),
|
||||||
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -112,7 +112,7 @@ CalculatorGraphConfig GetModelSidePacketsToStreamPacketsGraphConfig(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class TaskRunnerTest : public tflite_shims::testing::Test {};
|
class TaskRunnerTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) {
|
TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) {
|
||||||
CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
|
172
mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
Normal file
172
mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
Normal file
|
@ -0,0 +1,172 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = glob([
|
||||||
|
"testdata/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "config_fbs",
|
||||||
|
srcs = ["config.fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "config",
|
||||||
|
srcs = [
|
||||||
|
"config.fbs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
flatbuffer_cc_library(
|
||||||
|
name = "encoder_config",
|
||||||
|
srcs = [
|
||||||
|
"encoder_config.fbs",
|
||||||
|
],
|
||||||
|
includes = [":config_fbs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utils",
|
||||||
|
hdrs = [
|
||||||
|
"utils.h",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "double_array_trie",
|
||||||
|
hdrs = [
|
||||||
|
"double_array_trie.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":config",
|
||||||
|
":utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "double_array_trie_builder",
|
||||||
|
srcs = [
|
||||||
|
"double_array_trie_builder.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"double_array_trie_builder.h",
|
||||||
|
],
|
||||||
|
deps = ["@darts_clone"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "double_array_trie_test",
|
||||||
|
srcs = [
|
||||||
|
"double_array_trie_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":double_array_trie",
|
||||||
|
":double_array_trie_builder",
|
||||||
|
":encoder_config",
|
||||||
|
":utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sentencepiece_constants",
|
||||||
|
hdrs = ["sentencepiece_constants.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model_converter",
|
||||||
|
srcs = [
|
||||||
|
"model_converter.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"model_converter.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":config",
|
||||||
|
":double_array_trie_builder",
|
||||||
|
":encoder_config",
|
||||||
|
":sentencepiece_constants",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "optimized_encoder",
|
||||||
|
srcs = [
|
||||||
|
"optimized_encoder.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"optimized_encoder.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":double_array_trie",
|
||||||
|
":encoder_config",
|
||||||
|
":utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sentencepiece_tokenizer_tflite",
|
||||||
|
srcs = ["sentencepiece_tokenizer_tflite.cc"],
|
||||||
|
hdrs = ["sentencepiece_tokenizer_tflite.h"],
|
||||||
|
visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
deps =
|
||||||
|
[
|
||||||
|
":optimized_encoder",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "optimized_encoder_test",
|
||||||
|
srcs = [
|
||||||
|
"optimized_encoder_test.cc",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
":testdata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":double_array_trie_builder",
|
||||||
|
":encoder_config",
|
||||||
|
":model_converter",
|
||||||
|
":optimized_encoder",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
|
||||||
|
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||||
|
"@org_tensorflow//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
25
mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
Normal file
25
mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
namespace mediapipe.tflite_operations.sentencepiece;
|
||||||
|
|
||||||
|
table Trie {
|
||||||
|
nodes: [uint32];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
enum EncoderVersion: byte {
|
||||||
|
SENTENCE_PIECE = 0,
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// A trie node specifies a node in the tree, either an intermediate node or
|
||||||
|
// a leaf node.
|
||||||
|
// A leaf node contains the id as an int of the string match. This id is encoded
|
||||||
|
// in the lower 31 bits, thus the number of distinct ids is 2^31.
|
||||||
|
// An intermediate node has an associated label and an offset to its children.
|
||||||
|
// The label is encoded in the least significant byte and must match the input
|
||||||
|
// character during matching.
|
||||||
|
|
||||||
|
// A memory mappable trie, compatible with Darts::DoubleArray.
|
||||||
|
class DoubleArrayTrie {
|
||||||
|
public:
|
||||||
|
struct Match {
|
||||||
|
Match() {}
|
||||||
|
Match(int id, int match_length) : id(id), match_length(match_length) {}
|
||||||
|
int id = -1;
|
||||||
|
int match_length = -1;
|
||||||
|
bool empty() const { return match_length == -1; }
|
||||||
|
bool operator==(const Match& m) const {
|
||||||
|
return m.id == id && m.match_length == match_length;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// nodes and nodes_length specify the array of the nodes of the trie.
|
||||||
|
explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
|
||||||
|
: nodes_(nodes) {}
|
||||||
|
|
||||||
|
// Finds matches that are prefixes of a string.
|
||||||
|
template <typename callback>
|
||||||
|
void IteratePrefixMatches(const utils::string_view& input,
|
||||||
|
callback update_fn) const;
|
||||||
|
|
||||||
|
// Finds the longest prefix match of a string.
|
||||||
|
Match LongestPrefixMatch(const utils::string_view& input) const {
|
||||||
|
Match match;
|
||||||
|
IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
|
||||||
|
return match;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns whether a node as a leaf as a child.
|
||||||
|
bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
|
||||||
|
|
||||||
|
// Returns a value associated with a node. Available when a node is a leaf.
|
||||||
|
int value(uint32_t i) const {
|
||||||
|
return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a label associated with a node.
|
||||||
|
// A leaf node will have the MSB set and thus return an invalid label.
|
||||||
|
int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
|
||||||
|
|
||||||
|
// Returns offset to children.
|
||||||
|
int32_t offset(uint32_t i) const {
|
||||||
|
const uint32_t node = (*nodes_)[i];
|
||||||
|
return (node >> 10) << ((node & 0x200) >> 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
const flatbuffers::Vector<uint32_t>* nodes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename callback>
|
||||||
|
void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
|
||||||
|
callback update_fn) const {
|
||||||
|
if (nodes_->size() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
uint32_t pos = offset(0);
|
||||||
|
for (int i = 0; i < input.length(); ++i) {
|
||||||
|
pos ^= static_cast<unsigned char>(input.at(i));
|
||||||
|
if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
|
||||||
|
// No match, exit.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const bool node_has_leaf = has_leaf(pos);
|
||||||
|
pos ^= offset(pos);
|
||||||
|
if (pos < 0 || pos >= nodes_->size()) {
|
||||||
|
// We can get here only if the structure is corrupted.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (node_has_leaf) {
|
||||||
|
update_fn(Match(value(pos), i + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
|
@ -0,0 +1,75 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "include/darts.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) {
|
||||||
|
std::vector<int> ids;
|
||||||
|
ids.reserve(data.size());
|
||||||
|
for (int i = 0; i < data.size(); ++i) {
|
||||||
|
ids.push_back(i);
|
||||||
|
}
|
||||||
|
return BuildTrie(data, ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
|
||||||
|
const std::vector<int>& ids) {
|
||||||
|
// We make strong assumptions about binary structure of trie.
|
||||||
|
struct OneElement {
|
||||||
|
OneElement(const std::string* key_, int index_)
|
||||||
|
: key(key_), index(index_) {}
|
||||||
|
const std::string* key;
|
||||||
|
int index;
|
||||||
|
bool operator<(const OneElement& el) const { return *key < *el.key; }
|
||||||
|
};
|
||||||
|
std::vector<OneElement> elements;
|
||||||
|
elements.reserve(data.size());
|
||||||
|
auto data_iterator = std::begin(data);
|
||||||
|
auto ids_iterator = std::begin(ids);
|
||||||
|
for (; data_iterator != std::end(data) && ids_iterator != std::end(ids);
|
||||||
|
++data_iterator, ++ids_iterator) {
|
||||||
|
elements.emplace_back(&(*data_iterator), *ids_iterator);
|
||||||
|
}
|
||||||
|
// Sort by keys.
|
||||||
|
std::sort(elements.begin(), elements.end());
|
||||||
|
|
||||||
|
// Create vectors to build the trie.
|
||||||
|
std::vector<const char*> strings;
|
||||||
|
std::vector<int32_t> indexes;
|
||||||
|
strings.reserve(data.size());
|
||||||
|
indexes.reserve(data.size());
|
||||||
|
for (const auto& el : elements) {
|
||||||
|
strings.push_back(el.key->c_str());
|
||||||
|
indexes.push_back(el.index);
|
||||||
|
}
|
||||||
|
auto trie = std::make_unique<Darts::DoubleArray>();
|
||||||
|
trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr,
|
||||||
|
&indexes[0]);
|
||||||
|
// We make strong assumptions about internal Darts trie structure:
|
||||||
|
// - it is a vector of 32 bit signed integers
|
||||||
|
// - the "array" is the only one structure that contains all information about
|
||||||
|
// the trie.
|
||||||
|
const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array());
|
||||||
|
return std::vector<uint32_t>(trie_data, trie_data + trie->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,32 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
|
||||||
|
const std::vector<int>& ids);
|
||||||
|
|
||||||
|
// A variant where ids are indexes in data.
|
||||||
|
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
|
@ -0,0 +1,73 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
TEST(DoubleArrayTrieTest, Match) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"};
|
||||||
|
const auto trie_vector = builder.CreateVector(BuildTrie(test_strings));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto pieces = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_pieces(pieces);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
DoubleArrayTrie dat(config->pieces()->nodes());
|
||||||
|
EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")),
|
||||||
|
DoubleArrayTrie::Match(2, 2));
|
||||||
|
|
||||||
|
std::vector<DoubleArrayTrie::Match> matches;
|
||||||
|
dat.IteratePrefixMatches(
|
||||||
|
utils::string_view("AAXL"),
|
||||||
|
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
|
||||||
|
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1),
|
||||||
|
DoubleArrayTrie::Match(2, 2),
|
||||||
|
DoubleArrayTrie::Match(1, 3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DoubleArrayTrieTest, ComplexMatch) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s",
|
||||||
|
"\xe2\x96\x81Hello"};
|
||||||
|
const std::vector<int> test_ids = {0, 5, 10, 15};
|
||||||
|
const auto trie_vector =
|
||||||
|
builder.CreateVector(BuildTrie(test_strings, test_ids));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto pieces = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_pieces(pieces);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
DoubleArrayTrie dat(config->pieces()->nodes());
|
||||||
|
|
||||||
|
std::vector<DoubleArrayTrie::Match> matches;
|
||||||
|
dat.IteratePrefixMatches(
|
||||||
|
utils::string_view("\xe2\x96\x81Hello"),
|
||||||
|
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
|
||||||
|
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
include "config.fbs";
|
||||||
|
|
||||||
|
namespace mediapipe.tflite_operations.sentencepiece;
|
||||||
|
|
||||||
|
table EncoderConfig {
|
||||||
|
// Version of the encoder.
|
||||||
|
version: EncoderVersion = SENTENCE_PIECE;
|
||||||
|
start_code: int32 = 0;
|
||||||
|
end_code: int32 = 0;
|
||||||
|
|
||||||
|
unknown_code: int32 = -1;
|
||||||
|
// Weight of "unknown code" when encoding. "Penalty" because it usually has a
|
||||||
|
// big negative weight,less than any other sentencepiece.
|
||||||
|
unknown_penalty: float = 0;
|
||||||
|
|
||||||
|
// The offset for encoding, usually used when codes with low codes are reserved
|
||||||
|
// for some special needs.
|
||||||
|
encoding_offset: int32;
|
||||||
|
|
||||||
|
// String pieces for encoding.
|
||||||
|
pieces: Trie;
|
||||||
|
pieces_scores: [float];
|
||||||
|
|
||||||
|
// Normalization related parameters.
|
||||||
|
remove_extra_whitespaces: bool;
|
||||||
|
|
||||||
|
// Add a whitespace prefix before encoding.
|
||||||
|
add_dummy_prefix: bool;
|
||||||
|
|
||||||
|
// Escape whitespaces during encoding so the decoder can restore them exactly as
|
||||||
|
// in the input.
|
||||||
|
escape_whitespaces: bool;
|
||||||
|
|
||||||
|
// Normalization parameters.
|
||||||
|
normalized_prefixes: Trie;
|
||||||
|
normalized_replacements: [byte];
|
||||||
|
}
|
||||||
|
|
||||||
|
root_type EncoderConfig;
|
|
@ -0,0 +1,131 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
|
||||||
|
#include "src/sentencepiece_model.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
|
||||||
|
DecodePrecompiledCharsmap(
|
||||||
|
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
|
||||||
|
// This function "undoes" encoding done by
|
||||||
|
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
|
||||||
|
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
|
||||||
|
const uint32_t trie_size =
|
||||||
|
*reinterpret_cast<const uint32_t*>(precompiled_map);
|
||||||
|
const uint32_t* trie_ptr =
|
||||||
|
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
|
||||||
|
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
|
||||||
|
precompiled_map + sizeof(uint32_t) + trie_size);
|
||||||
|
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
|
||||||
|
sizeof(uint32_t) - trie_size;
|
||||||
|
return std::make_tuple(
|
||||||
|
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
|
||||||
|
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||||
|
const std::string& model_config_str, int encoding_offset) {
|
||||||
|
::sentencepiece::ModelProto model_config;
|
||||||
|
if (!model_config.ParseFromString(model_config_str)) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"Invalid configuration, can't parse SentencePiece model config " +
|
||||||
|
model_config.InitializationErrorString());
|
||||||
|
}
|
||||||
|
// Convert sentencepieces.
|
||||||
|
std::vector<std::string> pieces;
|
||||||
|
pieces.reserve(model_config.pieces_size());
|
||||||
|
std::vector<float> scores;
|
||||||
|
scores.reserve(model_config.pieces_size());
|
||||||
|
std::vector<int> ids;
|
||||||
|
ids.reserve(model_config.pieces_size());
|
||||||
|
float min_score = 0.0;
|
||||||
|
int index = 0;
|
||||||
|
for (const auto& piece : model_config.pieces()) {
|
||||||
|
switch (piece.type()) {
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
|
||||||
|
pieces.push_back(piece.piece());
|
||||||
|
ids.push_back(index);
|
||||||
|
if (piece.score() < min_score) {
|
||||||
|
min_score = piece.score();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
|
||||||
|
// Ignore unknown and control codes.
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
|
||||||
|
piece.piece());
|
||||||
|
}
|
||||||
|
scores.push_back(piece.score());
|
||||||
|
++index;
|
||||||
|
}
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
|
||||||
|
const auto pieces_score_vector = builder.CreateVector(scores);
|
||||||
|
TrieBuilder pieces_trie_builder(builder);
|
||||||
|
pieces_trie_builder.add_nodes(pieces_trie_vector);
|
||||||
|
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
|
||||||
|
|
||||||
|
// Converting normalization.
|
||||||
|
const auto normalization =
|
||||||
|
DecodePrecompiledCharsmap(model_config.normalizer_spec());
|
||||||
|
const auto normalization_trie = std::get<0>(normalization);
|
||||||
|
const auto normalization_strings = std::get<1>(normalization);
|
||||||
|
const auto normalization_trie_vector =
|
||||||
|
builder.CreateVector(normalization_trie);
|
||||||
|
TrieBuilder normalization_trie_builder(builder);
|
||||||
|
normalization_trie_builder.add_nodes(normalization_trie_vector);
|
||||||
|
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
|
||||||
|
const auto normalization_strings_fbs =
|
||||||
|
builder.CreateVector(normalization_strings);
|
||||||
|
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
|
||||||
|
ecb.add_start_code(model_config.trainer_spec().bos_id());
|
||||||
|
ecb.add_end_code(model_config.trainer_spec().eos_id());
|
||||||
|
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
|
||||||
|
ecb.add_unknown_penalty(min_score - kUnkPenalty);
|
||||||
|
ecb.add_encoding_offset(encoding_offset);
|
||||||
|
ecb.add_pieces(pieces_trie_fbs);
|
||||||
|
ecb.add_pieces_scores(pieces_score_vector);
|
||||||
|
ecb.add_remove_extra_whitespaces(
|
||||||
|
model_config.normalizer_spec().remove_extra_whitespaces());
|
||||||
|
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
|
||||||
|
ecb.add_escape_whitespaces(
|
||||||
|
model_config.normalizer_spec().escape_whitespaces());
|
||||||
|
ecb.add_normalized_prefixes(normalization_trie_fbs);
|
||||||
|
ecb.add_normalized_replacements(normalization_strings_fbs);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||||
|
builder.GetSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ConvertSentencepieceModel(const std::string& model_string) {
|
||||||
|
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
|
||||||
|
assert(result.status().ok());
|
||||||
|
return result.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,33 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// Converts Sentencepiece configuration to flatbuffer format.
|
||||||
|
// encoding_offset is used by some encoders that combine different encodings.
|
||||||
|
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||||
|
const std::string& model_config_str, int encoding_offset = 0);
|
||||||
|
std::string ConvertSentencepieceModel(const std::string& model_string);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
|
@ -0,0 +1,236 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||||
|
|
||||||
|
template <typename processing_callback>
|
||||||
|
std::tuple<std::string, std::vector<int>> process_string(
|
||||||
|
const std::string& input, const std::vector<int>& offsets,
|
||||||
|
const processing_callback& pc) {
|
||||||
|
std::string result_string;
|
||||||
|
result_string.reserve(input.size());
|
||||||
|
std::vector<int> result_offsets;
|
||||||
|
result_offsets.reserve(offsets.size());
|
||||||
|
for (int i = 0, j = 0; i < input.size();) {
|
||||||
|
auto result = pc(input.data() + i, input.size() - i);
|
||||||
|
auto consumed = std::get<0>(result);
|
||||||
|
auto new_string = std::get<1>(result);
|
||||||
|
if (consumed == 0) {
|
||||||
|
// Skip the current byte and move forward.
|
||||||
|
result_string.push_back(input[i]);
|
||||||
|
result_offsets.push_back(offsets[j]);
|
||||||
|
i++;
|
||||||
|
j++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result_string.append(new_string.data(), new_string.length());
|
||||||
|
for (int i = 0; i < new_string.length(); ++i) {
|
||||||
|
result_offsets.push_back(offsets[j]);
|
||||||
|
}
|
||||||
|
j += consumed;
|
||||||
|
i += consumed;
|
||||||
|
}
|
||||||
|
return std::make_tuple(result_string, result_offsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline char is_whitespace(char c) {
|
||||||
|
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
|
||||||
|
int len) {
|
||||||
|
if (len == 0 || !is_whitespace(*data)) {
|
||||||
|
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
}
|
||||||
|
int num_consumed = 1;
|
||||||
|
for (; num_consumed < len && is_whitespace(data[num_consumed]);
|
||||||
|
++num_consumed) {
|
||||||
|
}
|
||||||
|
return num_consumed > 1
|
||||||
|
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
|
||||||
|
: std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int, utils::string_view> find_replacement(
|
||||||
|
const char* data, int len, const DoubleArrayTrie& dat,
|
||||||
|
const flatbuffers::Vector<int8_t>& replacements) {
|
||||||
|
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
|
||||||
|
if (!max_match.empty()) {
|
||||||
|
// Because flatbuffer byte is signed char which is not the same as char,
|
||||||
|
// there is the reinterpret_cast here.
|
||||||
|
const char* replaced_string_ptr =
|
||||||
|
reinterpret_cast<const char*>(replacements.data() + max_match.id);
|
||||||
|
return std::make_tuple(max_match.match_length,
|
||||||
|
utils::string_view(replaced_string_ptr));
|
||||||
|
}
|
||||||
|
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||||
|
const std::string& in_string, const EncoderConfig& config) {
|
||||||
|
std::vector<int> output_offsets;
|
||||||
|
std::string result = in_string;
|
||||||
|
output_offsets.reserve(in_string.length());
|
||||||
|
for (int i = 0; i < in_string.length(); ++i) {
|
||||||
|
output_offsets.push_back(i);
|
||||||
|
}
|
||||||
|
if (in_string.empty()) {
|
||||||
|
return std::make_tuple(result, output_offsets);
|
||||||
|
}
|
||||||
|
if (config.add_dummy_prefix()) {
|
||||||
|
result.insert(result.begin(), ' ');
|
||||||
|
output_offsets.insert(output_offsets.begin(), 0);
|
||||||
|
}
|
||||||
|
// Greedely replace normalized_prefixes with normalized_replacements
|
||||||
|
if (config.normalized_prefixes() != nullptr &&
|
||||||
|
config.normalized_replacements() != nullptr) {
|
||||||
|
const DoubleArrayTrie normalized_prefixes_matcher(
|
||||||
|
config.normalized_prefixes()->nodes());
|
||||||
|
const auto norm_replace = [&config, &normalized_prefixes_matcher](
|
||||||
|
const char* data, int len) {
|
||||||
|
return find_replacement(data, len, normalized_prefixes_matcher,
|
||||||
|
*config.normalized_replacements());
|
||||||
|
};
|
||||||
|
std::tie(result, output_offsets) =
|
||||||
|
process_string(result, output_offsets, norm_replace);
|
||||||
|
}
|
||||||
|
if (config.remove_extra_whitespaces()) {
|
||||||
|
std::tie(result, output_offsets) =
|
||||||
|
process_string(result, output_offsets, remove_extra_whitespaces);
|
||||||
|
if (!result.empty() && is_whitespace(result.back())) {
|
||||||
|
result.pop_back();
|
||||||
|
output_offsets.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (config.escape_whitespaces()) {
|
||||||
|
const auto replace_whitespaces = [](const char* data, int len) {
|
||||||
|
if (len > 0 && is_whitespace(*data)) {
|
||||||
|
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
|
||||||
|
}
|
||||||
|
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
};
|
||||||
|
std::tie(result, output_offsets) =
|
||||||
|
process_string(result, output_offsets, replace_whitespaces);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(result, output_offsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
EncoderResult EncodeNormalizedString(const std::string& str,
|
||||||
|
const std::vector<int>& offsets,
|
||||||
|
const EncoderConfig& config, bool add_bos,
|
||||||
|
bool add_eos, bool reverse) {
|
||||||
|
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
|
||||||
|
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
|
||||||
|
const int unknown_code = config.unknown_code();
|
||||||
|
const float unknown_penalty = config.unknown_penalty();
|
||||||
|
struct LatticeElement {
|
||||||
|
float score = 0;
|
||||||
|
int code = -1;
|
||||||
|
int prev_position = -1;
|
||||||
|
LatticeElement(float score_, int code_, int prev_position_)
|
||||||
|
: score(score_), code(code_), prev_position(prev_position_) {}
|
||||||
|
LatticeElement() {}
|
||||||
|
};
|
||||||
|
const int length = str.length();
|
||||||
|
std::vector<LatticeElement> lattice(length + 1);
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
if (i > 0 && lattice[i].prev_position < 0) {
|
||||||
|
// This state is unreachable.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (unknown_code >= 0) {
|
||||||
|
// Put unknown code.
|
||||||
|
const float penalized_score = lattice[i].score + unknown_penalty;
|
||||||
|
const int pos = i + 1;
|
||||||
|
LatticeElement& current_element = lattice[pos];
|
||||||
|
if (current_element.prev_position < 0 ||
|
||||||
|
current_element.score < penalized_score) {
|
||||||
|
current_element = LatticeElement(
|
||||||
|
penalized_score, unknown_code,
|
||||||
|
// If the current state is already reached by unknown code, merge
|
||||||
|
// states.
|
||||||
|
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto lattice_update = [&lattice, i,
|
||||||
|
piece_scores](const DoubleArrayTrie::Match& m) {
|
||||||
|
LatticeElement& target_element = lattice[i + m.match_length];
|
||||||
|
const float score = lattice[i].score + (*piece_scores)[m.id];
|
||||||
|
if (target_element.prev_position < 0 || target_element.score < score) {
|
||||||
|
target_element = LatticeElement(score, m.id, i);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
piece_matcher.IteratePrefixMatches(
|
||||||
|
utils::string_view(str.data() + i, length - i), lattice_update);
|
||||||
|
}
|
||||||
|
|
||||||
|
EncoderResult result;
|
||||||
|
if (add_eos) {
|
||||||
|
result.codes.push_back(config.end_code());
|
||||||
|
result.offsets.push_back(length);
|
||||||
|
}
|
||||||
|
if (lattice[length].prev_position >= 0) {
|
||||||
|
for (int pos = length; pos > 0;) {
|
||||||
|
auto code = lattice[pos].code;
|
||||||
|
if (code != config.unknown_code()) {
|
||||||
|
code += config.encoding_offset();
|
||||||
|
}
|
||||||
|
result.codes.push_back(code);
|
||||||
|
pos = lattice[pos].prev_position;
|
||||||
|
result.offsets.push_back(offsets[pos]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_bos) {
|
||||||
|
result.codes.push_back(config.start_code());
|
||||||
|
result.offsets.push_back(0);
|
||||||
|
}
|
||||||
|
if (!reverse) {
|
||||||
|
std::reverse(result.codes.begin(), result.codes.end());
|
||||||
|
std::reverse(result.offsets.begin(), result.offsets.end());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||||
|
bool add_bos, bool add_eos, bool reverse) {
|
||||||
|
// Get the config from the buffer.
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(config_buffer);
|
||||||
|
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
|
||||||
|
EncoderResult result;
|
||||||
|
result.type = EncoderResultType::WRONG_CONFIG;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
std::string normalized_string;
|
||||||
|
std::vector<int> offsets;
|
||||||
|
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
|
||||||
|
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
|
||||||
|
add_eos, reverse);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,46 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||||
|
|
||||||
|
// Sentencepiece encoder optimized with memmapped model.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
|
||||||
|
|
||||||
|
struct EncoderResult {
|
||||||
|
EncoderResultType type = EncoderResultType::SUCCESS;
|
||||||
|
std::vector<int> codes;
|
||||||
|
std::vector<int> offsets;
|
||||||
|
};
|
||||||
|
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||||
|
const std::string& in_string, const EncoderConfig& config);
|
||||||
|
|
||||||
|
// Encodes one string and returns ids and offsets. Takes the configuration as a
|
||||||
|
// type-erased buffer.
|
||||||
|
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||||
|
bool add_bos, bool add_eos, bool reverse);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
|
@ -0,0 +1,171 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||||
|
#include "src/sentencepiece.pb.h"
|
||||||
|
#include "src/sentencepiece_processor.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
tensorflow::Status TFReadFileToString(const std::string& filepath,
|
||||||
|
std::string* data) {
|
||||||
|
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
|
||||||
|
data);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status StdReadFileToString(const std::string& filepath,
|
||||||
|
std::string* data) {
|
||||||
|
std::ifstream infile(filepath);
|
||||||
|
if (!infile.is_open()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrFormat("Error when opening %s", filepath));
|
||||||
|
}
|
||||||
|
std::string contents((std::istreambuf_iterator<char>(infile)),
|
||||||
|
(std::istreambuf_iterator<char>()));
|
||||||
|
data->append(contents);
|
||||||
|
infile.close();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
|
||||||
|
static char kConfigFilePath[] =
|
||||||
|
"/mediapipe/tasks/cc/text/custom_ops/"
|
||||||
|
"sentencepiece/testdata/sentencepiece.model";
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_remove_extra_whitespaces(true);
|
||||||
|
ecb.add_add_dummy_prefix(true);
|
||||||
|
ecb.add_escape_whitespaces(true);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("x y", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||||
|
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("\tx y\n", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||||
|
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, NormalizeStringReplacement) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
|
||||||
|
const char norm_replacements[] = "A1\0A2\0A3\0A4";
|
||||||
|
const auto trie_vector =
|
||||||
|
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
|
||||||
|
const auto norm_r = builder.CreateVector<int8_t>(
|
||||||
|
reinterpret_cast<const signed char*>(norm_replacements),
|
||||||
|
sizeof(norm_replacements));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto norm_p = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_remove_extra_whitespaces(false);
|
||||||
|
ecb.add_normalized_prefixes(norm_p);
|
||||||
|
ecb.add_normalized_replacements(norm_r);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("ABAABAAABAAAA", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, "A1BA2BA3BA4");
|
||||||
|
EXPECT_THAT(offsets,
|
||||||
|
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
|
||||||
|
"X"};
|
||||||
|
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
|
||||||
|
const auto trie_vector =
|
||||||
|
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
|
||||||
|
const auto norm_r = builder.CreateVector<int8_t>(
|
||||||
|
reinterpret_cast<const signed char*>(norm_replacements),
|
||||||
|
sizeof(norm_replacements));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto norm_p = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_remove_extra_whitespaces(true);
|
||||||
|
ecb.add_normalized_prefixes(norm_p);
|
||||||
|
ecb.add_normalized_replacements(norm_r);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, " A1BA2BA3BA4");
|
||||||
|
EXPECT_THAT(offsets,
|
||||||
|
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, ConfigConverter) {
|
||||||
|
std::string config;
|
||||||
|
auto status =
|
||||||
|
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
|
||||||
|
ASSERT_TRUE(status.ok());
|
||||||
|
|
||||||
|
::sentencepiece::SentencePieceProcessor processor;
|
||||||
|
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
|
||||||
|
const auto converted_model = ConvertSentencepieceModel(config);
|
||||||
|
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
|
||||||
|
const auto encoded =
|
||||||
|
EncodeString(test_string, converted_model.data(), false, false, false);
|
||||||
|
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
|
||||||
|
|
||||||
|
::sentencepiece::SentencePieceText reference_encoded;
|
||||||
|
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
|
||||||
|
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
|
||||||
|
for (int i = 0; i < encoded.codes.size(); ++i) {
|
||||||
|
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
|
||||||
|
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,38 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// The constant is copied from
|
||||||
|
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
|
||||||
|
constexpr float kUnkPenalty = 10.0;
|
||||||
|
|
||||||
|
// These constants are copied from
|
||||||
|
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||||
|
//
|
||||||
|
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
|
||||||
|
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||||
|
|
||||||
|
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
|
||||||
|
// since this character can be useful both for user and
|
||||||
|
// developer. We can easily figure out that <unk> is emitted.
|
||||||
|
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
|
@ -0,0 +1,129 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/context.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations {
|
||||||
|
namespace sentencepiece::tokenizer {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::tflite::SetTensorToDynamic;
|
||||||
|
|
||||||
|
constexpr int kSPModelIndex = 0;
|
||||||
|
constexpr int kInputIndex = 1;
|
||||||
|
constexpr int kAddBOSInput = 4;
|
||||||
|
constexpr int kAddEOSInput = 5;
|
||||||
|
constexpr int kReverseInput = 6;
|
||||||
|
|
||||||
|
constexpr int kOutputValuesInd = 0;
|
||||||
|
constexpr int kOutputSplitsInd = 1;
|
||||||
|
|
||||||
|
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
|
||||||
|
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
|
||||||
|
int index = 0;
|
||||||
|
for (const int size : sizes) {
|
||||||
|
array_size->data[index++] = size;
|
||||||
|
}
|
||||||
|
return array_size;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Initializes text encoder object from serialized parameters.
|
||||||
|
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
|
||||||
|
size_t /*length*/) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
// TODO: Add checks for input and output tensors.
|
||||||
|
TfLiteTensor& output_values =
|
||||||
|
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||||
|
SetTensorToDynamic(&output_values);
|
||||||
|
|
||||||
|
TfLiteTensor& output_splits =
|
||||||
|
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||||
|
SetTensorToDynamic(&output_splits);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor& model_tensor =
|
||||||
|
context->tensors[node->inputs->data[kSPModelIndex]];
|
||||||
|
const auto model_buffer_data = model_tensor.data.data;
|
||||||
|
const TfLiteTensor& input_text =
|
||||||
|
context->tensors[node->inputs->data[kInputIndex]];
|
||||||
|
|
||||||
|
const TfLiteTensor add_bos_tensor =
|
||||||
|
context->tensors[node->inputs->data[kAddBOSInput]];
|
||||||
|
const bool add_bos = add_bos_tensor.data.b[0];
|
||||||
|
const TfLiteTensor add_eos_tensor =
|
||||||
|
context->tensors[node->inputs->data[kAddEOSInput]];
|
||||||
|
const bool add_eos = add_eos_tensor.data.b[0];
|
||||||
|
const TfLiteTensor reverse_tensor =
|
||||||
|
context->tensors[node->inputs->data[kReverseInput]];
|
||||||
|
const bool reverse = reverse_tensor.data.b[0];
|
||||||
|
|
||||||
|
std::vector<int32> encoded;
|
||||||
|
std::vector<int32> splits;
|
||||||
|
const int num_strings = tflite::GetStringCount(&input_text);
|
||||||
|
for (int i = 0; i < num_strings; ++i) {
|
||||||
|
const auto strref = tflite::GetString(&input_text, i);
|
||||||
|
const auto res = EncodeString(std::string(strref.str, strref.len),
|
||||||
|
model_buffer_data, add_bos, add_eos, reverse);
|
||||||
|
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
|
||||||
|
"Sentencepiece conversion failed");
|
||||||
|
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
|
||||||
|
splits.emplace_back(encoded.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteTensor& output_values =
|
||||||
|
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
context->ResizeTensor(
|
||||||
|
context, &output_values,
|
||||||
|
CreateSizeArray({static_cast<int>(encoded.size())})));
|
||||||
|
int32_t* output_values_flat = output_values.data.i32;
|
||||||
|
std::copy(encoded.begin(), encoded.end(), output_values_flat);
|
||||||
|
TfLiteTensor& output_splits =
|
||||||
|
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(
|
||||||
|
context, &output_splits,
|
||||||
|
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
|
||||||
|
int32_t* output_splits_flat = output_splits.data.i32;
|
||||||
|
*output_splits_flat = 0;
|
||||||
|
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace sentencepiece::tokenizer
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
|
||||||
|
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
Binary file not shown.
60
mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
Normal file
60
mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// AOSP and WASM doesn't support string_view,
|
||||||
|
// we put here a minimal re-implementation.
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
class string_view {
|
||||||
|
public:
|
||||||
|
explicit string_view(const std::string& s)
|
||||||
|
: str_(s.data()), len_(s.length()) {}
|
||||||
|
string_view(const char* str, int len) : str_(str), len_(len) {}
|
||||||
|
// A constructor from c string.
|
||||||
|
explicit string_view(const char* s) : str_(s), len_(strlen(s)) {}
|
||||||
|
|
||||||
|
int length() const { return len_; }
|
||||||
|
const char* data() const { return str_; }
|
||||||
|
bool empty() const { return len_ == 0; }
|
||||||
|
unsigned char at(int i) const { return str_[i]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
const char* str_ = nullptr;
|
||||||
|
const int len_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(std::ostream& os, const string_view& sv) {
|
||||||
|
os << std::string(sv.data(), sv.length());
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
inline bool operator==(const string_view& view1, const string_view& view2) {
|
||||||
|
if (view1.length() != view2.length()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return memcmp(view1.data(), view2.data(), view1.length()) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::language_detector {
|
namespace mediapipe::tasks::text::language_detector {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -75,7 +75,7 @@ absl::Status MatchesLanguageDetectorResult(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class LanguageDetectorTest : public tflite_shims::testing::Test {};
|
class LanguageDetectorTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
||||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||||
|
|
|
@ -89,7 +89,7 @@ cc_test(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:cord",
|
"@com_google_absl//absl/strings:cord",
|
||||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::text_classifier {
|
namespace mediapipe::tasks::text::text_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -87,7 +87,7 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
class TextClassifierTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||||
auto options = std::make_unique<TextClassifierOptions>();
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
|
|
@ -91,6 +91,6 @@ cc_test(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::text_embedder {
|
namespace mediapipe::tasks::text::text_embedder {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
|
||||||
// Embedding model with regex preprocessing.
|
// Embedding model with regex preprocessing.
|
||||||
constexpr char kRegexOneEmbeddingModel[] =
|
constexpr char kRegexOneEmbeddingModel[] =
|
||||||
"regex_one_embedding_with_metadata.tflite";
|
"regex_one_embedding_with_metadata.tflite";
|
||||||
|
constexpr char kUniversalSentenceEncoderModel[] =
|
||||||
|
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||||
|
|
||||||
// Tolerance for embedding vector coordinate values.
|
// Tolerance for embedding vector coordinate values.
|
||||||
constexpr float kEpsilon = 1e-4;
|
constexpr float kEpsilon = 1e-4;
|
||||||
|
@ -49,7 +51,7 @@ using ::mediapipe::file::JoinPath;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
class EmbedderTest : public tflite_shims::testing::Test {};
|
class EmbedderTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
||||||
auto text_embedder =
|
auto text_embedder =
|
||||||
|
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
|
||||||
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||||
|
TextEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result0,
|
||||||
|
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||||
|
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||||
|
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
|
||||||
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||||
|
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||||
|
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
|
||||||
|
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
|
result1.embeddings[0]));
|
||||||
|
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||||
auto options = std::make_unique<TextEmbedderOptions>();
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
|
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
|
||||||
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||||
|
TextEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextEmbedderResult result0,
|
||||||
|
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||||
|
"pancake upside-down before they hand it "
|
||||||
|
"to you. It's a great gimmick."));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextEmbedderResult result1,
|
||||||
|
text_embedder->Embed(
|
||||||
|
"Let's make a plan to steal the declaration of independence."));
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
|
result1.embeddings[0]));
|
||||||
|
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe::tasks::text::text_embedder
|
} // namespace mediapipe::tasks::text::text_embedder
|
||||||
|
|
|
@ -81,6 +81,6 @@ cc_test(
|
||||||
"@com_google_absl//absl/flags:flag",
|
"@com_google_absl//absl/flags:flag",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::text::utils {
|
namespace mediapipe::tasks::text::utils {
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class TextModelUtilsTest : public tflite_shims::testing::Test {};
|
class TextModelUtilsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||||
|
|
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -105,7 +105,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
||||||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
class FaceBlendshapesTest : public tflite_shims::testing::Test {};
|
class FaceBlendshapesTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(FaceBlendshapesTest, SmokeTest) {
|
TEST_F(FaceBlendshapesTest, SmokeTest) {
|
||||||
// Prepare graph inputs.
|
// Prepare graph inputs.
|
||||||
|
|
|
@ -67,7 +67,7 @@ constexpr char kPortraitExpectedFaceLandmarksName[] =
|
||||||
"portrait_expected_face_landmarks.pbtxt";
|
"portrait_expected_face_landmarks.pbtxt";
|
||||||
constexpr char kPortraitExpectedBlendshapesName[] =
|
constexpr char kPortraitExpectedBlendshapesName[] =
|
||||||
"portrait_expected_blendshapes.pbtxt";
|
"portrait_expected_blendshapes.pbtxt";
|
||||||
constexpr char kPortaitExpectedFaceGeomertyName[] =
|
constexpr char kPortraitExpectedFaceGeometryName[] =
|
||||||
"portrait_expected_face_geometry.pbtxt";
|
"portrait_expected_face_geometry.pbtxt";
|
||||||
|
|
||||||
constexpr float kLandmarksDiffMargin = 0.03;
|
constexpr float kLandmarksDiffMargin = 0.03;
|
||||||
|
@ -100,7 +100,7 @@ struct FaceLandmarkerTestParams {
|
||||||
|
|
||||||
mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() {
|
mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() {
|
||||||
auto face_geometry = GetExpectedProto<face_geometry::proto::FaceGeometry>(
|
auto face_geometry = GetExpectedProto<face_geometry::proto::FaceGeometry>(
|
||||||
kPortaitExpectedFaceGeomertyName);
|
kPortraitExpectedFaceGeometryName);
|
||||||
return face_geometry.pose_transform_matrix();
|
return face_geometry.pose_transform_matrix();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,18 +23,12 @@ cc_library(
|
||||||
srcs = ["face_stylizer_graph.cc"],
|
srcs = ["face_stylizer_graph.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
|
|
||||||
"//mediapipe/calculators/image:warp_affine_calculator",
|
|
||||||
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
|
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensor:inference_calculator",
|
"//mediapipe/calculators/tensor:inference_calculator",
|
||||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||||
"//mediapipe/calculators/util:face_to_rect_calculator",
|
"//mediapipe/calculators/util:face_to_rect_calculator",
|
||||||
"//mediapipe/calculators/util:from_image_calculator",
|
|
||||||
"//mediapipe/calculators/util:inverse_matrix_calculator",
|
|
||||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
|
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/util:to_image_calculator",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
@ -53,7 +47,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
|
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
|
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
||||||
|
|
|
@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The input image can be of any size with format RGB or RGBA.
|
// The input image can be of any size with format RGB or RGBA.
|
||||||
// When no face is detected on the input image, the method returns a
|
// When no face is detected on the input image, the method returns a
|
||||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||||
// face. To ensure that the output image has reasonable quality, the stylized
|
// face. The stylized output image size is the same as the model output size.
|
||||||
// output image size is the smaller of the model output size and the size of
|
|
||||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
|
||||||
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
|
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
// When no face is detected on the input image, the method returns a
|
// When no face is detected on the input image, the method returns a
|
||||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||||
// face. To ensure that the output image has reasonable quality, the stylized
|
// face. The stylized output image size is the same as the model output size.
|
||||||
// output image size is the smaller of the model output size and the size of
|
|
||||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
|
||||||
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
|
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
|
||||||
mediapipe::Image image, int64_t timestamp_ms,
|
mediapipe::Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The "result_callback" provides:
|
// The "result_callback" provides:
|
||||||
// - When no face is detected on the input image, the method returns a
|
// - When no face is detected on the input image, the method returns a
|
||||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||||
// face. To ensure that the output image has reasonable quality, the
|
// face. The stylized output image size is the same as the model output
|
||||||
// stylized output image size is the smaller of the model output size and
|
// size.
|
||||||
// the size of the 'region_of_interest' specified in
|
|
||||||
// 'image_processing_options'.
|
|
||||||
// - The input timestamp in milliseconds.
|
// - The input timestamp in milliseconds.
|
||||||
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
|
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions>
|
std::optional<core::ImageProcessingOptions>
|
||||||
|
|
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
|
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
|
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
|
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
face_rect >> preprocessing.In(kNormRectTag);
|
face_rect >> preprocessing.In(kNormRectTag);
|
||||||
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
||||||
auto transform_matrix = preprocessing.Out(kMatrixTag);
|
|
||||||
|
|
||||||
// Adds inference subgraph and connects its input stream to the output
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
// tensors produced by the ImageToTensorCalculator.
|
// tensors produced by the ImageToTensorCalculator.
|
||||||
|
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
||||||
model_output_tensors >> tensors_to_image.In(kTensorsTag);
|
model_output_tensors >> tensors_to_image.In(kTensorsTag);
|
||||||
auto tensor_image = tensors_to_image.Out(kImageTag);
|
auto tensor_image = tensors_to_image.Out(kImageTag);
|
||||||
|
|
||||||
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
auto& image_converter = graph.AddNode("ImageCloneCalculator");
|
||||||
transform_matrix >> inverse_matrix.In(kMatrixTag);
|
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
|
||||||
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
|
.set_output_on_gpu(false);
|
||||||
|
tensor_image >> image_converter.In("");
|
||||||
|
|
||||||
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
|
||||||
auto& warp_affine_options =
|
|
||||||
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
|
|
||||||
warp_affine_options.set_border_mode(
|
|
||||||
WarpAffineCalculatorOptions::BORDER_ZERO);
|
|
||||||
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
|
|
||||||
tensor_image >> warp_affine.In(kImageTag);
|
|
||||||
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
|
|
||||||
image_size >> warp_affine.In(kOutputSizeTag);
|
|
||||||
auto image_to_crop = warp_affine.Out(kImageTag);
|
|
||||||
|
|
||||||
// The following calculators are for cropping and resizing the output image
|
|
||||||
// based on the roi and the model output size. As the WarpAffineCalculator
|
|
||||||
// rotates the image based on the transform matrix, the rotation info in the
|
|
||||||
// rect proto is stripped to prevent the ImageCroppingCalculator from
|
|
||||||
// performing extra rotation.
|
|
||||||
auto& strip_rotation =
|
|
||||||
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
|
|
||||||
face_rect >> strip_rotation.In(kNormRectTag);
|
|
||||||
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
|
|
||||||
auto& from_image = graph.AddNode("FromImageCalculator");
|
|
||||||
image_to_crop >> from_image.In(kImageTag);
|
|
||||||
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
|
|
||||||
auto& image_cropping_opts =
|
|
||||||
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
|
|
||||||
image_cropping_opts.set_output_max_width(
|
|
||||||
image_to_tensor_options.output_tensor_width());
|
|
||||||
image_cropping_opts.set_output_max_height(
|
|
||||||
image_to_tensor_options.output_tensor_height());
|
|
||||||
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
|
|
||||||
auto& to_image = graph.AddNode("ToImageCalculator");
|
|
||||||
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
|
|
||||||
// graph selects its cpu or gpu path based on the image preprocessing
|
|
||||||
// backend.
|
|
||||||
if (use_gpu) {
|
|
||||||
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
|
|
||||||
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
|
|
||||||
} else {
|
|
||||||
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
|
|
||||||
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
|
|
||||||
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
|
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -43,7 +43,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -137,7 +137,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
||||||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
class HandLandmarkerTest : public tflite_shims::testing::Test {};
|
class HandLandmarkerTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(HandLandmarkerTest, Succeeds) {
|
TEST_F(HandLandmarkerTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -41,7 +41,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
|
|
@ -59,7 +59,6 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::utils::AllowIf;
|
using ::mediapipe::tasks::components::utils::AllowIf;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
|
||||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||||
HandLandmarksDetectorGraphOptions;
|
HandLandmarksDetectorGraphOptions;
|
||||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||||
|
|
|
@ -146,7 +146,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner(
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to create a Multi Hand Landmark TaskRunner.
|
// Helper function to create a Multi Hand Landmark TaskRunner.
|
||||||
|
@ -188,7 +188,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner(
|
||||||
|
|
||||||
return TaskRunner::Create(
|
return TaskRunner::Create(
|
||||||
graph.GetConfig(),
|
graph.GetConfig(),
|
||||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||||
}
|
}
|
||||||
|
|
||||||
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
||||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -148,7 +148,7 @@ class MobileNetQuantizedOpResolverMissingOps
|
||||||
const MobileNetQuantizedOpResolverMissingOps& r) = delete;
|
const MobileNetQuantizedOpResolverMissingOps& r) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateTest : public tflite_shims::testing::Test {};
|
class CreateTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ImageClassifierOptions>();
|
auto options = std::make_unique<ImageClassifierOptions>();
|
||||||
|
@ -265,7 +265,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -605,7 +605,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
||||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -707,7 +707,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK(image_classifier->Close());
|
MP_ASSERT_OK(image_classifier->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -30,9 +30,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -103,7 +103,7 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
delete;
|
delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateTest : public tflite_shims::testing::Test {};
|
class CreateTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||||
|
@ -181,7 +181,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -410,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -494,7 +494,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK(image_embedder->Close());
|
MP_ASSERT_OK(image_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -180,7 +180,7 @@ class DeepLabOpResolver : public ::tflite::MutableOpResolver {
|
||||||
DeepLabOpResolver(const DeepLabOpResolver& r) = delete;
|
DeepLabOpResolver(const DeepLabOpResolver& r) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
public:
|
public:
|
||||||
|
@ -268,7 +268,7 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -521,7 +521,7 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -581,7 +581,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK(segmenter->Close());
|
MP_ASSERT_OK(segmenter->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
|
|
@ -64,7 +64,6 @@ using ::mediapipe::CalculatorGraphConfig;
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
|
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
|
||||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
|
||||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
#include "testing/base/public/gmock.h"
|
#include "testing/base/public/gmock.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -124,7 +124,7 @@ MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
|
||||||
similarity_threshold;
|
similarity_threshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||||
public:
|
public:
|
||||||
|
@ -261,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||||
info) { return info.param.test_name; });
|
info) { return info.param.test_name; });
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
// TODO: fix this unit test after image segmenter handled post
|
// TODO: fix this unit test after image segmenter handled post
|
||||||
// processing correctly with rotated image.
|
// processing correctly with rotated image.
|
||||||
|
|
|
@ -43,9 +43,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
|
||||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/test_util.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -159,7 +159,7 @@ class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver {
|
||||||
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
|
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
||||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||||
|
@ -332,7 +332,7 @@ TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
|
||||||
// TODO: Add NumThreadsTest back after having an
|
// TODO: Add NumThreadsTest back after having an
|
||||||
// "acceleration configuration" field in the ObjectDetectorOptions.
|
// "acceleration configuration" field in the ObjectDetectorOptions.
|
||||||
|
|
||||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
class ImageModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
@ -618,7 +618,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||||
}
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
@ -673,7 +673,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK(object_detector->Close());
|
MP_ASSERT_OK(object_detector->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||||
|
|
||||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||||
|
|
|
@ -97,8 +97,10 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
// limit the number of frames in flight.
|
// limit the number of frames in flight.
|
||||||
CalculatorGraphConfig CreateGraphConfig(
|
CalculatorGraphConfig CreateGraphConfig(
|
||||||
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
|
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
|
||||||
bool enable_flow_limiting) {
|
bool enable_flow_limiting, bool output_segmentation_masks) {
|
||||||
api2::builder::Graph graph;
|
api2::builder::Graph graph;
|
||||||
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
|
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
|
||||||
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
|
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
|
||||||
graph.Out(kSegmentationMaskTag);
|
|
||||||
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
|
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
|
||||||
graph.Out(kNormLandmarksTag);
|
graph.Out(kNormLandmarksTag);
|
||||||
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
||||||
|
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
||||||
graph.Out(kPoseAuxiliaryLandmarksTag);
|
graph.Out(kPoseAuxiliaryLandmarksTag);
|
||||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||||
|
if (output_segmentation_masks) {
|
||||||
|
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||||
|
graph.Out(kSegmentationMaskTag);
|
||||||
|
}
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
|
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
|
||||||
|
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
|
||||||
PoseLandmarkerGraphOptionsProto>(
|
PoseLandmarkerGraphOptionsProto>(
|
||||||
CreateGraphConfig(
|
CreateGraphConfig(
|
||||||
std::move(options_proto),
|
std::move(options_proto),
|
||||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
options->running_mode == core::RunningMode::LIVE_STREAM,
|
||||||
|
options->output_segmentation_masks),
|
||||||
std::move(options->base_options.op_resolver), options->running_mode,
|
std::move(options->base_options.op_resolver), options->running_mode,
|
||||||
std::move(packets_callback))));
|
std::move(packets_callback))));
|
||||||
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
|
||||||
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
||||||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||||
Source<std::vector<Detection>> pose_detections;
|
Source<std::vector<Detection>> pose_detections;
|
||||||
Source<std::vector<Image>> segmentation_masks;
|
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||||
// input_stream: "IMAGE:image_in"
|
// input_stream: "IMAGE:image_in"
|
||||||
// input_stream: "NORM_RECT:norm_rect"
|
// input_stream: "NORM_RECT:norm_rect"
|
||||||
// output_stream: "NORM_LANDMARKS:pose_landmarks"
|
// output_stream: "NORM_LANDMARKS:pose_landmarks"
|
||||||
// output_stream: "LANDMARKS:world_landmarks"
|
// output_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||||
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
|
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
|
||||||
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
|
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
|
||||||
// output_stream: "POSE_RECTS:pose_rects"
|
// output_stream: "POSE_RECTS:pose_rects"
|
||||||
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
|
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
|
||||||
|
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
bool output_segmentation_masks =
|
||||||
|
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||||
if (sc->Options<PoseLandmarkerGraphOptions>()
|
if (sc->Options<PoseLandmarkerGraphOptions>()
|
||||||
.base_options()
|
.base_options()
|
||||||
.has_model_asset()) {
|
.has_model_asset()) {
|
||||||
|
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||||
.IsAvailable()));
|
.IsAvailable()));
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(auto outs,
|
||||||
auto outs,
|
|
||||||
BuildPoseLandmarkerGraph(
|
BuildPoseLandmarkerGraph(
|
||||||
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||||
|
graph, output_segmentation_masks));
|
||||||
outs.landmark_lists >>
|
outs.landmark_lists >>
|
||||||
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
||||||
outs.world_landmark_lists >>
|
outs.world_landmark_lists >>
|
||||||
|
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
kAuxiliaryLandmarksTag)];
|
kAuxiliaryLandmarksTag)];
|
||||||
outs.pose_rects_next_frame >>
|
outs.pose_rects_next_frame >>
|
||||||
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
||||||
outs.segmentation_masks >>
|
|
||||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
|
||||||
outs.pose_detections >>
|
outs.pose_detections >>
|
||||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||||
outs.image >> graph[Output<Image>(kImageTag)];
|
outs.image >> graph[Output<Image>(kImageTag)];
|
||||||
|
if (outs.segmentation_masks) {
|
||||||
|
*outs.segmentation_masks >>
|
||||||
|
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||||
|
}
|
||||||
|
|
||||||
// TODO remove when support is fixed.
|
// TODO remove when support is fixed.
|
||||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
// As mediapipe GraphBuilder currently doesn't support configuring
|
||||||
|
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
// graph: the mediapipe graph instance to be updated.
|
// graph: the mediapipe graph instance to be updated.
|
||||||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
|
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
|
||||||
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
||||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, Graph& graph,
|
||||||
|
bool output_segmentation_masks) {
|
||||||
const int max_num_poses =
|
const int max_num_poses =
|
||||||
tasks_options.pose_detector_graph_options().num_poses();
|
tasks_options.pose_detector_graph_options().num_poses();
|
||||||
|
|
||||||
|
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
auto pose_rects_for_next_frame =
|
auto pose_rects_for_next_frame =
|
||||||
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
|
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
|
||||||
.Cast<std::vector<NormalizedRect>>();
|
.Cast<std::vector<NormalizedRect>>();
|
||||||
auto segmentation_masks =
|
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||||
|
if (output_segmentation_masks) {
|
||||||
|
segmentation_masks =
|
||||||
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
||||||
.Cast<std::vector<Image>>();
|
.Cast<std::vector<Image>>();
|
||||||
|
}
|
||||||
|
|
||||||
if (tasks_options.base_options().use_stream_mode()) {
|
if (tasks_options.base_options().use_stream_mode()) {
|
||||||
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
|
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user