Merge branch 'google:master' into face-stylizer-python-add-tests
This commit is contained in:
commit
a5716c9225
10
WORKSPACE
10
WORKSPACE
|
@ -239,6 +239,16 @@ http_archive(
|
|||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "darts_clone",
|
||||
build_file = "@//third_party:darts_clone.BUILD",
|
||||
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
|
||||
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
|
||||
urls = [
|
||||
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||
],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "org_tensorflow_text",
|
||||
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",
|
||||
|
|
|
@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
|||
} else if (packet_options.has_string_value()) {
|
||||
packet.Set<std::string>();
|
||||
} else if (packet_options.has_uint64_value()) {
|
||||
packet.Set<uint64>();
|
||||
packet.Set<uint64_t>();
|
||||
} else if (packet_options.has_classification_list_value()) {
|
||||
packet.Set<ClassificationList>();
|
||||
} else if (packet_options.has_landmark_list_value()) {
|
||||
|
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
|||
} else if (packet_options.has_string_value()) {
|
||||
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
||||
} else if (packet_options.has_uint64_value()) {
|
||||
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
|
||||
packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
|
||||
} else if (packet_options.has_classification_list_value()) {
|
||||
packet.Set(MakePacket<ClassificationList>(
|
||||
packet_options.classification_list_value()));
|
||||
|
|
|
@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
||||
void RunTimeStep(int64 timestamp, bool stream_payload) {
|
||||
void RunTimeStep(int64_t timestamp, bool stream_payload) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
||||
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||
}
|
||||
|
||||
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
||||
void RunTimeStep(int64 timestamp, const std::string& control_tag,
|
||||
void RunTimeStep(int64_t timestamp, const std::string& control_tag,
|
||||
bool control) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
||||
|
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
|
|||
}
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
|
|||
}
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
|
|||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
|
|||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
|
|||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
|
|
@ -35,7 +35,7 @@ class MatrixToVectorCalculatorTest
|
|||
void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; }
|
||||
|
||||
void AppendInput(const std::vector<float>& column_major_data,
|
||||
int64 timestamp) {
|
||||
int64_t timestamp) {
|
||||
ASSERT_EQ(num_input_samples_ * num_input_channels_,
|
||||
column_major_data.size());
|
||||
Eigen::Map<const Matrix> data_map(&column_major_data[0],
|
||||
|
|
|
@ -51,9 +51,9 @@ class SimpleRunner : public CalculatorRunner {
|
|||
|
||||
virtual ~SimpleRunner() {}
|
||||
|
||||
void SetInput(const std::vector<int64>& timestamp_list) {
|
||||
void SetInput(const std::vector<int64_t>& timestamp_list) {
|
||||
MutableInputs()->Index(0).packets.clear();
|
||||
for (const int64 ts : timestamp_list) {
|
||||
for (const int64_t ts : timestamp_list) {
|
||||
MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(new std::string(absl::StrCat("Frame #", ts)))
|
||||
.At(Timestamp(ts)));
|
||||
|
@ -72,8 +72,8 @@ class SimpleRunner : public CalculatorRunner {
|
|||
}
|
||||
|
||||
void CheckOutputTimestamps(
|
||||
const std::vector<int64>& expected_frames,
|
||||
const std::vector<int64>& expected_timestamps) const {
|
||||
const std::vector<int64_t>& expected_frames,
|
||||
const std::vector<int64_t>& expected_timestamps) const {
|
||||
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
|
||||
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
|
||||
int count = 0;
|
||||
|
@ -112,7 +112,7 @@ MATCHER_P2(PacketAtTimestamp, payload, timestamp,
|
|||
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
|
||||
return false;
|
||||
}
|
||||
int64 actual_payload = arg.template Get<int64>();
|
||||
int64_t actual_payload = arg.template Get<int64_t>();
|
||||
if (actual_payload != payload) {
|
||||
*result_listener << "with incorrect payload = " << actual_payload;
|
||||
return false;
|
||||
|
@ -137,18 +137,18 @@ class ReproducibleJitterWithReflectionStrategyForTesting
|
|||
//
|
||||
// An EXPECT will fail if sequence is less than the number requested during
|
||||
// processing.
|
||||
static std::vector<uint64> random_sequence;
|
||||
static std::vector<uint64_t> random_sequence;
|
||||
|
||||
protected:
|
||||
virtual uint64 GetNextRandom(uint64 n) {
|
||||
virtual uint64_t GetNextRandom(uint64_t n) {
|
||||
EXPECT_LT(sequence_index_, random_sequence.size());
|
||||
return random_sequence[sequence_index_++] % n;
|
||||
}
|
||||
|
||||
private:
|
||||
int32 sequence_index_ = 0;
|
||||
int32_t sequence_index_ = 0;
|
||||
};
|
||||
std::vector<uint64>
|
||||
std::vector<uint64_t>
|
||||
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
|
||||
|
||||
// PacketResamplerCalculator child class which injects a specified stream
|
||||
|
@ -469,7 +469,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
|
|||
}
|
||||
)pb"));
|
||||
|
||||
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
|
||||
for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) {
|
||||
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
|
||||
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
|
||||
}
|
||||
|
|
|
@ -43,8 +43,8 @@ constexpr char kDisallowTag[] = "DISALLOW";
|
|||
|
||||
// Returns the timestamp values for a vector of Packets.
|
||||
// TODO: puth this kind of test util in a common place.
|
||||
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||
std::vector<int64> result;
|
||||
std::vector<int64_t> TimestampValues(const std::vector<Packet>& packets) {
|
||||
std::vector<int64_t> result;
|
||||
for (const Packet& packet : packets) {
|
||||
result.push_back(packet.Timestamp().Value());
|
||||
}
|
||||
|
@ -371,7 +371,7 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
|
|||
for (int main_ts = 0; main_ts < 50; ++main_ts) {
|
||||
send_packet("in", main_ts);
|
||||
MP_EXPECT_OK(graph_.WaitUntilIdle());
|
||||
std::vector<int64> ts_values = TimestampValues(outputs);
|
||||
std::vector<int64_t> ts_values = TimestampValues(outputs);
|
||||
EXPECT_EQ(ts_values.size(), main_ts + 1);
|
||||
for (int j = 0; j < main_ts + 1; ++j) {
|
||||
EXPECT_EQ(ts_values[j], j);
|
||||
|
|
|
@ -121,7 +121,7 @@ absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
|
|||
if (cc->Outputs().HasTag(kTagAtTimestamp)) {
|
||||
RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries())
|
||||
<< "For AT_TIMESTAMP tag, 2 input side packets are required.";
|
||||
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64>();
|
||||
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64_t>();
|
||||
} else {
|
||||
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries())
|
||||
<< "Same number of input side packets and output streams is required.";
|
||||
|
@ -178,8 +178,8 @@ absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
|
|||
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
|
||||
}
|
||||
} else if (cc->Outputs().HasTag(kTagAtTimestamp)) {
|
||||
int64 timestamp =
|
||||
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64>();
|
||||
int64_t timestamp =
|
||||
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64_t>();
|
||||
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
|
||||
cc->Outputs()
|
||||
.Get(output_tag_, i)
|
||||
|
|
|
@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
|
|||
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
|
||||
REGISTER_CALCULATOR(StringToInt32Calculator);
|
||||
|
||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
|
||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
|
||||
REGISTER_CALCULATOR(StringToUint32Calculator);
|
||||
|
||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
|
||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
|
||||
REGISTER_CALCULATOR(StringToInt64Calculator);
|
||||
|
||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
|
||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
|
||||
REGISTER_CALCULATOR(StringToUint64Calculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
|
|||
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
||||
frame_ptr->Height(), frame_ptr->WidthStep(),
|
||||
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
||||
[](uint8* data){});
|
||||
[](uint8_t* data){});
|
||||
ASSIGN_OR_RETURN(auto result,
|
||||
runner->Run(image_frame, matrix, size, border_mode));
|
||||
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
||||
|
|
|
@ -401,8 +401,8 @@ cc_library_with_tflite(
|
|||
hdrs = ["inference_calculator.h"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
|
@ -506,7 +506,7 @@ cc_library_with_tflite(
|
|||
name = "tflite_delegate_ptr",
|
||||
hdrs = ["tflite_delegate_ptr.h"],
|
||||
tflite_deps = [
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -517,8 +517,8 @@ cc_library_with_tflite(
|
|||
tflite_deps = [
|
||||
":tflite_delegate_ptr",
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
],
|
||||
deps = [
|
||||
":inference_runner",
|
||||
|
@ -546,8 +546,8 @@ cc_library(
|
|||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
] + select({
|
||||
"//conditions:default": [],
|
||||
|
|
|
@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
|
|||
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
||||
}
|
||||
return PacketAdopting<tflite::OpResolver>(
|
||||
std::make_unique<tflite_shims::ops::builtin::
|
||||
BuiltinOpResolverWithoutDefaultDelegates>());
|
||||
std::make_unique<
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
|
|||
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
||||
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
||||
// migration.
|
||||
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
|
||||
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
||||
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
||||
"OP_RESOLVER"};
|
||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
|
||||
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
|
||||
#include "mediapipe/calculators/tensor/inference_runner.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#endif // ANDROID
|
||||
|
|
|
@ -22,9 +22,9 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter.h"
|
||||
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
@ -33,8 +33,8 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
using Interpreter = ::tflite_shims::Interpreter;
|
||||
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
|
||||
using Interpreter = ::tflite::Interpreter;
|
||||
using InterpreterBuilder = ::tflite::InterpreterBuilder;
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
#include "mediapipe/calculators/tensor/tflite_delegate_ptr.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/core/shims/c/c_api_types.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
|
|
@ -61,12 +61,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
|
|||
// overload GPU/TPU/...
|
||||
class SimpleSemaphore {
|
||||
public:
|
||||
explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {}
|
||||
explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {}
|
||||
SimpleSemaphore(const SimpleSemaphore&) = delete;
|
||||
SimpleSemaphore(SimpleSemaphore&&) = delete;
|
||||
|
||||
// Acquires the semaphore by certain amount.
|
||||
void Acquire(uint32 amount) {
|
||||
void Acquire(uint32_t amount) {
|
||||
mutex_.Lock();
|
||||
while (count_ < amount) {
|
||||
cond_.Wait(&mutex_);
|
||||
|
@ -76,7 +76,7 @@ class SimpleSemaphore {
|
|||
}
|
||||
|
||||
// Releases the semaphore by certain amount.
|
||||
void Release(uint32 amount) {
|
||||
void Release(uint32_t amount) {
|
||||
mutex_.Lock();
|
||||
count_ += amount;
|
||||
cond_.SignalAll();
|
||||
|
@ -84,7 +84,7 @@ class SimpleSemaphore {
|
|||
}
|
||||
|
||||
private:
|
||||
uint32 count_;
|
||||
uint32_t count_;
|
||||
absl::Mutex mutex_;
|
||||
absl::CondVar cond_;
|
||||
};
|
||||
|
@ -488,7 +488,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
// necessary.
|
||||
absl::Status OutputBatch(CalculatorContext* cc,
|
||||
std::unique_ptr<InferenceState> inference_state) {
|
||||
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
||||
|
||||
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
||||
|
@ -544,7 +544,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
get_session_run_throttle(options_.max_concurrent_session_runs());
|
||||
session_run_throttle->Acquire(1);
|
||||
}
|
||||
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
tf::Status tf_status;
|
||||
{
|
||||
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
||||
|
@ -562,7 +562,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
// informative error message.
|
||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||
|
||||
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
||||
->IncrementBy(run_end_time - run_start_time);
|
||||
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
||||
|
@ -611,7 +611,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
// Get end time and report.
|
||||
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||
cc->GetCounter(kTotalUsecsCounterSuffix)
|
||||
->IncrementBy(end_time - start_time);
|
||||
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
||||
|
@ -650,7 +650,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
|
||||
// The static singleton semaphore to throttle concurrent session runs.
|
||||
static SimpleSemaphore* get_session_run_throttle(
|
||||
int32 max_concurrent_session_runs) {
|
||||
int32_t max_concurrent_session_runs) {
|
||||
static SimpleSemaphore* session_run_throttle =
|
||||
new SimpleSemaphore(max_concurrent_session_runs);
|
||||
return session_run_throttle;
|
||||
|
|
|
@ -197,15 +197,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
// timestamp and the associated feature. This information is used in process
|
||||
// to output batches of packets in order.
|
||||
timestamps_.clear();
|
||||
int64 last_timestamp_seen = Timestamp::PreStream().Value();
|
||||
int64_t last_timestamp_seen = Timestamp::PreStream().Value();
|
||||
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
|
||||
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
|
||||
if (absl::StrContains(map_kv.first, "/timestamp")) {
|
||||
LOG(INFO) << "Found feature timestamps: " << map_kv.first
|
||||
<< " with size: " << map_kv.second.feature_size();
|
||||
int64 recent_timestamp = Timestamp::PreStream().Value();
|
||||
int64_t recent_timestamp = Timestamp::PreStream().Value();
|
||||
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
|
||||
int64 next_timestamp =
|
||||
int64_t next_timestamp =
|
||||
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
|
||||
RET_CHECK_GT(next_timestamp, recent_timestamp)
|
||||
<< "Timestamps must be sequential. If you're seeing this message "
|
||||
|
@ -361,8 +361,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
// any particular call to Process(). At the every end, we output the
|
||||
// poststream packets. If we only have poststream packets,
|
||||
// last_timestamp_key_ will be empty.
|
||||
int64 start_timestamp = 0;
|
||||
int64 end_timestamp = 0;
|
||||
int64_t start_timestamp = 0;
|
||||
int64_t end_timestamp = 0;
|
||||
if (last_timestamp_key_.empty() || process_poststream_) {
|
||||
process_poststream_ = true;
|
||||
start_timestamp = Timestamp::PostStream().Value();
|
||||
|
@ -481,14 +481,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
|||
// Store a map from the keys for each stream to the timestamps for each
|
||||
// key. This allows us to identify which packets to output for each stream
|
||||
// for timestamps within a given time window.
|
||||
std::map<std::string, std::vector<int64>> timestamps_;
|
||||
std::map<std::string, std::vector<int64_t>> timestamps_;
|
||||
// Store the stream with the latest timestamp in the SequenceExample.
|
||||
std::string last_timestamp_key_;
|
||||
// Store the index of the current timestamp. Will be less than
|
||||
// timestamps_[last_timestamp_key_].size().
|
||||
int current_timestamp_index_;
|
||||
// Store the very first timestamp, so we output everything on the first frame.
|
||||
int64 first_timestamp_seen_;
|
||||
int64_t first_timestamp_seen_;
|
||||
// List of keypoint names.
|
||||
std::vector<std::string> keypoint_names_;
|
||||
// Default keypoint location when missing.
|
||||
|
|
|
@ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
|
||||
const int64 time = 1234;
|
||||
const int64_t time = 1234;
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
|
@ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) {
|
|||
// 2^i can be represented exactly in floating point numbers if 'i' is small.
|
||||
input->at(i) = static_cast<float>(1 << i);
|
||||
}
|
||||
const int64 time = 1234;
|
||||
const int64_t time = 1234;
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(input.release()).At(Timestamp(time)));
|
||||
|
||||
|
|
|
@ -28,11 +28,8 @@
|
|||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
using mediapipe::Adopt;
|
||||
using mediapipe::CalculatorBase;
|
||||
using mediapipe::ImageFrame;
|
||||
using mediapipe::PacketTypeSet;
|
||||
using mediapipe::autoflip::Border;
|
||||
|
||||
constexpr char kDetectedBorders[] = "DETECTED_BORDERS";
|
||||
constexpr int kMinBorderDistance = 5;
|
||||
|
|
|
@ -28,16 +28,12 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
using mediapipe::Adopt;
|
||||
using mediapipe::CalculatorGraphConfig;
|
||||
using mediapipe::CalculatorRunner;
|
||||
using mediapipe::ImageFormat;
|
||||
using mediapipe::ImageFrame;
|
||||
using mediapipe::Packet;
|
||||
using mediapipe::PacketTypeSet;
|
||||
using mediapipe::ParseTextProtoOrDie;
|
||||
using mediapipe::Timestamp;
|
||||
using mediapipe::autoflip::Border;
|
||||
|
||||
namespace mediapipe {
|
||||
namespace autoflip {
|
||||
|
|
|
@ -31,14 +31,11 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
using mediapipe::Adopt;
|
||||
using mediapipe::CalculatorGraphConfig;
|
||||
using mediapipe::CalculatorRunner;
|
||||
using mediapipe::ImageFormat;
|
||||
using mediapipe::ImageFrame;
|
||||
using mediapipe::PacketTypeSet;
|
||||
using mediapipe::ParseTextProtoOrDie;
|
||||
using mediapipe::Timestamp;
|
||||
|
||||
namespace mediapipe {
|
||||
namespace autoflip {
|
||||
|
|
|
@ -28,8 +28,6 @@
|
|||
using mediapipe::Packet;
|
||||
using mediapipe::PacketTypeSet;
|
||||
using mediapipe::autoflip::DetectionSet;
|
||||
using mediapipe::autoflip::SalientRegion;
|
||||
using mediapipe::autoflip::SignalType;
|
||||
|
||||
constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY";
|
||||
constexpr char kSignalInputsTag[] = "SIGNAL";
|
||||
|
|
|
@ -19,8 +19,6 @@ namespace mediapipe {
|
|||
namespace api2 {
|
||||
namespace test {
|
||||
|
||||
using testing::ElementsAre;
|
||||
|
||||
// Returns the packet values for a vector of Packets.
|
||||
template <typename T>
|
||||
std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) {
|
||||
|
|
|
@ -310,7 +310,7 @@ class Scheduler {
|
|||
absl::Mutex state_mutex_;
|
||||
|
||||
// Current state of the scheduler.
|
||||
std::atomic<State> state_ = ATOMIC_VAR_INIT(STATE_NOT_STARTED);
|
||||
std::atomic<State> state_ = STATE_NOT_STARTED;
|
||||
|
||||
// True if all graph input streams are closed.
|
||||
bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false;
|
||||
|
|
|
@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
|||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
|
||||
// Record the most recent first kept timestamp on any stream.
|
||||
for (const auto& stream : input_stream_managers_) {
|
||||
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
||||
? target_queue_size_
|
||||
: trigger_queue_size_ - 1;
|
||||
int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
||||
? target_queue_size_
|
||||
: trigger_queue_size_ - 1;
|
||||
if (stream->QueueSize() > queue_size) {
|
||||
kept_timestamp_ = std::max(
|
||||
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
|
||||
|
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
|||
}
|
||||
|
||||
private:
|
||||
int32 trigger_queue_size_;
|
||||
int32 target_queue_size_;
|
||||
int32_t trigger_queue_size_;
|
||||
int32_t target_queue_size_;
|
||||
bool fixed_min_size_;
|
||||
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
|
||||
// the corresponding call to FillInputSet has not yet completed.
|
||||
|
|
|
@ -30,15 +30,15 @@ namespace mediapipe {
|
|||
|
||||
namespace {
|
||||
|
||||
const int64 kMaxPacketId = 100;
|
||||
const int64 kSlowCalculatorRate = 10;
|
||||
const int64_t kMaxPacketId = 100;
|
||||
const int64_t kSlowCalculatorRate = 10;
|
||||
|
||||
// Rate limiter for TestSlowCalculator.
|
||||
ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit);
|
||||
int64 g_source_counter ABSL_GUARDED_BY(g_source_mutex);
|
||||
int64_t g_source_counter ABSL_GUARDED_BY(g_source_mutex);
|
||||
|
||||
// Rate limiter for TestSourceCalculator.
|
||||
int64 g_slow_counter ABSL_GUARDED_BY(g_source_mutex);
|
||||
int64_t g_slow_counter ABSL_GUARDED_BY(g_source_mutex);
|
||||
|
||||
// Flag that indicates that the source is done.
|
||||
bool g_source_done ABSL_GUARDED_BY(g_source_mutex);
|
||||
|
@ -47,7 +47,7 @@ class TestSourceCalculator : public CalculatorBase {
|
|||
public:
|
||||
TestSourceCalculator() : current_packet_id_(0) {}
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Outputs().Index(0).Set<int64>();
|
||||
cc->Outputs().Index(0).Set<int64_t>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
|
@ -62,7 +62,7 @@ class TestSourceCalculator : public CalculatorBase {
|
|||
g_source_done = true;
|
||||
return tool::StatusStop();
|
||||
}
|
||||
cc->Outputs().Index(0).Add(new int64(0), Timestamp(current_packet_id_));
|
||||
cc->Outputs().Index(0).Add(new int64_t(0), Timestamp(current_packet_id_));
|
||||
++current_packet_id_;
|
||||
{
|
||||
absl::MutexLock lock(&g_source_mutex);
|
||||
|
@ -78,7 +78,7 @@ class TestSourceCalculator : public CalculatorBase {
|
|||
return g_source_counter <= kSlowCalculatorRate * g_slow_counter ||
|
||||
g_source_counter <= 1;
|
||||
}
|
||||
int64 current_packet_id_;
|
||||
int64_t current_packet_id_;
|
||||
};
|
||||
|
||||
REGISTER_CALCULATOR(TestSourceCalculator);
|
||||
|
@ -87,8 +87,8 @@ class TestSlowCalculator : public CalculatorBase {
|
|||
public:
|
||||
TestSlowCalculator() = default;
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int64>();
|
||||
cc->Outputs().Index(0).Set<int64>();
|
||||
cc->Inputs().Index(0).Set<int64_t>();
|
||||
cc->Outputs().Index(0).Set<int64_t>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
|
@ -97,7 +97,7 @@ class TestSlowCalculator : public CalculatorBase {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
cc->Outputs().Index(0).Add(new int64(0),
|
||||
cc->Outputs().Index(0).Add(new int64_t(0),
|
||||
cc->Inputs().Index(0).Value().Timestamp());
|
||||
{
|
||||
absl::MutexLock lock(&g_source_mutex);
|
||||
|
@ -118,8 +118,9 @@ class TestSlowCalculator : public CalculatorBase {
|
|||
REGISTER_CALCULATOR(TestSlowCalculator);
|
||||
|
||||
// Return the values of the timestamps of a vector of Packets.
|
||||
static std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
|
||||
std::vector<int64> result;
|
||||
static std::vector<int64_t> TimestampValues(
|
||||
const std::vector<Packet>& packets) {
|
||||
std::vector<int64_t> result;
|
||||
for (const Packet& p : packets) {
|
||||
result.push_back(p.Timestamp().Value());
|
||||
}
|
||||
|
@ -174,7 +175,7 @@ TEST_P(FixedSizeInputStreamHandlerTest, DropsPackets) {
|
|||
// consumed. In this way, the TestSlowCalculator consumes and outputs only
|
||||
// every tenth packet.
|
||||
EXPECT_EQ(output_packets.size(), 11);
|
||||
std::vector<int64> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99};
|
||||
std::vector<int64_t> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99};
|
||||
EXPECT_THAT(TimestampValues(output_packets),
|
||||
testing::ContainerEq(expected_ts));
|
||||
}
|
||||
|
@ -344,18 +345,18 @@ TEST_P(FixedSizeInputStreamHandlerTest, LateArrivalDrop) {
|
|||
|
||||
if (GetParam()) {
|
||||
EXPECT_THAT(TimestampValues(output_packets[0]),
|
||||
testing::ContainerEq(std::vector<int64>{1, 2, 3, 4, 5, 6}));
|
||||
testing::ContainerEq(std::vector<int64_t>{1, 2, 3, 4, 5, 6}));
|
||||
EXPECT_THAT(TimestampValues(output_packets[1]),
|
||||
testing::ContainerEq(std::vector<int64>{3, 4, 5, 6, 7}));
|
||||
testing::ContainerEq(std::vector<int64_t>{3, 4, 5, 6, 7}));
|
||||
EXPECT_THAT(TimestampValues(output_packets[2]),
|
||||
testing::ContainerEq(std::vector<int64>{4, 5, 6, 7}));
|
||||
testing::ContainerEq(std::vector<int64_t>{4, 5, 6, 7}));
|
||||
} else {
|
||||
EXPECT_THAT(TimestampValues(output_packets[0]),
|
||||
testing::ContainerEq(std::vector<int64>{5, 6}));
|
||||
testing::ContainerEq(std::vector<int64_t>{5, 6}));
|
||||
EXPECT_THAT(TimestampValues(output_packets[1]),
|
||||
testing::ContainerEq(std::vector<int64>{5, 6, 7}));
|
||||
testing::ContainerEq(std::vector<int64_t>{5, 6, 7}));
|
||||
EXPECT_THAT(TimestampValues(output_packets[2]),
|
||||
testing::ContainerEq(std::vector<int64>{5, 6, 7}));
|
||||
testing::ContainerEq(std::vector<int64_t>{5, 6, 7}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,10 +27,6 @@ namespace options_field_util {
|
|||
|
||||
using ::mediapipe::proto_ns::internal::WireFormatLite;
|
||||
using FieldType = WireFormatLite::FieldType;
|
||||
using ::mediapipe::proto_ns::io::ArrayInputStream;
|
||||
using ::mediapipe::proto_ns::io::CodedInputStream;
|
||||
using ::mediapipe::proto_ns::io::CodedOutputStream;
|
||||
using ::mediapipe::proto_ns::io::StringOutputStream;
|
||||
|
||||
// Utility functions for OptionsFieldUtil.
|
||||
namespace {
|
||||
|
|
|
@ -454,8 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
// Number of glFinish calls completed on the GL thread.
|
||||
// Changes should be guarded by mutex_. However, we use simple atomic
|
||||
// loads for efficiency on the fast path.
|
||||
std::atomic<int64_t> gl_finish_count_ = ATOMIC_VAR_INIT(0);
|
||||
std::atomic<int64_t> gl_finish_count_target_ = ATOMIC_VAR_INIT(0);
|
||||
std::atomic<int64_t> gl_finish_count_ = 0;
|
||||
std::atomic<int64_t> gl_finish_count_target_ = 0;
|
||||
|
||||
GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr;
|
||||
|
||||
|
|
|
@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
|
|||
// TODO: Investigate this option in more detail, esp. on Safari.
|
||||
attrs.preserveDrawingBuffer = 0;
|
||||
|
||||
// Since the Emscripten canvas target finding function is visible from here,
|
||||
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas
|
||||
// behavior if the user desires, falling back to the new DOM element CSS
|
||||
// selector behavior next if that is specified, and finally just allowing the
|
||||
// lookup to proceed on a null target.
|
||||
// TODO: Ensure this works with all options (in particular,
|
||||
// multithreading options, like the special-case combination of USE_PTHREADS
|
||||
// and OFFSCREEN_FRAMEBUFFER)
|
||||
// clang-format off
|
||||
EM_ASM(
|
||||
let init_once = true;
|
||||
if (init_once) {
|
||||
const cachedFindCanvasEventTarget = findCanvasEventTarget;
|
||||
|
||||
if (typeof cachedFindCanvasEventTarget !== 'function') {
|
||||
if (typeof console !== 'undefined') {
|
||||
console.error('Expected Emscripten global function '
|
||||
+ '"findCanvasEventTarget" not found. WebGL context creation '
|
||||
+ 'may fail.');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
findCanvasEventTarget = function(target) {
|
||||
if (target == 0) {
|
||||
if (Module && Module.canvas) {
|
||||
return Module.canvas;
|
||||
} else if (Module && Module.canvasCssSelector) {
|
||||
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
|
||||
}
|
||||
if (typeof console !== 'undefined') {
|
||||
console.warn('Module properties canvas and canvasCssSelector not ' +
|
||||
'found during WebGL context creation.');
|
||||
}
|
||||
}
|
||||
// We still go through with the find attempt, although for most use
|
||||
// cases it will not succeed, just in case the user does want to fall-
|
||||
// back.
|
||||
return cachedFindCanvasEventTarget(target);
|
||||
}; // NOLINT: Necessary semicolon.
|
||||
init_once = false;
|
||||
}
|
||||
);
|
||||
// clang-format on
|
||||
|
||||
// Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
|
||||
// looks for our #canvas target in Module.canvas, where we expect it to be.
|
||||
// -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
|
||||
// event target behavior, but it was never supposed to be tapping into our
|
||||
// canvas anyways. See b/278155946 for more background.
|
||||
EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
|
||||
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
|
||||
emscripten_webgl_create_context(nullptr, &attrs);
|
||||
emscripten_webgl_create_context("#canvas", &attrs);
|
||||
|
||||
// Check for failure
|
||||
if (context_handle <= 0) {
|
||||
|
|
|
@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
|
|||
int actual_ws = image_frame.WidthStep();
|
||||
int alignment = 0;
|
||||
std::unique_ptr<ImageFrame> temp;
|
||||
const uint8* data = image_frame.PixelData();
|
||||
const uint8_t* data = image_frame.PixelData();
|
||||
|
||||
// Let's see if the pixel data is tightly aligned to one of the alignments
|
||||
// supported by OpenGL, preferring 4 if possible since it's the default.
|
||||
|
|
|
@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
|||
GpuBufferFormat format) {
|
||||
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
|
||||
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
|
||||
auto y_data = std::make_unique<uint8[]>(y_stride * height);
|
||||
auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
|
||||
switch (fourcc) {
|
||||
case libyuv::FOURCC_NV12:
|
||||
case libyuv::FOURCC_NV21: {
|
||||
|
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
|||
int uv_width = 2 * std::ceil(0.5f * width);
|
||||
int uv_height = std::ceil(0.5f * height);
|
||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
||||
auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||
yuv_image_ = std::make_shared<YUVImage>(
|
||||
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
|
||||
nullptr, 0, width, height);
|
||||
|
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
|||
int uv_width = std::ceil(0.5f * width);
|
||||
int uv_height = std::ceil(0.5f * height);
|
||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
||||
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
||||
auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||
auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||
yuv_image_ = std::make_shared<YUVImage>(
|
||||
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
|
||||
std::move(v_data), uv_stride, width, height);
|
||||
|
|
|
@ -16,6 +16,7 @@ import csv
|
|||
import filecmp
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import tensorflow as tf
|
||||
|
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
|
|||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
@unittest.skip('b/275624089')
|
||||
class TextClassifierTest(tf.test.TestCase):
|
||||
|
||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||
|
|
|
@ -175,11 +175,7 @@ py_test(
|
|||
data = [":testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_spec",
|
||||
":object_detector",
|
||||
":object_detector_options",
|
||||
":object_detector_import",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
|
|||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
from mediapipe.model_maker.python.vision import object_detector
|
||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||
|
||||
|
||||
|
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
super().setUp()
|
||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||
cache_dir = self.create_tempdir()
|
||||
self.data = dataset.Dataset.from_coco_folder(
|
||||
self.data = object_detector.Dataset.from_coco_folder(
|
||||
dataset_folder, cache_dir=cache_dir
|
||||
)
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
|
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_object_detector(self):
|
||||
hparams = hyperparameters.HParams(
|
||||
hparams = object_detector.HParams(
|
||||
epochs=1,
|
||||
batch_size=2,
|
||||
learning_rate=0.9,
|
||||
shuffle=False,
|
||||
export_dir=self.create_tempdir(),
|
||||
)
|
||||
options = object_detector_options.ObjectDetectorOptions(
|
||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
||||
options = object_detector.ObjectDetectorOptions(
|
||||
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||
hparams=hparams,
|
||||
)
|
||||
# Test `create``
|
||||
model = object_detector.ObjectDetector.create(
|
||||
|
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
|
||||
# Test `quantization_aware_training`
|
||||
qat_hparams = hyperparameters.QATHParams(
|
||||
qat_hparams = object_detector.QATHParams(
|
||||
learning_rate=0.9,
|
||||
batch_size=2,
|
||||
epochs=1,
|
||||
|
|
|
@ -298,6 +298,7 @@ cc_library(
|
|||
":tensors_to_objects_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
|
|
@ -91,8 +91,8 @@ absl::Status FrameAnnotationToTimedBoxListCalculator::Process(
|
|||
TimedBoxProto* added_box = output_objects->add_box();
|
||||
ComputeBoundingRect(key_points, added_box);
|
||||
added_box->set_id(annotation.object_id());
|
||||
const int64 time_msec =
|
||||
static_cast<int64>(std::round(frame_annotation.timestamp() / 1000));
|
||||
const int64_t time_msec =
|
||||
static_cast<int64_t>(std::round(frame_annotation.timestamp() / 1000));
|
||||
added_box->set_time_msec(time_msec);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,8 +24,8 @@ namespace mediapipe {
|
|||
|
||||
void FrameAnnotationTracker::AddDetectionResult(
|
||||
const FrameAnnotation& frame_annotation) {
|
||||
const int64 time_us =
|
||||
static_cast<int64>(std::round(frame_annotation.timestamp()));
|
||||
const int64_t time_us =
|
||||
static_cast<int64_t>(std::round(frame_annotation.timestamp()));
|
||||
for (const auto& object_annotation : frame_annotation.annotations()) {
|
||||
detected_objects_[time_us + object_annotation.object_id()] =
|
||||
object_annotation;
|
||||
|
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
|
|||
absl::flat_hash_set<int>* cancel_object_ids) {
|
||||
CHECK(cancel_object_ids != nullptr);
|
||||
FrameAnnotation frame_annotation;
|
||||
std::vector<int64> keys_to_be_deleted;
|
||||
std::vector<int64_t> keys_to_be_deleted;
|
||||
for (const auto& detected_obj : detected_objects_) {
|
||||
const int object_id = detected_obj.second.object_id();
|
||||
if (cancel_object_ids->contains(object_id)) {
|
||||
|
|
|
@ -76,7 +76,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase {
|
|||
// In a single MediaPipe session, the IDs are unique.
|
||||
// Also assign timestamp for the FrameAnnotation to be the input packet
|
||||
// timestamp.
|
||||
void AssignObjectIdAndTimestamp(int64 timestamp_us,
|
||||
void AssignObjectIdAndTimestamp(int64_t timestamp_us,
|
||||
FrameAnnotation* annotation);
|
||||
|
||||
int num_classes_ = 0;
|
||||
|
@ -207,7 +207,7 @@ void TfLiteTensorsToObjectsCalculator::Project3DTo2D(
|
|||
}
|
||||
|
||||
void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp(
|
||||
int64 timestamp_us, FrameAnnotation* annotation) {
|
||||
int64_t timestamp_us, FrameAnnotation* annotation) {
|
||||
for (auto& ann : *annotation->mutable_annotations()) {
|
||||
ann.set_object_id(GetNextObjectId());
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -157,7 +157,7 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
|
|||
}
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
|
||||
auto options = std::make_unique<AudioClassifierOptions>();
|
||||
|
@ -270,7 +270,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
|
|||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
class ClassifyTest : public tflite_shims::testing::Test {};
|
||||
class ClassifyTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ClassifyTest, Succeeds) {
|
||||
auto audio_buffer = GetAudioData(k16kTestWavFilename);
|
||||
|
@ -467,7 +467,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
|
|||
}
|
||||
}
|
||||
|
||||
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
|
||||
class ClassifyAsyncTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ClassifyAsyncTest, Succeeds) {
|
||||
constexpr int kSampleRateHz = 48000;
|
||||
|
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -66,7 +66,7 @@ Matrix GetAudioData(absl::string_view filename) {
|
|||
return matrix_mapping.matrix();
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
||||
auto audio_embedder =
|
||||
|
@ -124,7 +124,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
|
|||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
class EmbedTest : public tflite_shims::testing::Test {};
|
||||
class EmbedTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
|
||||
auto options = std::make_unique<AudioEmbedderOptions>();
|
||||
|
@ -187,7 +187,7 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
|
|||
MP_EXPECT_OK(audio_embedder->Close());
|
||||
}
|
||||
|
||||
class EmbedAsyncTest : public tflite_shims::testing::Test {
|
||||
class EmbedAsyncTest : public tflite::testing::Test {
|
||||
protected:
|
||||
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
|
||||
int sample_rate_hz,
|
||||
|
|
|
@ -47,7 +47,7 @@ cc_test_with_tflite(
|
|||
data = ["//mediapipe/tasks/testdata/audio:test_models"],
|
||||
tflite_deps = [
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
deps = [
|
||||
":audio_tensor_specs",
|
||||
|
|
|
@ -34,7 +34,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -52,7 +52,7 @@ constexpr char kModelWithMetadata[] =
|
|||
"yamnet_audio_classifier_with_metadata.tflite";
|
||||
constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite";
|
||||
|
||||
class AudioTensorSpecsTest : public tflite_shims::testing::Test {};
|
||||
class AudioTensorSpecsTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(AudioTensorSpecsTest,
|
||||
BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) {
|
||||
|
|
|
@ -63,7 +63,7 @@ cc_test(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -232,6 +232,6 @@ cc_test(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -33,7 +33,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
@ -66,8 +66,7 @@ ClassificationList MakeClassificationList(int class_index) {
|
|||
class_index));
|
||||
}
|
||||
|
||||
class ClassificationAggregationCalculatorTest
|
||||
: public tflite_shims::testing::Test {
|
||||
class ClassificationAggregationCalculatorTest : public tflite::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
bool connect_timestamps = false) {
|
||||
|
|
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
@ -52,7 +52,7 @@ constexpr char kTimestampsName[] = "timestamps_in";
|
|||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||
|
||||
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
||||
class EmbeddingAggregationCalculatorTest : public tflite::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||
Graph graph;
|
||||
|
|
|
@ -66,7 +66,7 @@ using ::mediapipe::tasks::core::ModelResources;
|
|||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::tflite::ProcessUnit;
|
||||
using ::tflite::TensorMetadata;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
|
||||
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
|
|
@ -49,7 +49,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -101,7 +101,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
|||
std::move(external_file));
|
||||
}
|
||||
|
||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
||||
class ConfigureTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -417,7 +417,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
|||
)pb")));
|
||||
}
|
||||
|
||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
class PostprocessingTest : public tflite::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||
|
@ -520,7 +520,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
|||
auto poller,
|
||||
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 18;
|
||||
tensor[2] = 16;
|
||||
|
||||
|
@ -552,7 +552,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 12;
|
||||
tensor[2] = 14;
|
||||
tensor[3] = 16;
|
||||
|
@ -589,7 +589,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
|||
auto poller,
|
||||
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 12;
|
||||
tensor[2] = 14;
|
||||
tensor[3] = 16;
|
||||
|
@ -677,11 +677,11 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
|||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||
/*connect_timestamps=*/true));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
|
||||
std::vector<uint8_t> tensor_0(kMobileNetNumClasses, 0);
|
||||
tensor_0[1] = 12;
|
||||
tensor_0[2] = 14;
|
||||
tensor_0[3] = 16;
|
||||
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
|
||||
std::vector<uint8_t> tensor_1(kMobileNetNumClasses, 0);
|
||||
tensor_1[5] = 12;
|
||||
tensor_1[6] = 14;
|
||||
tensor_1[7] = 16;
|
||||
|
|
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -86,7 +86,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
|||
std::move(external_file));
|
||||
}
|
||||
|
||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
||||
class ConfigureTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -153,7 +153,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
|||
has_quantized_outputs: false)pb")));
|
||||
}
|
||||
|
||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
class PostprocessingTest : public tflite::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||
|
|
|
@ -37,7 +37,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -125,7 +125,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
|||
return TaskRunner::Create(graph.GetConfig());
|
||||
}
|
||||
|
||||
class ConfigureTest : public tflite_shims::testing::Test {};
|
||||
class ConfigureTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
|
|
@ -78,6 +78,7 @@ cc_library(
|
|||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
|
||||
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
|
@ -128,9 +129,9 @@ cc_library_with_tflite(
|
|||
srcs = ["model_resources.cc"],
|
||||
hdrs = ["model_resources.h"],
|
||||
tflite_deps = [
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:verifier",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/tools:verifier",
|
||||
],
|
||||
deps = [
|
||||
":external_file_handler",
|
||||
|
@ -159,9 +160,9 @@ cc_test_with_tflite(
|
|||
],
|
||||
tflite_deps = [
|
||||
":model_resources",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
deps = [
|
||||
":utils",
|
||||
|
@ -186,7 +187,7 @@ cc_library_with_tflite(
|
|||
hdrs = ["model_resources_cache.h"],
|
||||
tflite_deps = [
|
||||
":model_resources",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
deps = [
|
||||
":model_asset_bundle_resources",
|
||||
|
@ -233,7 +234,7 @@ cc_test_with_tflite(
|
|||
":model_resources",
|
||||
":model_resources_cache",
|
||||
":model_resources_calculator",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
@ -284,7 +285,7 @@ cc_test_with_tflite(
|
|||
":task_runner",
|
||||
":model_resources",
|
||||
":model_resources_cache",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
|
@ -317,6 +318,9 @@ cc_library(
|
|||
":model_resources",
|
||||
":task_runner",
|
||||
":utils",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:requires",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||
|
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
|||
AddCustom("KmeansEmbeddingLookup",
|
||||
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
||||
// For the UniversalSentenceEncoder model.
|
||||
AddCustom("TFSentencepieceTokenizeOp",
|
||||
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
|
||||
AddCustom("RaggedTensorToTensor",
|
||||
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
|
||||
}
|
||||
|
|
|
@ -37,8 +37,8 @@ limitations under the License.
|
|||
#include "mediapipe/util/tflite/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
||||
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/tools/verifier.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -52,7 +52,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
|||
|
||||
bool ModelResources::Verifier::Verify(const char* data, int length,
|
||||
tflite::ErrorReporter* reporter) {
|
||||
return tflite_shims::Verify(data, length, reporter);
|
||||
return tflite::Verify(data, length, reporter);
|
||||
}
|
||||
|
||||
ModelResources::ModelResources(const std::string& tag,
|
||||
|
@ -124,7 +124,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
|
|||
// and that it uses only operators that are supported by the OpResolver
|
||||
// that was passed to the ModelResources constructor, and then builds
|
||||
// the model from the buffer.
|
||||
auto model = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
|
||||
buffer_data, buffer_size, &verifier_, &error_reporter_);
|
||||
if (model == nullptr) {
|
||||
static constexpr char kInvalidFlatbufferMessage[] =
|
||||
|
@ -151,8 +151,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
|
|||
}
|
||||
|
||||
model_packet_ = MakePacket<ModelPtr>(
|
||||
model.release(),
|
||||
[](tflite_shims::FlatBufferModel* model) { delete model; });
|
||||
model.release(), [](tflite::FlatBufferModel* model) { delete model; });
|
||||
ASSIGN_OR_RETURN(auto model_metadata_extractor,
|
||||
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
buffer_data, buffer_size));
|
||||
|
|
|
@ -32,10 +32,10 @@ limitations under the License.
|
|||
#include "mediapipe/util/tflite/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
||||
#include "tensorflow/lite/core/shims/cc/model_builder.h"
|
||||
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/tools/verifier.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -51,8 +51,8 @@ class ModelResources {
|
|||
public:
|
||||
// Represents a TfLite model as a FlatBuffer.
|
||||
using ModelPtr =
|
||||
std::unique_ptr<tflite_shims::FlatBufferModel,
|
||||
std::function<void(tflite_shims::FlatBufferModel*)>>;
|
||||
std::unique_ptr<tflite::FlatBufferModel,
|
||||
std::function<void(tflite::FlatBufferModel*)>>;
|
||||
|
||||
// Takes the ownership of the provided ExternalFile proto and creates
|
||||
// ModelResources from the proto and an op resolver object. A non-empty tag
|
||||
|
@ -61,7 +61,7 @@ class ModelResources {
|
|||
static absl::StatusOr<std::unique_ptr<ModelResources>> Create(
|
||||
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver =
|
||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||
|
||||
// Takes the ownership of the provided ExternalFile proto and creates
|
||||
// ModelResources from the proto and an op resolver mediapipe packet. A
|
||||
|
|
|
@ -30,7 +30,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -124,7 +124,7 @@ void RunGraphWithGraphService(std::unique_ptr<ModelResources> model_resources,
|
|||
|
||||
} // namespace
|
||||
|
||||
class ModelResourcesCalculatorTest : public tflite_shims::testing::Test {};
|
||||
class ModelResourcesCalculatorTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
|
|
|
@ -38,9 +38,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
|
@ -116,7 +116,7 @@ void CheckModelResourcesPackets(const ModelResources* model_resources) {
|
|||
|
||||
} // namespace
|
||||
|
||||
class ModelResourcesTest : public tflite_shims::testing::Test {};
|
||||
class ModelResourcesTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ModelResourcesTest, CreateFromBinaryContent) {
|
||||
auto model_file = std::make_unique<proto::ExternalFile>();
|
||||
|
@ -211,7 +211,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsFromFile) {
|
|||
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
||||
tflite::MutableOpResolver resolver;
|
||||
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||
::tflite_shims::ops::builtin::Register_ADD());
|
||||
::tflite::ops::builtin::Register_ADD());
|
||||
resolver.AddCustom(kCustomOpName,
|
||||
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
||||
|
||||
|
@ -275,7 +275,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsPacket) {
|
|||
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
|
||||
tflite::MutableOpResolver resolver;
|
||||
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
|
||||
::tflite_shims::ops::builtin::Register_ADD());
|
||||
::tflite::ops::builtin::Register_ADD());
|
||||
resolver.AddCustom(kCustomOpName,
|
||||
::tflite::ops::custom::Register_MY_CUSTOM_OP());
|
||||
|
||||
|
|
|
@ -23,7 +23,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/port/requires.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -54,6 +58,8 @@ class TaskApiFactory {
|
|||
std::unique_ptr<tflite::OpResolver> resolver,
|
||||
PacketsCallback packets_callback = nullptr) {
|
||||
bool found_task_subgraph = false;
|
||||
// This for-loop ensures there's only one subgraph besides
|
||||
// FlowLimiterCalculator.
|
||||
for (const auto& node : graph_config.node()) {
|
||||
if (node.calculator() == "FlowLimiterCalculator") {
|
||||
continue;
|
||||
|
@ -64,13 +70,7 @@ class TaskApiFactory {
|
|||
"Task graph config should only contain one task subgraph node.",
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
} else {
|
||||
if (!node.options().HasExtension(Options::ext)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat(node.calculator(),
|
||||
" is missing the required task options field."),
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node));
|
||||
found_task_subgraph = true;
|
||||
}
|
||||
}
|
||||
|
@ -80,6 +80,35 @@ class TaskApiFactory {
|
|||
std::move(packets_callback)));
|
||||
return std::make_unique<T>(std::move(runner));
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Options>
|
||||
static absl::Status CheckHasValidOptions(
|
||||
const CalculatorGraphConfig::Node& node) {
|
||||
if constexpr (mediapipe::Requires<Options>(
|
||||
[](auto&& o) -> decltype(o.ext) {})) {
|
||||
if (node.options().HasExtension(Options::ext)) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} else {
|
||||
#ifndef MEDIAPIPE_PROTO_LITE
|
||||
for (const auto& option : node.node_options()) {
|
||||
if (absl::StrContains(option.type_url(),
|
||||
Options::descriptor()->full_name())) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
#else // MEDIAPIPE_PROTO_LITE
|
||||
// Skip the check for proto lite, as Options::descriptor() is unavailable.
|
||||
return absl::OkStatus();
|
||||
#endif // MEDIAPIPE_PROTO_LITE
|
||||
}
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat(node.calculator(),
|
||||
" is missing the required task options field."),
|
||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -112,7 +112,7 @@ CalculatorGraphConfig GetModelSidePacketsToStreamPacketsGraphConfig(
|
|||
|
||||
} // namespace
|
||||
|
||||
class TaskRunnerTest : public tflite_shims::testing::Test {};
|
||||
class TaskRunnerTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) {
|
||||
CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
|
172
mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
Normal file
172
mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
Normal file
|
@ -0,0 +1,172 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "config_fbs",
|
||||
srcs = ["config.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "config",
|
||||
srcs = [
|
||||
"config.fbs",
|
||||
],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "encoder_config",
|
||||
srcs = [
|
||||
"encoder_config.fbs",
|
||||
],
|
||||
includes = [":config_fbs"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "utils",
|
||||
hdrs = [
|
||||
"utils.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "double_array_trie",
|
||||
hdrs = [
|
||||
"double_array_trie.h",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "double_array_trie_builder",
|
||||
srcs = [
|
||||
"double_array_trie_builder.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"double_array_trie_builder.h",
|
||||
],
|
||||
deps = ["@darts_clone"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "double_array_trie_test",
|
||||
srcs = [
|
||||
"double_array_trie_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie",
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":utils",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_constants",
|
||||
hdrs = ["sentencepiece_constants.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_converter",
|
||||
srcs = [
|
||||
"model_converter.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"model_converter.h",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":sentencepiece_constants",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "optimized_encoder",
|
||||
srcs = [
|
||||
"optimized_encoder.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"optimized_encoder.h",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie",
|
||||
":encoder_config",
|
||||
":utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_tokenizer_tflite",
|
||||
srcs = ["sentencepiece_tokenizer_tflite.cc"],
|
||||
hdrs = ["sentencepiece_tokenizer_tflite.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps =
|
||||
[
|
||||
":optimized_encoder",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "optimized_encoder_test",
|
||||
srcs = [
|
||||
"optimized_encoder_test.cc",
|
||||
],
|
||||
data = [
|
||||
":testdata",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":model_converter",
|
||||
":optimized_encoder",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@org_tensorflow//tensorflow/core:lib",
|
||||
],
|
||||
)
|
25
mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
Normal file
25
mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
Normal file
|
@ -0,0 +1,25 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
namespace mediapipe.tflite_operations.sentencepiece;
|
||||
|
||||
table Trie {
|
||||
nodes: [uint32];
|
||||
}
|
||||
|
||||
|
||||
enum EncoderVersion: byte {
|
||||
SENTENCE_PIECE = 0,
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// A trie node specifies a node in the tree, either an intermediate node or
|
||||
// a leaf node.
|
||||
// A leaf node contains the id as an int of the string match. This id is encoded
|
||||
// in the lower 31 bits, thus the number of distinct ids is 2^31.
|
||||
// An intermediate node has an associated label and an offset to its children.
|
||||
// The label is encoded in the least significant byte and must match the input
|
||||
// character during matching.
|
||||
|
||||
// A memory mappable trie, compatible with Darts::DoubleArray.
|
||||
class DoubleArrayTrie {
|
||||
public:
|
||||
struct Match {
|
||||
Match() {}
|
||||
Match(int id, int match_length) : id(id), match_length(match_length) {}
|
||||
int id = -1;
|
||||
int match_length = -1;
|
||||
bool empty() const { return match_length == -1; }
|
||||
bool operator==(const Match& m) const {
|
||||
return m.id == id && m.match_length == match_length;
|
||||
}
|
||||
};
|
||||
|
||||
// nodes and nodes_length specify the array of the nodes of the trie.
|
||||
explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
|
||||
: nodes_(nodes) {}
|
||||
|
||||
// Finds matches that are prefixes of a string.
|
||||
template <typename callback>
|
||||
void IteratePrefixMatches(const utils::string_view& input,
|
||||
callback update_fn) const;
|
||||
|
||||
// Finds the longest prefix match of a string.
|
||||
Match LongestPrefixMatch(const utils::string_view& input) const {
|
||||
Match match;
|
||||
IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
|
||||
return match;
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns whether a node as a leaf as a child.
|
||||
bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
|
||||
|
||||
// Returns a value associated with a node. Available when a node is a leaf.
|
||||
int value(uint32_t i) const {
|
||||
return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
|
||||
}
|
||||
|
||||
// Returns a label associated with a node.
|
||||
// A leaf node will have the MSB set and thus return an invalid label.
|
||||
int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
|
||||
|
||||
// Returns offset to children.
|
||||
int32_t offset(uint32_t i) const {
|
||||
const uint32_t node = (*nodes_)[i];
|
||||
return (node >> 10) << ((node & 0x200) >> 6);
|
||||
}
|
||||
|
||||
const flatbuffers::Vector<uint32_t>* nodes_;
|
||||
};
|
||||
|
||||
template <typename callback>
|
||||
void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
|
||||
callback update_fn) const {
|
||||
if (nodes_->size() == 0) {
|
||||
return;
|
||||
}
|
||||
uint32_t pos = offset(0);
|
||||
for (int i = 0; i < input.length(); ++i) {
|
||||
pos ^= static_cast<unsigned char>(input.at(i));
|
||||
if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
|
||||
// No match, exit.
|
||||
return;
|
||||
}
|
||||
const bool node_has_leaf = has_leaf(pos);
|
||||
pos ^= offset(pos);
|
||||
if (pos < 0 || pos >= nodes_->size()) {
|
||||
// We can get here only if the structure is corrupted.
|
||||
return;
|
||||
}
|
||||
if (node_has_leaf) {
|
||||
update_fn(Match(value(pos), i + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
|
|
@ -0,0 +1,75 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "include/darts.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) {
|
||||
std::vector<int> ids;
|
||||
ids.reserve(data.size());
|
||||
for (int i = 0; i < data.size(); ++i) {
|
||||
ids.push_back(i);
|
||||
}
|
||||
return BuildTrie(data, ids);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
|
||||
const std::vector<int>& ids) {
|
||||
// We make strong assumptions about binary structure of trie.
|
||||
struct OneElement {
|
||||
OneElement(const std::string* key_, int index_)
|
||||
: key(key_), index(index_) {}
|
||||
const std::string* key;
|
||||
int index;
|
||||
bool operator<(const OneElement& el) const { return *key < *el.key; }
|
||||
};
|
||||
std::vector<OneElement> elements;
|
||||
elements.reserve(data.size());
|
||||
auto data_iterator = std::begin(data);
|
||||
auto ids_iterator = std::begin(ids);
|
||||
for (; data_iterator != std::end(data) && ids_iterator != std::end(ids);
|
||||
++data_iterator, ++ids_iterator) {
|
||||
elements.emplace_back(&(*data_iterator), *ids_iterator);
|
||||
}
|
||||
// Sort by keys.
|
||||
std::sort(elements.begin(), elements.end());
|
||||
|
||||
// Create vectors to build the trie.
|
||||
std::vector<const char*> strings;
|
||||
std::vector<int32_t> indexes;
|
||||
strings.reserve(data.size());
|
||||
indexes.reserve(data.size());
|
||||
for (const auto& el : elements) {
|
||||
strings.push_back(el.key->c_str());
|
||||
indexes.push_back(el.index);
|
||||
}
|
||||
auto trie = std::make_unique<Darts::DoubleArray>();
|
||||
trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr,
|
||||
&indexes[0]);
|
||||
// We make strong assumptions about internal Darts trie structure:
|
||||
// - it is a vector of 32 bit signed integers
|
||||
// - the "array" is the only one structure that contains all information about
|
||||
// the trie.
|
||||
const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array());
|
||||
return std::vector<uint32_t>(trie_data, trie_data + trie->size());
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,32 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
|
||||
const std::vector<int>& ids);
|
||||
|
||||
// A variant where ids are indexes in data.
|
||||
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
|
|
@ -0,0 +1,73 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
TEST(DoubleArrayTrieTest, Match) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"};
|
||||
const auto trie_vector = builder.CreateVector(BuildTrie(test_strings));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto pieces = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_pieces(pieces);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
DoubleArrayTrie dat(config->pieces()->nodes());
|
||||
EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")),
|
||||
DoubleArrayTrie::Match(2, 2));
|
||||
|
||||
std::vector<DoubleArrayTrie::Match> matches;
|
||||
dat.IteratePrefixMatches(
|
||||
utils::string_view("AAXL"),
|
||||
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
|
||||
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1),
|
||||
DoubleArrayTrie::Match(2, 2),
|
||||
DoubleArrayTrie::Match(1, 3)));
|
||||
}
|
||||
|
||||
TEST(DoubleArrayTrieTest, ComplexMatch) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s",
|
||||
"\xe2\x96\x81Hello"};
|
||||
const std::vector<int> test_ids = {0, 5, 10, 15};
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(test_strings, test_ids));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto pieces = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_pieces(pieces);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
DoubleArrayTrie dat(config->pieces()->nodes());
|
||||
|
||||
std::vector<DoubleArrayTrie::Match> matches;
|
||||
dat.IteratePrefixMatches(
|
||||
utils::string_view("\xe2\x96\x81Hello"),
|
||||
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
|
||||
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8)));
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
include "config.fbs";
|
||||
|
||||
namespace mediapipe.tflite_operations.sentencepiece;
|
||||
|
||||
table EncoderConfig {
|
||||
// Version of the encoder.
|
||||
version: EncoderVersion = SENTENCE_PIECE;
|
||||
start_code: int32 = 0;
|
||||
end_code: int32 = 0;
|
||||
|
||||
unknown_code: int32 = -1;
|
||||
// Weight of "unknown code" when encoding. "Penalty" because it usually has a
|
||||
// big negative weight,less than any other sentencepiece.
|
||||
unknown_penalty: float = 0;
|
||||
|
||||
// The offset for encoding, usually used when codes with low codes are reserved
|
||||
// for some special needs.
|
||||
encoding_offset: int32;
|
||||
|
||||
// String pieces for encoding.
|
||||
pieces: Trie;
|
||||
pieces_scores: [float];
|
||||
|
||||
// Normalization related parameters.
|
||||
remove_extra_whitespaces: bool;
|
||||
|
||||
// Add a whitespace prefix before encoding.
|
||||
add_dummy_prefix: bool;
|
||||
|
||||
// Escape whitespaces during encoding so the decoder can restore them exactly as
|
||||
// in the input.
|
||||
escape_whitespaces: bool;
|
||||
|
||||
// Normalization parameters.
|
||||
normalized_prefixes: Trie;
|
||||
normalized_replacements: [byte];
|
||||
}
|
||||
|
||||
root_type EncoderConfig;
|
|
@ -0,0 +1,131 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
|
||||
#include "src/sentencepiece_model.pb.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
|
||||
DecodePrecompiledCharsmap(
|
||||
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
|
||||
// This function "undoes" encoding done by
|
||||
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
|
||||
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
|
||||
const uint32_t trie_size =
|
||||
*reinterpret_cast<const uint32_t*>(precompiled_map);
|
||||
const uint32_t* trie_ptr =
|
||||
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
|
||||
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
|
||||
precompiled_map + sizeof(uint32_t) + trie_size);
|
||||
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
|
||||
sizeof(uint32_t) - trie_size;
|
||||
return std::make_tuple(
|
||||
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
|
||||
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||
const std::string& model_config_str, int encoding_offset) {
|
||||
::sentencepiece::ModelProto model_config;
|
||||
if (!model_config.ParseFromString(model_config_str)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Invalid configuration, can't parse SentencePiece model config " +
|
||||
model_config.InitializationErrorString());
|
||||
}
|
||||
// Convert sentencepieces.
|
||||
std::vector<std::string> pieces;
|
||||
pieces.reserve(model_config.pieces_size());
|
||||
std::vector<float> scores;
|
||||
scores.reserve(model_config.pieces_size());
|
||||
std::vector<int> ids;
|
||||
ids.reserve(model_config.pieces_size());
|
||||
float min_score = 0.0;
|
||||
int index = 0;
|
||||
for (const auto& piece : model_config.pieces()) {
|
||||
switch (piece.type()) {
|
||||
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
|
||||
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
|
||||
pieces.push_back(piece.piece());
|
||||
ids.push_back(index);
|
||||
if (piece.score() < min_score) {
|
||||
min_score = piece.score();
|
||||
}
|
||||
break;
|
||||
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
|
||||
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
|
||||
// Ignore unknown and control codes.
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
|
||||
piece.piece());
|
||||
}
|
||||
scores.push_back(piece.score());
|
||||
++index;
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
|
||||
const auto pieces_score_vector = builder.CreateVector(scores);
|
||||
TrieBuilder pieces_trie_builder(builder);
|
||||
pieces_trie_builder.add_nodes(pieces_trie_vector);
|
||||
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
|
||||
|
||||
// Converting normalization.
|
||||
const auto normalization =
|
||||
DecodePrecompiledCharsmap(model_config.normalizer_spec());
|
||||
const auto normalization_trie = std::get<0>(normalization);
|
||||
const auto normalization_strings = std::get<1>(normalization);
|
||||
const auto normalization_trie_vector =
|
||||
builder.CreateVector(normalization_trie);
|
||||
TrieBuilder normalization_trie_builder(builder);
|
||||
normalization_trie_builder.add_nodes(normalization_trie_vector);
|
||||
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
|
||||
const auto normalization_strings_fbs =
|
||||
builder.CreateVector(normalization_strings);
|
||||
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
|
||||
ecb.add_start_code(model_config.trainer_spec().bos_id());
|
||||
ecb.add_end_code(model_config.trainer_spec().eos_id());
|
||||
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
|
||||
ecb.add_unknown_penalty(min_score - kUnkPenalty);
|
||||
ecb.add_encoding_offset(encoding_offset);
|
||||
ecb.add_pieces(pieces_trie_fbs);
|
||||
ecb.add_pieces_scores(pieces_score_vector);
|
||||
ecb.add_remove_extra_whitespaces(
|
||||
model_config.normalizer_spec().remove_extra_whitespaces());
|
||||
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
|
||||
ecb.add_escape_whitespaces(
|
||||
model_config.normalizer_spec().escape_whitespaces());
|
||||
ecb.add_normalized_prefixes(normalization_trie_fbs);
|
||||
ecb.add_normalized_replacements(normalization_strings_fbs);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
}
|
||||
|
||||
std::string ConvertSentencepieceModel(const std::string& model_string) {
|
||||
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
|
||||
assert(result.status().ok());
|
||||
return result.value();
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// Converts Sentencepiece configuration to flatbuffer format.
|
||||
// encoding_offset is used by some encoders that combine different encodings.
|
||||
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||
const std::string& model_config_str, int encoding_offset = 0);
|
||||
std::string ConvertSentencepieceModel(const std::string& model_string);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
|
@ -0,0 +1,236 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
namespace {
|
||||
|
||||
const char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
template <typename processing_callback>
|
||||
std::tuple<std::string, std::vector<int>> process_string(
|
||||
const std::string& input, const std::vector<int>& offsets,
|
||||
const processing_callback& pc) {
|
||||
std::string result_string;
|
||||
result_string.reserve(input.size());
|
||||
std::vector<int> result_offsets;
|
||||
result_offsets.reserve(offsets.size());
|
||||
for (int i = 0, j = 0; i < input.size();) {
|
||||
auto result = pc(input.data() + i, input.size() - i);
|
||||
auto consumed = std::get<0>(result);
|
||||
auto new_string = std::get<1>(result);
|
||||
if (consumed == 0) {
|
||||
// Skip the current byte and move forward.
|
||||
result_string.push_back(input[i]);
|
||||
result_offsets.push_back(offsets[j]);
|
||||
i++;
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
result_string.append(new_string.data(), new_string.length());
|
||||
for (int i = 0; i < new_string.length(); ++i) {
|
||||
result_offsets.push_back(offsets[j]);
|
||||
}
|
||||
j += consumed;
|
||||
i += consumed;
|
||||
}
|
||||
return std::make_tuple(result_string, result_offsets);
|
||||
}
|
||||
|
||||
inline char is_whitespace(char c) {
|
||||
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
|
||||
}
|
||||
|
||||
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
|
||||
int len) {
|
||||
if (len == 0 || !is_whitespace(*data)) {
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
int num_consumed = 1;
|
||||
for (; num_consumed < len && is_whitespace(data[num_consumed]);
|
||||
++num_consumed) {
|
||||
}
|
||||
return num_consumed > 1
|
||||
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
|
||||
: std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
|
||||
std::tuple<int, utils::string_view> find_replacement(
|
||||
const char* data, int len, const DoubleArrayTrie& dat,
|
||||
const flatbuffers::Vector<int8_t>& replacements) {
|
||||
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
|
||||
if (!max_match.empty()) {
|
||||
// Because flatbuffer byte is signed char which is not the same as char,
|
||||
// there is the reinterpret_cast here.
|
||||
const char* replaced_string_ptr =
|
||||
reinterpret_cast<const char*>(replacements.data() + max_match.id);
|
||||
return std::make_tuple(max_match.match_length,
|
||||
utils::string_view(replaced_string_ptr));
|
||||
}
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||
const std::string& in_string, const EncoderConfig& config) {
|
||||
std::vector<int> output_offsets;
|
||||
std::string result = in_string;
|
||||
output_offsets.reserve(in_string.length());
|
||||
for (int i = 0; i < in_string.length(); ++i) {
|
||||
output_offsets.push_back(i);
|
||||
}
|
||||
if (in_string.empty()) {
|
||||
return std::make_tuple(result, output_offsets);
|
||||
}
|
||||
if (config.add_dummy_prefix()) {
|
||||
result.insert(result.begin(), ' ');
|
||||
output_offsets.insert(output_offsets.begin(), 0);
|
||||
}
|
||||
// Greedely replace normalized_prefixes with normalized_replacements
|
||||
if (config.normalized_prefixes() != nullptr &&
|
||||
config.normalized_replacements() != nullptr) {
|
||||
const DoubleArrayTrie normalized_prefixes_matcher(
|
||||
config.normalized_prefixes()->nodes());
|
||||
const auto norm_replace = [&config, &normalized_prefixes_matcher](
|
||||
const char* data, int len) {
|
||||
return find_replacement(data, len, normalized_prefixes_matcher,
|
||||
*config.normalized_replacements());
|
||||
};
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, norm_replace);
|
||||
}
|
||||
if (config.remove_extra_whitespaces()) {
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, remove_extra_whitespaces);
|
||||
if (!result.empty() && is_whitespace(result.back())) {
|
||||
result.pop_back();
|
||||
output_offsets.pop_back();
|
||||
}
|
||||
}
|
||||
if (config.escape_whitespaces()) {
|
||||
const auto replace_whitespaces = [](const char* data, int len) {
|
||||
if (len > 0 && is_whitespace(*data)) {
|
||||
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
|
||||
}
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
};
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, replace_whitespaces);
|
||||
}
|
||||
|
||||
return std::make_tuple(result, output_offsets);
|
||||
}
|
||||
|
||||
EncoderResult EncodeNormalizedString(const std::string& str,
|
||||
const std::vector<int>& offsets,
|
||||
const EncoderConfig& config, bool add_bos,
|
||||
bool add_eos, bool reverse) {
|
||||
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
|
||||
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
|
||||
const int unknown_code = config.unknown_code();
|
||||
const float unknown_penalty = config.unknown_penalty();
|
||||
struct LatticeElement {
|
||||
float score = 0;
|
||||
int code = -1;
|
||||
int prev_position = -1;
|
||||
LatticeElement(float score_, int code_, int prev_position_)
|
||||
: score(score_), code(code_), prev_position(prev_position_) {}
|
||||
LatticeElement() {}
|
||||
};
|
||||
const int length = str.length();
|
||||
std::vector<LatticeElement> lattice(length + 1);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
if (i > 0 && lattice[i].prev_position < 0) {
|
||||
// This state is unreachable.
|
||||
continue;
|
||||
}
|
||||
if (unknown_code >= 0) {
|
||||
// Put unknown code.
|
||||
const float penalized_score = lattice[i].score + unknown_penalty;
|
||||
const int pos = i + 1;
|
||||
LatticeElement& current_element = lattice[pos];
|
||||
if (current_element.prev_position < 0 ||
|
||||
current_element.score < penalized_score) {
|
||||
current_element = LatticeElement(
|
||||
penalized_score, unknown_code,
|
||||
// If the current state is already reached by unknown code, merge
|
||||
// states.
|
||||
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
|
||||
}
|
||||
}
|
||||
auto lattice_update = [&lattice, i,
|
||||
piece_scores](const DoubleArrayTrie::Match& m) {
|
||||
LatticeElement& target_element = lattice[i + m.match_length];
|
||||
const float score = lattice[i].score + (*piece_scores)[m.id];
|
||||
if (target_element.prev_position < 0 || target_element.score < score) {
|
||||
target_element = LatticeElement(score, m.id, i);
|
||||
}
|
||||
};
|
||||
piece_matcher.IteratePrefixMatches(
|
||||
utils::string_view(str.data() + i, length - i), lattice_update);
|
||||
}
|
||||
|
||||
EncoderResult result;
|
||||
if (add_eos) {
|
||||
result.codes.push_back(config.end_code());
|
||||
result.offsets.push_back(length);
|
||||
}
|
||||
if (lattice[length].prev_position >= 0) {
|
||||
for (int pos = length; pos > 0;) {
|
||||
auto code = lattice[pos].code;
|
||||
if (code != config.unknown_code()) {
|
||||
code += config.encoding_offset();
|
||||
}
|
||||
result.codes.push_back(code);
|
||||
pos = lattice[pos].prev_position;
|
||||
result.offsets.push_back(offsets[pos]);
|
||||
}
|
||||
}
|
||||
if (add_bos) {
|
||||
result.codes.push_back(config.start_code());
|
||||
result.offsets.push_back(0);
|
||||
}
|
||||
if (!reverse) {
|
||||
std::reverse(result.codes.begin(), result.codes.end());
|
||||
std::reverse(result.offsets.begin(), result.offsets.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||
bool add_bos, bool add_eos, bool reverse) {
|
||||
// Get the config from the buffer.
|
||||
const EncoderConfig* config = GetEncoderConfig(config_buffer);
|
||||
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
|
||||
EncoderResult result;
|
||||
result.type = EncoderResultType::WRONG_CONFIG;
|
||||
return result;
|
||||
}
|
||||
std::string normalized_string;
|
||||
std::vector<int> offsets;
|
||||
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
|
||||
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
|
||||
add_eos, reverse);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||
|
||||
// Sentencepiece encoder optimized with memmapped model.
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
|
||||
|
||||
struct EncoderResult {
|
||||
EncoderResultType type = EncoderResultType::SUCCESS;
|
||||
std::vector<int> codes;
|
||||
std::vector<int> offsets;
|
||||
};
|
||||
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||
const std::string& in_string, const EncoderConfig& config);
|
||||
|
||||
// Encodes one string and returns ids and offsets. Takes the configuration as a
|
||||
// type-erased buffer.
|
||||
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||
bool add_bos, bool add_eos, bool reverse);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
|
@ -0,0 +1,171 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||
#include "src/sentencepiece.pb.h"
|
||||
#include "src/sentencepiece_processor.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
namespace internal {
|
||||
|
||||
tensorflow::Status TFReadFileToString(const std::string& filepath,
|
||||
std::string* data) {
|
||||
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
|
||||
data);
|
||||
}
|
||||
|
||||
absl::Status StdReadFileToString(const std::string& filepath,
|
||||
std::string* data) {
|
||||
std::ifstream infile(filepath);
|
||||
if (!infile.is_open()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrFormat("Error when opening %s", filepath));
|
||||
}
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
data->append(contents);
|
||||
infile.close();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
|
||||
static char kConfigFilePath[] =
|
||||
"/mediapipe/tasks/cc/text/custom_ops/"
|
||||
"sentencepiece/testdata/sentencepiece.model";
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(true);
|
||||
ecb.add_add_dummy_prefix(true);
|
||||
ecb.add_escape_whitespaces(true);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("x y", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
|
||||
}
|
||||
{
|
||||
const auto result = NormalizeString("\tx y\n", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringReplacement) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
|
||||
const char norm_replacements[] = "A1\0A2\0A3\0A4";
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
|
||||
const auto norm_r = builder.CreateVector<int8_t>(
|
||||
reinterpret_cast<const signed char*>(norm_replacements),
|
||||
sizeof(norm_replacements));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto norm_p = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(false);
|
||||
ecb.add_normalized_prefixes(norm_p);
|
||||
ecb.add_normalized_replacements(norm_r);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("ABAABAAABAAAA", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "A1BA2BA3BA4");
|
||||
EXPECT_THAT(offsets,
|
||||
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
|
||||
"X"};
|
||||
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
|
||||
const auto norm_r = builder.CreateVector<int8_t>(
|
||||
reinterpret_cast<const signed char*>(norm_replacements),
|
||||
sizeof(norm_replacements));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto norm_p = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(true);
|
||||
ecb.add_normalized_prefixes(norm_p);
|
||||
ecb.add_normalized_replacements(norm_r);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, " A1BA2BA3BA4");
|
||||
EXPECT_THAT(offsets,
|
||||
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, ConfigConverter) {
|
||||
std::string config;
|
||||
auto status =
|
||||
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
::sentencepiece::SentencePieceProcessor processor;
|
||||
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
|
||||
const auto converted_model = ConvertSentencepieceModel(config);
|
||||
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
|
||||
const auto encoded =
|
||||
EncodeString(test_string, converted_model.data(), false, false, false);
|
||||
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
|
||||
|
||||
::sentencepiece::SentencePieceText reference_encoded;
|
||||
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
|
||||
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
|
||||
for (int i = 0; i < encoded.codes.size(); ++i) {
|
||||
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
|
||||
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// The constant is copied from
|
||||
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
|
||||
constexpr float kUnkPenalty = 10.0;
|
||||
|
||||
// These constants are copied from
|
||||
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||
//
|
||||
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
|
||||
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
|
||||
// since this character can be useful both for user and
|
||||
// developer. We can easily figure out that <unk> is emitted.
|
||||
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
|
@ -0,0 +1,129 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe::tflite_operations {
|
||||
namespace sentencepiece::tokenizer {
|
||||
namespace {
|
||||
|
||||
using ::tflite::SetTensorToDynamic;
|
||||
|
||||
constexpr int kSPModelIndex = 0;
|
||||
constexpr int kInputIndex = 1;
|
||||
constexpr int kAddBOSInput = 4;
|
||||
constexpr int kAddEOSInput = 5;
|
||||
constexpr int kReverseInput = 6;
|
||||
|
||||
constexpr int kOutputValuesInd = 0;
|
||||
constexpr int kOutputSplitsInd = 1;
|
||||
|
||||
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
|
||||
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
|
||||
int index = 0;
|
||||
for (const int size : sizes) {
|
||||
array_size->data[index++] = size;
|
||||
}
|
||||
return array_size;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Initializes text encoder object from serialized parameters.
|
||||
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
|
||||
size_t /*length*/) {
|
||||
return nullptr;
|
||||
}
|
||||
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO: Add checks for input and output tensors.
|
||||
TfLiteTensor& output_values =
|
||||
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||
SetTensorToDynamic(&output_values);
|
||||
|
||||
TfLiteTensor& output_splits =
|
||||
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||
SetTensorToDynamic(&output_splits);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor& model_tensor =
|
||||
context->tensors[node->inputs->data[kSPModelIndex]];
|
||||
const auto model_buffer_data = model_tensor.data.data;
|
||||
const TfLiteTensor& input_text =
|
||||
context->tensors[node->inputs->data[kInputIndex]];
|
||||
|
||||
const TfLiteTensor add_bos_tensor =
|
||||
context->tensors[node->inputs->data[kAddBOSInput]];
|
||||
const bool add_bos = add_bos_tensor.data.b[0];
|
||||
const TfLiteTensor add_eos_tensor =
|
||||
context->tensors[node->inputs->data[kAddEOSInput]];
|
||||
const bool add_eos = add_eos_tensor.data.b[0];
|
||||
const TfLiteTensor reverse_tensor =
|
||||
context->tensors[node->inputs->data[kReverseInput]];
|
||||
const bool reverse = reverse_tensor.data.b[0];
|
||||
|
||||
std::vector<int32> encoded;
|
||||
std::vector<int32> splits;
|
||||
const int num_strings = tflite::GetStringCount(&input_text);
|
||||
for (int i = 0; i < num_strings; ++i) {
|
||||
const auto strref = tflite::GetString(&input_text, i);
|
||||
const auto res = EncodeString(std::string(strref.str, strref.len),
|
||||
model_buffer_data, add_bos, add_eos, reverse);
|
||||
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
|
||||
"Sentencepiece conversion failed");
|
||||
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
|
||||
splits.emplace_back(encoded.size());
|
||||
}
|
||||
|
||||
TfLiteTensor& output_values =
|
||||
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(
|
||||
context, &output_values,
|
||||
CreateSizeArray({static_cast<int>(encoded.size())})));
|
||||
int32_t* output_values_flat = output_values.data.i32;
|
||||
std::copy(encoded.begin(), encoded.end(), output_values_flat);
|
||||
TfLiteTensor& output_splits =
|
||||
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(
|
||||
context, &output_splits,
|
||||
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
|
||||
int32_t* output_splits_flat = output_splits.data.i32;
|
||||
*output_splits_flat = 0;
|
||||
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace sentencepiece::tokenizer
|
||||
|
||||
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
|
||||
static TfLiteRegistration r = {
|
||||
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
|
||||
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe::tflite_operations {
|
||||
|
||||
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
|
||||
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
Binary file not shown.
60
mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
Normal file
60
mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
Normal file
|
@ -0,0 +1,60 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// AOSP and WASM doesn't support string_view,
|
||||
// we put here a minimal re-implementation.
|
||||
namespace utils {
|
||||
|
||||
class string_view {
|
||||
public:
|
||||
explicit string_view(const std::string& s)
|
||||
: str_(s.data()), len_(s.length()) {}
|
||||
string_view(const char* str, int len) : str_(str), len_(len) {}
|
||||
// A constructor from c string.
|
||||
explicit string_view(const char* s) : str_(s), len_(strlen(s)) {}
|
||||
|
||||
int length() const { return len_; }
|
||||
const char* data() const { return str_; }
|
||||
bool empty() const { return len_ == 0; }
|
||||
unsigned char at(int i) const { return str_[i]; }
|
||||
|
||||
private:
|
||||
const char* str_ = nullptr;
|
||||
const int len_ = 0;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const string_view& sv) {
|
||||
os << std::string(sv.data(), sv.length());
|
||||
return os;
|
||||
}
|
||||
inline bool operator==(const string_view& view1, const string_view& view2) {
|
||||
if (view1.length() != view2.length()) {
|
||||
return false;
|
||||
}
|
||||
return memcmp(view1.data(), view2.data(), view1.length()) == 0;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
|
|
@ -32,7 +32,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::language_detector {
|
||||
namespace {
|
||||
|
@ -75,7 +75,7 @@ absl::Status MatchesLanguageDetectorResult(
|
|||
|
||||
} // namespace
|
||||
|
||||
class LanguageDetectorTest : public tflite_shims::testing::Test {};
|
||||
class LanguageDetectorTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
|
||||
auto options = std::make_unique<LanguageDetectorOptions>();
|
||||
|
|
|
@ -89,7 +89,7 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_classifier {
|
||||
namespace {
|
||||
|
@ -87,7 +87,7 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
|||
|
||||
} // namespace
|
||||
|
||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||
class TextClassifierTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
|
|
|
@ -91,6 +91,6 @@ cc_test(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::text_embedder {
|
||||
namespace {
|
||||
|
@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
|
|||
// Embedding model with regex preprocessing.
|
||||
constexpr char kRegexOneEmbeddingModel[] =
|
||||
"regex_one_embedding_with_metadata.tflite";
|
||||
constexpr char kUniversalSentenceEncoderModel[] =
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
// Tolerance for embedding vector coordinate values.
|
||||
constexpr float kEpsilon = 1e-4;
|
||||
|
@ -49,7 +51,7 @@ using ::mediapipe::file::JoinPath;
|
|||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
class EmbedderTest : public tflite_shims::testing::Test {};
|
||||
class EmbedderTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(EmbedderTest, FailsWithMissingModel) {
|
||||
auto text_embedder =
|
||||
|
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
|||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result0,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
|
||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
|
||||
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
|
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
|||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result0,
|
||||
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||
"pancake upside-down before they hand it "
|
||||
"to you. It's a great gimmick."));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result1,
|
||||
text_embedder->Embed(
|
||||
"Let's make a plan to steal the declaration of independence."));
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
||||
|
|
|
@ -81,6 +81,6 @@ cc_test(
|
|||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
"@org_tensorflow//tensorflow/lite:test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe::tasks::text::utils {
|
||||
|
||||
|
@ -76,7 +76,7 @@ absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
|
|||
|
||||
} // namespace
|
||||
|
||||
class TextModelUtilsTest : public tflite_shims::testing::Test {};
|
||||
class TextModelUtilsTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_type,
|
||||
|
|
|
@ -29,7 +29,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -105,7 +105,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
|||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
class FaceBlendshapesTest : public tflite_shims::testing::Test {};
|
||||
class FaceBlendshapesTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(FaceBlendshapesTest, SmokeTest) {
|
||||
// Prepare graph inputs.
|
||||
|
|
|
@ -67,7 +67,7 @@ constexpr char kPortraitExpectedFaceLandmarksName[] =
|
|||
"portrait_expected_face_landmarks.pbtxt";
|
||||
constexpr char kPortraitExpectedBlendshapesName[] =
|
||||
"portrait_expected_blendshapes.pbtxt";
|
||||
constexpr char kPortaitExpectedFaceGeomertyName[] =
|
||||
constexpr char kPortraitExpectedFaceGeometryName[] =
|
||||
"portrait_expected_face_geometry.pbtxt";
|
||||
|
||||
constexpr float kLandmarksDiffMargin = 0.03;
|
||||
|
@ -100,7 +100,7 @@ struct FaceLandmarkerTestParams {
|
|||
|
||||
mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() {
|
||||
auto face_geometry = GetExpectedProto<face_geometry::proto::FaceGeometry>(
|
||||
kPortaitExpectedFaceGeomertyName);
|
||||
kPortraitExpectedFaceGeometryName);
|
||||
return face_geometry.pose_transform_matrix();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,18 +23,12 @@ cc_library(
|
|||
srcs = ["face_stylizer_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:warp_affine_calculator",
|
||||
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:face_to_rect_calculator",
|
||||
"//mediapipe/calculators/util:from_image_calculator",
|
||||
"//mediapipe/calculators/util:inverse_matrix_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
|
@ -53,7 +47,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
||||
|
|
|
@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The input image can be of any size with format RGB or RGBA.
|
||||
// When no face is detected on the input image, the method returns a
|
||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||
// face. To ensure that the output image has reasonable quality, the stylized
|
||||
// output image size is the smaller of the model output size and the size of
|
||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
||||
// face. The stylized output image size is the same as the model output size.
|
||||
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
|
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
// When no face is detected on the input image, the method returns a
|
||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||
// face. To ensure that the output image has reasonable quality, the stylized
|
||||
// output image size is the smaller of the model output size and the size of
|
||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
||||
// face. The stylized output image size is the same as the model output size.
|
||||
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
|
||||
mediapipe::Image image, int64_t timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
|
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The "result_callback" provides:
|
||||
// - When no face is detected on the input image, the method returns a
|
||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||
// face. To ensure that the output image has reasonable quality, the
|
||||
// stylized output image size is the smaller of the model output size and
|
||||
// the size of the 'region_of_interest' specified in
|
||||
// 'image_processing_options'.
|
||||
// face. The stylized output image size is the same as the model output
|
||||
// size.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
|
|
|
@ -19,8 +19,7 @@ limitations under the License.
|
|||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
|||
image_in >> preprocessing.In(kImageTag);
|
||||
face_rect >> preprocessing.In(kNormRectTag);
|
||||
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
||||
auto transform_matrix = preprocessing.Out(kMatrixTag);
|
||||
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the ImageToTensorCalculator.
|
||||
|
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
|||
model_output_tensors >> tensors_to_image.In(kTensorsTag);
|
||||
auto tensor_image = tensors_to_image.Out(kImageTag);
|
||||
|
||||
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
||||
transform_matrix >> inverse_matrix.In(kMatrixTag);
|
||||
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
|
||||
auto& image_converter = graph.AddNode("ImageCloneCalculator");
|
||||
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
|
||||
.set_output_on_gpu(false);
|
||||
tensor_image >> image_converter.In("");
|
||||
|
||||
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
||||
auto& warp_affine_options =
|
||||
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
|
||||
warp_affine_options.set_border_mode(
|
||||
WarpAffineCalculatorOptions::BORDER_ZERO);
|
||||
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
|
||||
tensor_image >> warp_affine.In(kImageTag);
|
||||
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
|
||||
image_size >> warp_affine.In(kOutputSizeTag);
|
||||
auto image_to_crop = warp_affine.Out(kImageTag);
|
||||
|
||||
// The following calculators are for cropping and resizing the output image
|
||||
// based on the roi and the model output size. As the WarpAffineCalculator
|
||||
// rotates the image based on the transform matrix, the rotation info in the
|
||||
// rect proto is stripped to prevent the ImageCroppingCalculator from
|
||||
// performing extra rotation.
|
||||
auto& strip_rotation =
|
||||
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
|
||||
face_rect >> strip_rotation.In(kNormRectTag);
|
||||
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
|
||||
auto& from_image = graph.AddNode("FromImageCalculator");
|
||||
image_to_crop >> from_image.In(kImageTag);
|
||||
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
|
||||
auto& image_cropping_opts =
|
||||
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
|
||||
image_cropping_opts.set_output_max_width(
|
||||
image_to_tensor_options.output_tensor_width());
|
||||
image_cropping_opts.set_output_max_height(
|
||||
image_to_tensor_options.output_tensor_height());
|
||||
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
|
||||
auto& to_image = graph.AddNode("ToImageCalculator");
|
||||
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
|
||||
// graph selects its cpu or gpu path based on the image preprocessing
|
||||
// backend.
|
||||
if (use_gpu) {
|
||||
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
|
||||
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
|
||||
} else {
|
||||
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
|
||||
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
|
||||
}
|
||||
|
||||
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
|
||||
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
|
||||
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -43,7 +43,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -137,7 +137,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
|||
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
class HandLandmarkerTest : public tflite_shims::testing::Test {};
|
||||
class HandLandmarkerTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(HandLandmarkerTest, Succeeds) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
|
|
@ -41,7 +41,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
|
|
@ -59,7 +59,6 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::utils::AllowIf;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarksDetectorGraphOptions;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||
|
|
|
@ -146,7 +146,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner(
|
|||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(),
|
||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||
}
|
||||
|
||||
// Helper function to create a Multi Hand Landmark TaskRunner.
|
||||
|
@ -188,7 +188,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner(
|
|||
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(),
|
||||
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
|
||||
}
|
||||
|
||||
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {
|
||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -148,7 +148,7 @@ class MobileNetQuantizedOpResolverMissingOps
|
|||
const MobileNetQuantizedOpResolverMissingOps& r) = delete;
|
||||
};
|
||||
|
||||
class CreateTest : public tflite_shims::testing::Test {};
|
||||
class CreateTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
||||
auto options = std::make_unique<ImageClassifierOptions>();
|
||||
|
@ -265,7 +265,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
|
|||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
class ImageModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -605,7 +605,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
|||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
class VideoModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -707,7 +707,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK(image_classifier->Close());
|
||||
}
|
||||
|
||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
||||
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
|
|
@ -30,9 +30,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -103,7 +103,7 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
|
|||
delete;
|
||||
};
|
||||
|
||||
class CreateTest : public tflite_shims::testing::Test {};
|
||||
class CreateTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
|
||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||
|
@ -181,7 +181,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
|
|||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
|
||||
}
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
class ImageModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -410,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
class VideoModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -494,7 +494,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
|||
MP_ASSERT_OK(image_embedder->Close());
|
||||
}
|
||||
|
||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
||||
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -180,7 +180,7 @@ class DeepLabOpResolver : public ::tflite::MutableOpResolver {
|
|||
DeepLabOpResolver(const DeepLabOpResolver& r) = delete;
|
||||
};
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||
|
||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||
public:
|
||||
|
@ -268,7 +268,7 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
|
|||
}
|
||||
}
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
class ImageModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -521,7 +521,7 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
|
|||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
class VideoModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -581,7 +581,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
|||
MP_ASSERT_OK(segmenter->Close());
|
||||
}
|
||||
|
||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
||||
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||
|
|
|
@ -64,7 +64,6 @@ using ::mediapipe::CalculatorGraphConfig;
|
|||
using ::mediapipe::Image;
|
||||
using ::mediapipe::NormalizedRect;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
|
||||
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
|
||||
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
image_segmenter::proto::ImageSegmenterGraphOptions;
|
||||
|
||||
|
|
|
@ -39,9 +39,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -124,7 +124,7 @@ MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
|
|||
similarity_threshold;
|
||||
}
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||
|
||||
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
|
||||
public:
|
||||
|
@ -261,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||
info) { return info.param.test_name; });
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
class ImageModeTest : public tflite::testing::Test {};
|
||||
|
||||
// TODO: fix this unit test after image segmenter handled post
|
||||
// processing correctly with rotated image.
|
||||
|
|
|
@ -43,9 +43,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
|
@ -159,7 +159,7 @@ class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver {
|
|||
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
|
||||
};
|
||||
|
||||
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
|
||||
class CreateFromOptionsTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
|
||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||
|
@ -332,7 +332,7 @@ TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
|
|||
// TODO: Add NumThreadsTest back after having an
|
||||
// "acceleration configuration" field in the ObjectDetectorOptions.
|
||||
|
||||
class ImageModeTest : public tflite_shims::testing::Test {};
|
||||
class ImageModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||
|
@ -618,7 +618,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
class VideoModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||
|
@ -673,7 +673,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
|||
MP_ASSERT_OK(object_detector->Close());
|
||||
}
|
||||
|
||||
class LiveStreamModeTest : public tflite_shims::testing::Test {};
|
||||
class LiveStreamModeTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||
|
|
|
@ -97,8 +97,10 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
|||
// limit the number of frames in flight.
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
|
||||
bool enable_flow_limiting) {
|
||||
bool enable_flow_limiting, bool output_segmentation_masks) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
|
||||
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||
graph.Out(kSegmentationMaskTag);
|
||||
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
|
||||
graph.Out(kNormLandmarksTag);
|
||||
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
||||
|
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
||||
graph.Out(kPoseAuxiliaryLandmarksTag);
|
||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||
if (output_segmentation_masks) {
|
||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||
graph.Out(kSegmentationMaskTag);
|
||||
}
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
|
||||
|
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
|
|||
PoseLandmarkerGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM,
|
||||
options->output_segmentation_masks),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback))));
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
|
|||
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
|
||||
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
|
||||
Source<std::vector<Detection>> pose_detections;
|
||||
Source<std::vector<Image>> segmentation_masks;
|
||||
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
|
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
// input_stream: "IMAGE:image_in"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "NORM_LANDMARKS:pose_landmarks"
|
||||
// output_stream: "LANDMARKS:world_landmarks"
|
||||
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
|
||||
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
|
||||
// output_stream: "POSE_RECTS:pose_rects"
|
||||
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
|
||||
|
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
bool output_segmentation_masks =
|
||||
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
|
||||
if (sc->Options<PoseLandmarkerGraphOptions>()
|
||||
.base_options()
|
||||
.has_model_asset()) {
|
||||
|
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto outs,
|
||||
BuildPoseLandmarkerGraph(
|
||||
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
ASSIGN_OR_RETURN(auto outs,
|
||||
BuildPoseLandmarkerGraph(
|
||||
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
|
||||
graph, output_segmentation_masks));
|
||||
outs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
|
||||
outs.world_landmark_lists >>
|
||||
|
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
kAuxiliaryLandmarksTag)];
|
||||
outs.pose_rects_next_frame >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
|
||||
outs.segmentation_masks >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
outs.pose_detections >>
|
||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||
outs.image >> graph[Output<Image>(kImageTag)];
|
||||
if (outs.segmentation_masks) {
|
||||
*outs.segmentation_masks >>
|
||||
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
|
||||
}
|
||||
|
||||
// TODO remove when support is fixed.
|
||||
// As mediapipe GraphBuilder currently doesn't support configuring
|
||||
|
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
// graph: the mediapipe graph instance to be updated.
|
||||
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
|
||||
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph,
|
||||
bool output_segmentation_masks) {
|
||||
const int max_num_poses =
|
||||
tasks_options.pose_detector_graph_options().num_poses();
|
||||
|
||||
|
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
|
|||
auto pose_rects_for_next_frame =
|
||||
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
|
||||
.Cast<std::vector<NormalizedRect>>();
|
||||
auto segmentation_masks =
|
||||
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
||||
.Cast<std::vector<Image>>();
|
||||
std::optional<Source<std::vector<Image>>> segmentation_masks;
|
||||
if (output_segmentation_masks) {
|
||||
segmentation_masks =
|
||||
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
|
||||
.Cast<std::vector<Image>>();
|
||||
}
|
||||
|
||||
if (tasks_options.base_options().use_stream_mode()) {
|
||||
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user