Merge branch 'google:master' into face-stylizer-python-add-tests

This commit is contained in:
Kinar R 2023-04-21 06:04:33 +05:30 committed by GitHub
commit a5716c9225
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
222 changed files with 8366 additions and 1384 deletions

View File

@ -239,6 +239,16 @@ http_archive(
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"},
)
http_archive(
name = "darts_clone",
build_file = "@//third_party:darts_clone.BUILD",
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
urls = [
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)
http_archive(
name = "org_tensorflow_text",
sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",

View File

@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
} else if (packet_options.has_string_value()) {
packet.Set<std::string>();
} else if (packet_options.has_uint64_value()) {
packet.Set<uint64>();
packet.Set<uint64_t>();
} else if (packet_options.has_classification_list_value()) {
packet.Set<ClassificationList>();
} else if (packet_options.has_landmark_list_value()) {
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
} else if (packet_options.has_string_value()) {
packet.Set(MakePacket<std::string>(packet_options.string_value()));
} else if (packet_options.has_uint64_value()) {
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
} else if (packet_options.has_classification_list_value()) {
packet.Set(MakePacket<ClassificationList>(
packet_options.classification_list_value()));

View File

@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
}
// Use this when ALLOW/DISALLOW input is provided as a side packet.
void RunTimeStep(int64 timestamp, bool stream_payload) {
void RunTimeStep(int64_t timestamp, bool stream_payload) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
}
// Use this when ALLOW/DISALLOW input is provided as an input stream.
void RunTimeStep(int64 timestamp, const std::string& control_tag,
void RunTimeStep(int64_t timestamp, const std::string& control_tag,
bool control) {
runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(true).At(Timestamp(timestamp)));
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
}
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
}
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
)");
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
)");
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
)");
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
)");
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", false);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
output_stream: "test_output"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", true);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", false);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", true);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output =
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43;
constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44;
constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", false);
constexpr int64 kTimestampValue3 = 45;
constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", true);
const std::vector<Packet>& output =
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", false);
const std::vector<Packet>& output =
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
output_stream: "STATE_CHANGE:state_changed"
)");
constexpr int64 kTimestampValue0 = 42;
constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output =

View File

@ -35,7 +35,7 @@ class MatrixToVectorCalculatorTest
void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; }
void AppendInput(const std::vector<float>& column_major_data,
int64 timestamp) {
int64_t timestamp) {
ASSERT_EQ(num_input_samples_ * num_input_channels_,
column_major_data.size());
Eigen::Map<const Matrix> data_map(&column_major_data[0],

View File

@ -51,9 +51,9 @@ class SimpleRunner : public CalculatorRunner {
virtual ~SimpleRunner() {}
void SetInput(const std::vector<int64>& timestamp_list) {
void SetInput(const std::vector<int64_t>& timestamp_list) {
MutableInputs()->Index(0).packets.clear();
for (const int64 ts : timestamp_list) {
for (const int64_t ts : timestamp_list) {
MutableInputs()->Index(0).packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts)))
.At(Timestamp(ts)));
@ -72,8 +72,8 @@ class SimpleRunner : public CalculatorRunner {
}
void CheckOutputTimestamps(
const std::vector<int64>& expected_frames,
const std::vector<int64>& expected_timestamps) const {
const std::vector<int64_t>& expected_frames,
const std::vector<int64_t>& expected_timestamps) const {
EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size());
EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size());
int count = 0;
@ -112,7 +112,7 @@ MATCHER_P2(PacketAtTimestamp, payload, timestamp,
*result_listener << "at incorrect timestamp = " << arg.Timestamp().Value();
return false;
}
int64 actual_payload = arg.template Get<int64>();
int64_t actual_payload = arg.template Get<int64_t>();
if (actual_payload != payload) {
*result_listener << "with incorrect payload = " << actual_payload;
return false;
@ -137,18 +137,18 @@ class ReproducibleJitterWithReflectionStrategyForTesting
//
// An EXPECT will fail if sequence is less than the number requested during
// processing.
static std::vector<uint64> random_sequence;
static std::vector<uint64_t> random_sequence;
protected:
virtual uint64 GetNextRandom(uint64 n) {
virtual uint64_t GetNextRandom(uint64_t n) {
EXPECT_LT(sequence_index_, random_sequence.size());
return random_sequence[sequence_index_++] % n;
}
private:
int32 sequence_index_ = 0;
int32_t sequence_index_ = 0;
};
std::vector<uint64>
std::vector<uint64_t>
ReproducibleJitterWithReflectionStrategyForTesting::random_sequence;
// PacketResamplerCalculator child class which injects a specified stream
@ -469,7 +469,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) {
}
)pb"));
for (const int64 ts : {0, 5000, 10010, 15001, 19990}) {
for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) {
runner.MutableInputs()->Tag(kDataTag).packets.push_back(
Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts)));
}

View File

@ -43,8 +43,8 @@ constexpr char kDisallowTag[] = "DISALLOW";
// Returns the timestamp values for a vector of Packets.
// TODO: puth this kind of test util in a common place.
std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64> result;
std::vector<int64_t> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64_t> result;
for (const Packet& packet : packets) {
result.push_back(packet.Timestamp().Value());
}
@ -371,7 +371,7 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) {
for (int main_ts = 0; main_ts < 50; ++main_ts) {
send_packet("in", main_ts);
MP_EXPECT_OK(graph_.WaitUntilIdle());
std::vector<int64> ts_values = TimestampValues(outputs);
std::vector<int64_t> ts_values = TimestampValues(outputs);
EXPECT_EQ(ts_values.size(), main_ts + 1);
for (int j = 0; j < main_ts + 1; ++j) {
EXPECT_EQ(ts_values[j], j);

View File

@ -121,7 +121,7 @@ absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
if (cc->Outputs().HasTag(kTagAtTimestamp)) {
RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries())
<< "For AT_TIMESTAMP tag, 2 input side packets are required.";
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64>();
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64_t>();
} else {
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries())
<< "Same number of input side packets and output streams is required.";
@ -178,8 +178,8 @@ absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
}
} else if (cc->Outputs().HasTag(kTagAtTimestamp)) {
int64 timestamp =
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64>();
int64_t timestamp =
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64_t>();
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
cc->Outputs()
.Get(output_tag_, i)

View File

@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
REGISTER_CALCULATOR(StringToUintCalculator);
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
REGISTER_CALCULATOR(StringToInt32Calculator);
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
REGISTER_CALCULATOR(StringToUint32Calculator);
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
REGISTER_CALCULATOR(StringToInt64Calculator);
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
REGISTER_CALCULATOR(StringToUint64Calculator);
} // namespace mediapipe

View File

@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
frame_ptr->Height(), frame_ptr->WidthStep(),
const_cast<uint8_t*>(frame_ptr->PixelData()),
[](uint8* data){});
[](uint8_t* data){});
ASSIGN_OR_RETURN(auto result,
runner->Run(image_frame, matrix, size, border_mode));
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));

View File

@ -401,8 +401,8 @@ cc_library_with_tflite(
hdrs = ["inference_calculator.h"],
tflite_deps = [
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
deps = [
":inference_calculator_cc_proto",
@ -506,7 +506,7 @@ cc_library_with_tflite(
name = "tflite_delegate_ptr",
hdrs = ["tflite_delegate_ptr.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
],
)
@ -517,8 +517,8 @@ cc_library_with_tflite(
tflite_deps = [
":tflite_delegate_ptr",
"//mediapipe/util/tflite:tflite_model_loader",
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
],
deps = [
":inference_runner",
@ -546,8 +546,8 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:c_api_types",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
] + select({
"//conditions:default": [],

View File

@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
}
return PacketAdopting<tflite::OpResolver>(
std::make_unique<tflite_shims::ops::builtin::
BuiltinOpResolverWithoutDefaultDelegates>());
std::make_unique<
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
}
} // namespace api2

View File

@ -26,7 +26,7 @@
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
namespace api2 {
@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf {
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
// migration.
static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>::
Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
"OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};

View File

@ -24,7 +24,7 @@
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h"
#include "mediapipe/calculators/tensor/inference_runner.h"
#include "tensorflow/lite/core/shims/cc/interpreter.h"
#include "tensorflow/lite/interpreter.h"
#if defined(MEDIAPIPE_ANDROID)
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#endif // ANDROID

View File

@ -22,9 +22,9 @@
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/core/shims/c/c_api_types.h"
#include "tensorflow/lite/core/shims/cc/interpreter.h"
#include "tensorflow/lite/core/shims/cc/interpreter_builder.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
@ -33,8 +33,8 @@ namespace mediapipe {
namespace {
using Interpreter = ::tflite_shims::Interpreter;
using InterpreterBuilder = ::tflite_shims::InterpreterBuilder;
using Interpreter = ::tflite::Interpreter;
using InterpreterBuilder = ::tflite::InterpreterBuilder;
template <typename T>
void CopyTensorBufferToInterpreter(const Tensor& input_tensor,

View File

@ -23,8 +23,8 @@
#include "mediapipe/calculators/tensor/tflite_delegate_ptr.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/c/c_api_types.h"
namespace mediapipe {

View File

@ -18,7 +18,7 @@
#include <functional>
#include <memory>
#include "tensorflow/lite/core/shims/c/c_api_types.h"
#include "tensorflow/lite/c/c_api_types.h"
namespace mediapipe {

View File

@ -61,12 +61,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
// overload GPU/TPU/...
class SimpleSemaphore {
public:
explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {}
explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {}
SimpleSemaphore(const SimpleSemaphore&) = delete;
SimpleSemaphore(SimpleSemaphore&&) = delete;
// Acquires the semaphore by certain amount.
void Acquire(uint32 amount) {
void Acquire(uint32_t amount) {
mutex_.Lock();
while (count_ < amount) {
cond_.Wait(&mutex_);
@ -76,7 +76,7 @@ class SimpleSemaphore {
}
// Releases the semaphore by certain amount.
void Release(uint32 amount) {
void Release(uint32_t amount) {
mutex_.Lock();
count_ += amount;
cond_.SignalAll();
@ -84,7 +84,7 @@ class SimpleSemaphore {
}
private:
uint32 count_;
uint32_t count_;
absl::Mutex mutex_;
absl::CondVar cond_;
};
@ -488,7 +488,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
// necessary.
absl::Status OutputBatch(CalculatorContext* cc,
std::unique_ptr<InferenceState> inference_state) {
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow());
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
@ -544,7 +544,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
get_session_run_throttle(options_.max_concurrent_session_runs());
session_run_throttle->Acquire(1);
}
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow());
tf::Status tf_status;
{
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
@ -562,7 +562,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
// informative error message.
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
->IncrementBy(run_end_time - run_start_time);
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
@ -611,7 +611,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
}
// Get end time and report.
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow());
cc->GetCounter(kTotalUsecsCounterSuffix)
->IncrementBy(end_time - start_time);
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
@ -650,7 +650,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
// The static singleton semaphore to throttle concurrent session runs.
static SimpleSemaphore* get_session_run_throttle(
int32 max_concurrent_session_runs) {
int32_t max_concurrent_session_runs) {
static SimpleSemaphore* session_run_throttle =
new SimpleSemaphore(max_concurrent_session_runs);
return session_run_throttle;

View File

@ -197,15 +197,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
// timestamp and the associated feature. This information is used in process
// to output batches of packets in order.
timestamps_.clear();
int64 last_timestamp_seen = Timestamp::PreStream().Value();
int64_t last_timestamp_seen = Timestamp::PreStream().Value();
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
if (absl::StrContains(map_kv.first, "/timestamp")) {
LOG(INFO) << "Found feature timestamps: " << map_kv.first
<< " with size: " << map_kv.second.feature_size();
int64 recent_timestamp = Timestamp::PreStream().Value();
int64_t recent_timestamp = Timestamp::PreStream().Value();
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
int64 next_timestamp =
int64_t next_timestamp =
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
RET_CHECK_GT(next_timestamp, recent_timestamp)
<< "Timestamps must be sequential. If you're seeing this message "
@ -361,8 +361,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
// any particular call to Process(). At the every end, we output the
// poststream packets. If we only have poststream packets,
// last_timestamp_key_ will be empty.
int64 start_timestamp = 0;
int64 end_timestamp = 0;
int64_t start_timestamp = 0;
int64_t end_timestamp = 0;
if (last_timestamp_key_.empty() || process_poststream_) {
process_poststream_ = true;
start_timestamp = Timestamp::PostStream().Value();
@ -481,14 +481,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
// Store a map from the keys for each stream to the timestamps for each
// key. This allows us to identify which packets to output for each stream
// for timestamps within a given time window.
std::map<std::string, std::vector<int64>> timestamps_;
std::map<std::string, std::vector<int64_t>> timestamps_;
// Store the stream with the latest timestamp in the SequenceExample.
std::string last_timestamp_key_;
// Store the index of the current timestamp. Will be less than
// timestamps_[last_timestamp_key_].size().
int current_timestamp_index_;
// Store the very first timestamp, so we output everything on the first frame.
int64 first_timestamp_seen_;
int64_t first_timestamp_seen_;
// List of keypoint names.
std::vector<std::string> keypoint_names_;
// Default keypoint location when missing.

View File

@ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test {
}
}
const int64 time = 1234;
const int64_t time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(time)));
@ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) {
// 2^i can be represented exactly in floating point numbers if 'i' is small.
input->at(i) = static_cast<float>(1 << i);
}
const int64 time = 1234;
const int64_t time = 1234;
runner_->MutableInputs()->Index(0).packets.push_back(
Adopt(input.release()).At(Timestamp(time)));

View File

@ -28,11 +28,8 @@
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
using mediapipe::Adopt;
using mediapipe::CalculatorBase;
using mediapipe::ImageFrame;
using mediapipe::PacketTypeSet;
using mediapipe::autoflip::Border;
constexpr char kDetectedBorders[] = "DETECTED_BORDERS";
constexpr int kMinBorderDistance = 5;

View File

@ -28,16 +28,12 @@
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
using mediapipe::Adopt;
using mediapipe::CalculatorGraphConfig;
using mediapipe::CalculatorRunner;
using mediapipe::ImageFormat;
using mediapipe::ImageFrame;
using mediapipe::Packet;
using mediapipe::PacketTypeSet;
using mediapipe::ParseTextProtoOrDie;
using mediapipe::Timestamp;
using mediapipe::autoflip::Border;
namespace mediapipe {
namespace autoflip {

View File

@ -31,14 +31,11 @@
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
using mediapipe::Adopt;
using mediapipe::CalculatorGraphConfig;
using mediapipe::CalculatorRunner;
using mediapipe::ImageFormat;
using mediapipe::ImageFrame;
using mediapipe::PacketTypeSet;
using mediapipe::ParseTextProtoOrDie;
using mediapipe::Timestamp;
namespace mediapipe {
namespace autoflip {

View File

@ -28,8 +28,6 @@
using mediapipe::Packet;
using mediapipe::PacketTypeSet;
using mediapipe::autoflip::DetectionSet;
using mediapipe::autoflip::SalientRegion;
using mediapipe::autoflip::SignalType;
constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY";
constexpr char kSignalInputsTag[] = "SIGNAL";

View File

@ -19,8 +19,6 @@ namespace mediapipe {
namespace api2 {
namespace test {
using testing::ElementsAre;
// Returns the packet values for a vector of Packets.
template <typename T>
std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) {

View File

@ -310,7 +310,7 @@ class Scheduler {
absl::Mutex state_mutex_;
// Current state of the scheduler.
std::atomic<State> state_ = ATOMIC_VAR_INIT(STATE_NOT_STARTED);
std::atomic<State> state_ = STATE_NOT_STARTED;
// True if all graph input streams are closed.
bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false;

View File

@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
// Record the most recent first kept timestamp on any stream.
for (const auto& stream : input_stream_managers_) {
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_)
? target_queue_size_
: trigger_queue_size_ - 1;
int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
? target_queue_size_
: trigger_queue_size_ - 1;
if (stream->QueueSize() > queue_size) {
kept_timestamp_ = std::max(
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
}
private:
int32 trigger_queue_size_;
int32 target_queue_size_;
int32_t trigger_queue_size_;
int32_t target_queue_size_;
bool fixed_min_size_;
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
// the corresponding call to FillInputSet has not yet completed.

View File

@ -30,15 +30,15 @@ namespace mediapipe {
namespace {
const int64 kMaxPacketId = 100;
const int64 kSlowCalculatorRate = 10;
const int64_t kMaxPacketId = 100;
const int64_t kSlowCalculatorRate = 10;
// Rate limiter for TestSlowCalculator.
ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit);
int64 g_source_counter ABSL_GUARDED_BY(g_source_mutex);
int64_t g_source_counter ABSL_GUARDED_BY(g_source_mutex);
// Rate limiter for TestSourceCalculator.
int64 g_slow_counter ABSL_GUARDED_BY(g_source_mutex);
int64_t g_slow_counter ABSL_GUARDED_BY(g_source_mutex);
// Flag that indicates that the source is done.
bool g_source_done ABSL_GUARDED_BY(g_source_mutex);
@ -47,7 +47,7 @@ class TestSourceCalculator : public CalculatorBase {
public:
TestSourceCalculator() : current_packet_id_(0) {}
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int64>();
cc->Outputs().Index(0).Set<int64_t>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
@ -62,7 +62,7 @@ class TestSourceCalculator : public CalculatorBase {
g_source_done = true;
return tool::StatusStop();
}
cc->Outputs().Index(0).Add(new int64(0), Timestamp(current_packet_id_));
cc->Outputs().Index(0).Add(new int64_t(0), Timestamp(current_packet_id_));
++current_packet_id_;
{
absl::MutexLock lock(&g_source_mutex);
@ -78,7 +78,7 @@ class TestSourceCalculator : public CalculatorBase {
return g_source_counter <= kSlowCalculatorRate * g_slow_counter ||
g_source_counter <= 1;
}
int64 current_packet_id_;
int64_t current_packet_id_;
};
REGISTER_CALCULATOR(TestSourceCalculator);
@ -87,8 +87,8 @@ class TestSlowCalculator : public CalculatorBase {
public:
TestSlowCalculator() = default;
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int64>();
cc->Outputs().Index(0).Set<int64>();
cc->Inputs().Index(0).Set<int64_t>();
cc->Outputs().Index(0).Set<int64_t>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
@ -97,7 +97,7 @@ class TestSlowCalculator : public CalculatorBase {
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
cc->Outputs().Index(0).Add(new int64(0),
cc->Outputs().Index(0).Add(new int64_t(0),
cc->Inputs().Index(0).Value().Timestamp());
{
absl::MutexLock lock(&g_source_mutex);
@ -118,8 +118,9 @@ class TestSlowCalculator : public CalculatorBase {
REGISTER_CALCULATOR(TestSlowCalculator);
// Return the values of the timestamps of a vector of Packets.
static std::vector<int64> TimestampValues(const std::vector<Packet>& packets) {
std::vector<int64> result;
static std::vector<int64_t> TimestampValues(
const std::vector<Packet>& packets) {
std::vector<int64_t> result;
for (const Packet& p : packets) {
result.push_back(p.Timestamp().Value());
}
@ -174,7 +175,7 @@ TEST_P(FixedSizeInputStreamHandlerTest, DropsPackets) {
// consumed. In this way, the TestSlowCalculator consumes and outputs only
// every tenth packet.
EXPECT_EQ(output_packets.size(), 11);
std::vector<int64> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99};
std::vector<int64_t> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99};
EXPECT_THAT(TimestampValues(output_packets),
testing::ContainerEq(expected_ts));
}
@ -344,18 +345,18 @@ TEST_P(FixedSizeInputStreamHandlerTest, LateArrivalDrop) {
if (GetParam()) {
EXPECT_THAT(TimestampValues(output_packets[0]),
testing::ContainerEq(std::vector<int64>{1, 2, 3, 4, 5, 6}));
testing::ContainerEq(std::vector<int64_t>{1, 2, 3, 4, 5, 6}));
EXPECT_THAT(TimestampValues(output_packets[1]),
testing::ContainerEq(std::vector<int64>{3, 4, 5, 6, 7}));
testing::ContainerEq(std::vector<int64_t>{3, 4, 5, 6, 7}));
EXPECT_THAT(TimestampValues(output_packets[2]),
testing::ContainerEq(std::vector<int64>{4, 5, 6, 7}));
testing::ContainerEq(std::vector<int64_t>{4, 5, 6, 7}));
} else {
EXPECT_THAT(TimestampValues(output_packets[0]),
testing::ContainerEq(std::vector<int64>{5, 6}));
testing::ContainerEq(std::vector<int64_t>{5, 6}));
EXPECT_THAT(TimestampValues(output_packets[1]),
testing::ContainerEq(std::vector<int64>{5, 6, 7}));
testing::ContainerEq(std::vector<int64_t>{5, 6, 7}));
EXPECT_THAT(TimestampValues(output_packets[2]),
testing::ContainerEq(std::vector<int64>{5, 6, 7}));
testing::ContainerEq(std::vector<int64_t>{5, 6, 7}));
}
}

View File

@ -27,10 +27,6 @@ namespace options_field_util {
using ::mediapipe::proto_ns::internal::WireFormatLite;
using FieldType = WireFormatLite::FieldType;
using ::mediapipe::proto_ns::io::ArrayInputStream;
using ::mediapipe::proto_ns::io::CodedInputStream;
using ::mediapipe::proto_ns::io::CodedOutputStream;
using ::mediapipe::proto_ns::io::StringOutputStream;
// Utility functions for OptionsFieldUtil.
namespace {

View File

@ -454,8 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// Number of glFinish calls completed on the GL thread.
// Changes should be guarded by mutex_. However, we use simple atomic
// loads for efficiency on the fast path.
std::atomic<int64_t> gl_finish_count_ = ATOMIC_VAR_INIT(0);
std::atomic<int64_t> gl_finish_count_target_ = ATOMIC_VAR_INIT(0);
std::atomic<int64_t> gl_finish_count_ = 0;
std::atomic<int64_t> gl_finish_count_target_ = 0;
GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr;

View File

@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
// TODO: Investigate this option in more detail, esp. on Safari.
attrs.preserveDrawingBuffer = 0;
// Since the Emscripten canvas target finding function is visible from here,
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas
// behavior if the user desires, falling back to the new DOM element CSS
// selector behavior next if that is specified, and finally just allowing the
// lookup to proceed on a null target.
// TODO: Ensure this works with all options (in particular,
// multithreading options, like the special-case combination of USE_PTHREADS
// and OFFSCREEN_FRAMEBUFFER)
// clang-format off
EM_ASM(
let init_once = true;
if (init_once) {
const cachedFindCanvasEventTarget = findCanvasEventTarget;
if (typeof cachedFindCanvasEventTarget !== 'function') {
if (typeof console !== 'undefined') {
console.error('Expected Emscripten global function '
+ '"findCanvasEventTarget" not found. WebGL context creation '
+ 'may fail.');
}
return;
}
findCanvasEventTarget = function(target) {
if (target == 0) {
if (Module && Module.canvas) {
return Module.canvas;
} else if (Module && Module.canvasCssSelector) {
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
}
if (typeof console !== 'undefined') {
console.warn('Module properties canvas and canvasCssSelector not ' +
'found during WebGL context creation.');
}
}
// We still go through with the find attempt, although for most use
// cases it will not succeed, just in case the user does want to fall-
// back.
return cachedFindCanvasEventTarget(target);
}; // NOLINT: Necessary semicolon.
init_once = false;
}
);
// clang-format on
// Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
// looks for our #canvas target in Module.canvas, where we expect it to be.
// -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
// event target behavior, but it was never supposed to be tapping into our
// canvas anyways. See b/278155946 for more background.
EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
emscripten_webgl_create_context(nullptr, &attrs);
emscripten_webgl_create_context("#canvas", &attrs);
// Check for failure
if (context_handle <= 0) {

View File

@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
int actual_ws = image_frame.WidthStep();
int alignment = 0;
std::unique_ptr<ImageFrame> temp;
const uint8* data = image_frame.PixelData();
const uint8_t* data = image_frame.PixelData();
// Let's see if the pixel data is tightly aligned to one of the alignments
// supported by OpenGL, preferring 4 if possible since it's the default.

View File

@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
GpuBufferFormat format) {
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
auto y_data = std::make_unique<uint8[]>(y_stride * height);
auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
switch (fourcc) {
case libyuv::FOURCC_NV12:
case libyuv::FOURCC_NV21: {
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
int uv_width = 2 * std::ceil(0.5f * width);
int uv_height = std::ceil(0.5f * height);
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height);
auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
yuv_image_ = std::make_shared<YUVImage>(
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
nullptr, 0, width, height);
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
int uv_width = std::ceil(0.5f * width);
int uv_height = std::ceil(0.5f * height);
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height);
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height);
auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
yuv_image_ = std::make_shared<YUVImage>(
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
std::move(v_data), uv_stride, width, height);

View File

@ -16,6 +16,7 @@ import csv
import filecmp
import os
import tempfile
import unittest
from unittest import mock as unittest_mock
import tensorflow as tf
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils
@unittest.skip('b/275624089')
class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"],
tags = ["requires-net:external"],
deps = [
":dataset",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
":object_detector_import",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.model_maker.python.vision import object_detector
from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder(
self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir
)
# Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self):
hparams = hyperparameters.HParams(
hparams = object_detector.HParams(
epochs=1,
batch_size=2,
learning_rate=0.9,
shuffle=False,
export_dir=self.create_tempdir(),
)
options = object_detector_options.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
options = object_detector.ObjectDetectorOptions(
supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
)
# Test `create``
model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams(
qat_hparams = object_detector.QATHParams(
learning_rate=0.9,
batch_size=2,
epochs=1,

View File

@ -298,6 +298,7 @@ cc_library(
":tensors_to_objects_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/memory",

View File

@ -91,8 +91,8 @@ absl::Status FrameAnnotationToTimedBoxListCalculator::Process(
TimedBoxProto* added_box = output_objects->add_box();
ComputeBoundingRect(key_points, added_box);
added_box->set_id(annotation.object_id());
const int64 time_msec =
static_cast<int64>(std::round(frame_annotation.timestamp() / 1000));
const int64_t time_msec =
static_cast<int64_t>(std::round(frame_annotation.timestamp() / 1000));
added_box->set_time_msec(time_msec);
}

View File

@ -24,8 +24,8 @@ namespace mediapipe {
void FrameAnnotationTracker::AddDetectionResult(
const FrameAnnotation& frame_annotation) {
const int64 time_us =
static_cast<int64>(std::round(frame_annotation.timestamp()));
const int64_t time_us =
static_cast<int64_t>(std::round(frame_annotation.timestamp()));
for (const auto& object_annotation : frame_annotation.annotations()) {
detected_objects_[time_us + object_annotation.object_id()] =
object_annotation;
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
absl::flat_hash_set<int>* cancel_object_ids) {
CHECK(cancel_object_ids != nullptr);
FrameAnnotation frame_annotation;
std::vector<int64> keys_to_be_deleted;
std::vector<int64_t> keys_to_be_deleted;
for (const auto& detected_obj : detected_objects_) {
const int object_id = detected_obj.second.object_id();
if (cancel_object_ids->contains(object_id)) {

View File

@ -76,7 +76,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase {
// In a single MediaPipe session, the IDs are unique.
// Also assign timestamp for the FrameAnnotation to be the input packet
// timestamp.
void AssignObjectIdAndTimestamp(int64 timestamp_us,
void AssignObjectIdAndTimestamp(int64_t timestamp_us,
FrameAnnotation* annotation);
int num_classes_ = 0;
@ -207,7 +207,7 @@ void TfLiteTensorsToObjectsCalculator::Project3DTo2D(
}
void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp(
int64 timestamp_us, FrameAnnotation* annotation) {
int64_t timestamp_us, FrameAnnotation* annotation) {
for (auto& ann : *annotation->mutable_annotations()) {
ann.set_object_id(GetNextObjectId());
}

View File

@ -37,7 +37,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -157,7 +157,7 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) {
}
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class CreateFromOptionsTest : public tflite::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) {
auto options = std::make_unique<AudioClassifierOptions>();
@ -270,7 +270,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) {
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class ClassifyTest : public tflite_shims::testing::Test {};
class ClassifyTest : public tflite::testing::Test {};
TEST_F(ClassifyTest, Succeeds) {
auto audio_buffer = GetAudioData(k16kTestWavFilename);
@ -467,7 +467,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) {
}
}
class ClassifyAsyncTest : public tflite_shims::testing::Test {};
class ClassifyAsyncTest : public tflite::testing::Test {};
TEST_F(ClassifyAsyncTest, Succeeds) {
constexpr int kSampleRateHz = 48000;

View File

@ -36,7 +36,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -66,7 +66,7 @@ Matrix GetAudioData(absl::string_view filename) {
return matrix_mapping.matrix();
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class CreateFromOptionsTest : public tflite::testing::Test {};
TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
auto audio_embedder =
@ -124,7 +124,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) {
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class EmbedTest : public tflite_shims::testing::Test {};
class EmbedTest : public tflite::testing::Test {};
TEST_F(EmbedTest, SucceedsWithSilentAudio) {
auto options = std::make_unique<AudioEmbedderOptions>();
@ -187,7 +187,7 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) {
MP_EXPECT_OK(audio_embedder->Close());
}
class EmbedAsyncTest : public tflite_shims::testing::Test {
class EmbedAsyncTest : public tflite::testing::Test {
protected:
void RunAudioEmbedderInStreamMode(std::string audio_file_name,
int sample_rate_hz,

View File

@ -47,7 +47,7 @@ cc_test_with_tflite(
data = ["//mediapipe/tasks/testdata/audio:test_models"],
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
deps = [
":audio_tensor_specs",

View File

@ -34,7 +34,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -52,7 +52,7 @@ constexpr char kModelWithMetadata[] =
"yamnet_audio_classifier_with_metadata.tflite";
constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite";
class AudioTensorSpecsTest : public tflite_shims::testing::Test {};
class AudioTensorSpecsTest : public tflite::testing::Test {};
TEST_F(AudioTensorSpecsTest,
BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) {

View File

@ -63,7 +63,7 @@ cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
)
@ -232,6 +232,6 @@ cc_test(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
)

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace {
@ -66,8 +66,7 @@ ClassificationList MakeClassificationList(int class_index) {
class_index));
}
class ClassificationAggregationCalculatorTest
: public tflite_shims::testing::Test {
class ClassificationAggregationCalculatorTest : public tflite::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
bool connect_timestamps = false) {

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace {
@ -52,7 +52,7 @@ constexpr char kTimestampsName[] = "timestamps_in";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
class EmbeddingAggregationCalculatorTest : public tflite::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
Graph graph;

View File

@ -66,7 +66,7 @@ using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();

View File

@ -49,7 +49,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/util/label_map.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -101,7 +101,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
std::move(external_file));
}
class ConfigureTest : public tflite_shims::testing::Test {};
class ConfigureTest : public tflite::testing::Test {};
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(
@ -417,7 +417,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
)pb")));
}
class PostprocessingTest : public tflite_shims::testing::Test {
class PostprocessingTest : public tflite::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::ClassifierOptions& options,
@ -520,7 +520,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
auto poller,
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
tensor[1] = 18;
tensor[2] = 16;
@ -552,7 +552,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
@ -589,7 +589,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
auto poller,
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
std::vector<uint8_t> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
@ -677,11 +677,11 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
/*connect_timestamps=*/true));
// Build input tensors.
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
std::vector<uint8_t> tensor_0(kMobileNetNumClasses, 0);
tensor_0[1] = 12;
tensor_0[2] = 14;
tensor_0[3] = 16;
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
std::vector<uint8_t> tensor_1(kMobileNetNumClasses, 0);
tensor_1[5] = 12;
tensor_1[6] = 14;
tensor_1[7] = 16;

View File

@ -39,7 +39,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -86,7 +86,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
std::move(external_file));
}
class ConfigureTest : public tflite_shims::testing::Test {};
class ConfigureTest : public tflite::testing::Test {};
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
@ -153,7 +153,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
has_quantized_outputs: false)pb")));
}
class PostprocessingTest : public tflite_shims::testing::Test {
class PostprocessingTest : public tflite::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::EmbedderOptions& options,

View File

@ -37,7 +37,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -125,7 +125,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
return TaskRunner::Create(graph.GetConfig());
}
class ConfigureTest : public tflite_shims::testing::Test {};
class ConfigureTest : public tflite::testing::Test {};
TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(

View File

@ -78,6 +78,7 @@ cc_library(
hdrs = ["mediapipe_builtin_op_resolver.h"],
deps = [
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
@ -128,9 +129,9 @@ cc_library_with_tflite(
srcs = ["model_resources.cc"],
hdrs = ["model_resources.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite/core/shims:verifier",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/tools:verifier",
],
deps = [
":external_file_handler",
@ -159,9 +160,9 @@ cc_test_with_tflite(
],
tflite_deps = [
":model_resources",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:test_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
deps = [
":utils",
@ -186,7 +187,7 @@ cc_library_with_tflite(
hdrs = ["model_resources_cache.h"],
tflite_deps = [
":model_resources",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
deps = [
":model_asset_bundle_resources",
@ -233,7 +234,7 @@ cc_test_with_tflite(
":model_resources",
":model_resources_cache",
":model_resources_calculator",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
deps = [
"//mediapipe/framework/port:gtest_main",
@ -284,7 +285,7 @@ cc_test_with_tflite(
":task_runner",
":model_resources",
":model_resources_cache",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
deps = [
"//mediapipe/calculators/core:pass_through_calculator",
@ -317,6 +318,9 @@ cc_library(
":model_resources",
":task_runner",
":utils",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:requires",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
AddCustom("KmeansEmbeddingLookup",
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
// For the UniversalSentenceEncoder model.
AddCustom("TFSentencepieceTokenizeOp",
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
AddCustom("RaggedTensorToTensor",
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
}

View File

@ -37,8 +37,8 @@ limitations under the License.
#include "mediapipe/util/tflite/error_reporter.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/model_builder.h"
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/tools/verifier.h"
namespace mediapipe {
namespace tasks {
@ -52,7 +52,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
bool ModelResources::Verifier::Verify(const char* data, int length,
tflite::ErrorReporter* reporter) {
return tflite_shims::Verify(data, length, reporter);
return tflite::Verify(data, length, reporter);
}
ModelResources::ModelResources(const std::string& tag,
@ -124,7 +124,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
// and that it uses only operators that are supported by the OpResolver
// that was passed to the ModelResources constructor, and then builds
// the model from the buffer.
auto model = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer(
auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
buffer_data, buffer_size, &verifier_, &error_reporter_);
if (model == nullptr) {
static constexpr char kInvalidFlatbufferMessage[] =
@ -151,8 +151,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() {
}
model_packet_ = MakePacket<ModelPtr>(
model.release(),
[](tflite_shims::FlatBufferModel* model) { delete model; });
model.release(), [](tflite::FlatBufferModel* model) { delete model; });
ASSIGN_OR_RETURN(auto model_metadata_extractor,
metadata::ModelMetadataExtractor::CreateFromModelBuffer(
buffer_data, buffer_size));

View File

@ -32,10 +32,10 @@ limitations under the License.
#include "mediapipe/util/tflite/error_reporter.h"
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
#include "tensorflow/lite/core/shims/cc/model.h"
#include "tensorflow/lite/core/shims/cc/model_builder.h"
#include "tensorflow/lite/core/shims/cc/tools/verifier.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/tools/verifier.h"
namespace mediapipe {
namespace tasks {
@ -51,8 +51,8 @@ class ModelResources {
public:
// Represents a TfLite model as a FlatBuffer.
using ModelPtr =
std::unique_ptr<tflite_shims::FlatBufferModel,
std::function<void(tflite_shims::FlatBufferModel*)>>;
std::unique_ptr<tflite::FlatBufferModel,
std::function<void(tflite::FlatBufferModel*)>>;
// Takes the ownership of the provided ExternalFile proto and creates
// ModelResources from the proto and an op resolver object. A non-empty tag
@ -61,7 +61,7 @@ class ModelResources {
static absl::StatusOr<std::unique_ptr<ModelResources>> Create(
const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file,
std::unique_ptr<tflite::OpResolver> op_resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
// Takes the ownership of the provided ExternalFile proto and creates
// ModelResources from the proto and an op resolver mediapipe packet. A

View File

@ -30,7 +30,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -124,7 +124,7 @@ void RunGraphWithGraphService(std::unique_ptr<ModelResources> model_resources,
} // namespace
class ModelResourcesCalculatorTest : public tflite_shims::testing::Test {};
class ModelResourcesCalculatorTest : public tflite::testing::Test {};
TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(

View File

@ -38,9 +38,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/test_util.h"
namespace tflite {
namespace ops {
@ -116,7 +116,7 @@ void CheckModelResourcesPackets(const ModelResources* model_resources) {
} // namespace
class ModelResourcesTest : public tflite_shims::testing::Test {};
class ModelResourcesTest : public tflite::testing::Test {};
TEST_F(ModelResourcesTest, CreateFromBinaryContent) {
auto model_file = std::make_unique<proto::ExternalFile>();
@ -211,7 +211,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsFromFile) {
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
tflite::MutableOpResolver resolver;
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite_shims::ops::builtin::Register_ADD());
::tflite::ops::builtin::Register_ADD());
resolver.AddCustom(kCustomOpName,
::tflite::ops::custom::Register_MY_CUSTOM_OP());
@ -275,7 +275,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsPacket) {
static constexpr char kCustomOpName[] = "MY_CUSTOM_OP";
tflite::MutableOpResolver resolver;
resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
::tflite_shims::ops::builtin::Register_ADD());
::tflite::ops::builtin::Register_ADD());
resolver.AddCustom(kCustomOpName,
::tflite::ops::custom::Register_MY_CUSTOM_OP());

View File

@ -23,7 +23,11 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/port/requires.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
@ -54,6 +58,8 @@ class TaskApiFactory {
std::unique_ptr<tflite::OpResolver> resolver,
PacketsCallback packets_callback = nullptr) {
bool found_task_subgraph = false;
// This for-loop ensures there's only one subgraph besides
// FlowLimiterCalculator.
for (const auto& node : graph_config.node()) {
if (node.calculator() == "FlowLimiterCalculator") {
continue;
@ -64,13 +70,7 @@ class TaskApiFactory {
"Task graph config should only contain one task subgraph node.",
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
} else {
if (!node.options().HasExtension(Options::ext)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat(node.calculator(),
" is missing the required task options field."),
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node));
found_task_subgraph = true;
}
}
@ -80,6 +80,35 @@ class TaskApiFactory {
std::move(packets_callback)));
return std::make_unique<T>(std::move(runner));
}
private:
template <typename Options>
static absl::Status CheckHasValidOptions(
const CalculatorGraphConfig::Node& node) {
if constexpr (mediapipe::Requires<Options>(
[](auto&& o) -> decltype(o.ext) {})) {
if (node.options().HasExtension(Options::ext)) {
return absl::OkStatus();
}
} else {
#ifndef MEDIAPIPE_PROTO_LITE
for (const auto& option : node.node_options()) {
if (absl::StrContains(option.type_url(),
Options::descriptor()->full_name())) {
return absl::OkStatus();
}
}
#else // MEDIAPIPE_PROTO_LITE
// Skip the check for proto lite, as Options::descriptor() is unavailable.
return absl::OkStatus();
#endif // MEDIAPIPE_PROTO_LITE
}
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat(node.calculator(),
" is missing the required task options field."),
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
};
} // namespace core

View File

@ -32,7 +32,7 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -112,7 +112,7 @@ CalculatorGraphConfig GetModelSidePacketsToStreamPacketsGraphConfig(
} // namespace
class TaskRunnerTest : public tflite_shims::testing::Test {};
class TaskRunnerTest : public tflite::testing::Test {};
TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) {
CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(

View File

@ -0,0 +1,172 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
filegroup(
name = "config_fbs",
srcs = ["config.fbs"],
)
flatbuffer_cc_library(
name = "config",
srcs = [
"config.fbs",
],
)
flatbuffer_cc_library(
name = "encoder_config",
srcs = [
"encoder_config.fbs",
],
includes = [":config_fbs"],
)
cc_library(
name = "utils",
hdrs = [
"utils.h",
],
)
cc_library(
name = "double_array_trie",
hdrs = [
"double_array_trie.h",
],
deps = [
":config",
":utils",
],
)
cc_library(
name = "double_array_trie_builder",
srcs = [
"double_array_trie_builder.cc",
],
hdrs = [
"double_array_trie_builder.h",
],
deps = ["@darts_clone"],
)
cc_test(
name = "double_array_trie_test",
srcs = [
"double_array_trie_test.cc",
],
deps = [
":double_array_trie",
":double_array_trie_builder",
":encoder_config",
":utils",
"//mediapipe/framework/port:gtest_main",
],
)
cc_library(
name = "sentencepiece_constants",
hdrs = ["sentencepiece_constants.h"],
)
cc_library(
name = "model_converter",
srcs = [
"model_converter.cc",
],
hdrs = [
"model_converter.h",
],
deps = [
":config",
":double_array_trie_builder",
":encoder_config",
":sentencepiece_constants",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
],
)
cc_library(
name = "optimized_encoder",
srcs = [
"optimized_encoder.cc",
],
hdrs = [
"optimized_encoder.h",
],
deps = [
":double_array_trie",
":encoder_config",
":utils",
],
)
cc_library(
name = "sentencepiece_tokenizer_tflite",
srcs = ["sentencepiece_tokenizer_tflite.cc"],
hdrs = ["sentencepiece_tokenizer_tflite.h"],
visibility = [
"//visibility:public",
],
deps =
[
":optimized_encoder",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
],
)
cc_test(
name = "optimized_encoder_test",
srcs = [
"optimized_encoder_test.cc",
],
data = [
":testdata",
],
deps = [
":double_array_trie_builder",
":encoder_config",
":model_converter",
":optimized_encoder",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@org_tensorflow//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,25 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
namespace mediapipe.tflite_operations.sentencepiece;
table Trie {
nodes: [uint32];
}
enum EncoderVersion: byte {
SENTENCE_PIECE = 0,
}

View File

@ -0,0 +1,111 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
// A trie node specifies a node in the tree, either an intermediate node or
// a leaf node.
// A leaf node contains the id as an int of the string match. This id is encoded
// in the lower 31 bits, thus the number of distinct ids is 2^31.
// An intermediate node has an associated label and an offset to its children.
// The label is encoded in the least significant byte and must match the input
// character during matching.
// A memory mappable trie, compatible with Darts::DoubleArray.
class DoubleArrayTrie {
public:
struct Match {
Match() {}
Match(int id, int match_length) : id(id), match_length(match_length) {}
int id = -1;
int match_length = -1;
bool empty() const { return match_length == -1; }
bool operator==(const Match& m) const {
return m.id == id && m.match_length == match_length;
}
};
// nodes and nodes_length specify the array of the nodes of the trie.
explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
: nodes_(nodes) {}
// Finds matches that are prefixes of a string.
template <typename callback>
void IteratePrefixMatches(const utils::string_view& input,
callback update_fn) const;
// Finds the longest prefix match of a string.
Match LongestPrefixMatch(const utils::string_view& input) const {
Match match;
IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
return match;
}
private:
// Returns whether a node as a leaf as a child.
bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
// Returns a value associated with a node. Available when a node is a leaf.
int value(uint32_t i) const {
return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
}
// Returns a label associated with a node.
// A leaf node will have the MSB set and thus return an invalid label.
int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
// Returns offset to children.
int32_t offset(uint32_t i) const {
const uint32_t node = (*nodes_)[i];
return (node >> 10) << ((node & 0x200) >> 6);
}
const flatbuffers::Vector<uint32_t>* nodes_;
};
template <typename callback>
void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
callback update_fn) const {
if (nodes_->size() == 0) {
return;
}
uint32_t pos = offset(0);
for (int i = 0; i < input.length(); ++i) {
pos ^= static_cast<unsigned char>(input.at(i));
if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
// No match, exit.
return;
}
const bool node_has_leaf = has_leaf(pos);
pos ^= offset(pos);
if (pos < 0 || pos >= nodes_->size()) {
// We can get here only if the structure is corrupted.
return;
}
if (node_has_leaf) {
update_fn(Match(value(pos), i + 1));
}
}
}
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_

View File

@ -0,0 +1,75 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include <algorithm>
#include <memory>
#include "include/darts.h"
namespace mediapipe::tflite_operations::sentencepiece {
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) {
std::vector<int> ids;
ids.reserve(data.size());
for (int i = 0; i < data.size(); ++i) {
ids.push_back(i);
}
return BuildTrie(data, ids);
}
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
const std::vector<int>& ids) {
// We make strong assumptions about binary structure of trie.
struct OneElement {
OneElement(const std::string* key_, int index_)
: key(key_), index(index_) {}
const std::string* key;
int index;
bool operator<(const OneElement& el) const { return *key < *el.key; }
};
std::vector<OneElement> elements;
elements.reserve(data.size());
auto data_iterator = std::begin(data);
auto ids_iterator = std::begin(ids);
for (; data_iterator != std::end(data) && ids_iterator != std::end(ids);
++data_iterator, ++ids_iterator) {
elements.emplace_back(&(*data_iterator), *ids_iterator);
}
// Sort by keys.
std::sort(elements.begin(), elements.end());
// Create vectors to build the trie.
std::vector<const char*> strings;
std::vector<int32_t> indexes;
strings.reserve(data.size());
indexes.reserve(data.size());
for (const auto& el : elements) {
strings.push_back(el.key->c_str());
indexes.push_back(el.index);
}
auto trie = std::make_unique<Darts::DoubleArray>();
trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr,
&indexes[0]);
// We make strong assumptions about internal Darts trie structure:
// - it is a vector of 32 bit signed integers
// - the "array" is the only one structure that contains all information about
// the trie.
const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array());
return std::vector<uint32_t>(trie_data, trie_data + trie->size());
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,32 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
#include <string>
#include <vector>
namespace mediapipe::tflite_operations::sentencepiece {
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
const std::vector<int>& ids);
// A variant where ids are indexes in data.
std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_

View File

@ -0,0 +1,73 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
TEST(DoubleArrayTrieTest, Match) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"};
const auto trie_vector = builder.CreateVector(BuildTrie(test_strings));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto pieces = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_pieces(pieces);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
DoubleArrayTrie dat(config->pieces()->nodes());
EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")),
DoubleArrayTrie::Match(2, 2));
std::vector<DoubleArrayTrie::Match> matches;
dat.IteratePrefixMatches(
utils::string_view("AAXL"),
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1),
DoubleArrayTrie::Match(2, 2),
DoubleArrayTrie::Match(1, 3)));
}
TEST(DoubleArrayTrieTest, ComplexMatch) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s",
"\xe2\x96\x81Hello"};
const std::vector<int> test_ids = {0, 5, 10, 15};
const auto trie_vector =
builder.CreateVector(BuildTrie(test_strings, test_ids));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto pieces = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_pieces(pieces);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
DoubleArrayTrie dat(config->pieces()->nodes());
std::vector<DoubleArrayTrie::Match> matches;
dat.IteratePrefixMatches(
utils::string_view("\xe2\x96\x81Hello"),
[&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8)));
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,52 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
include "config.fbs";
namespace mediapipe.tflite_operations.sentencepiece;
table EncoderConfig {
// Version of the encoder.
version: EncoderVersion = SENTENCE_PIECE;
start_code: int32 = 0;
end_code: int32 = 0;
unknown_code: int32 = -1;
// Weight of "unknown code" when encoding. "Penalty" because it usually has a
// big negative weight,less than any other sentencepiece.
unknown_penalty: float = 0;
// The offset for encoding, usually used when codes with low codes are reserved
// for some special needs.
encoding_offset: int32;
// String pieces for encoding.
pieces: Trie;
pieces_scores: [float];
// Normalization related parameters.
remove_extra_whitespaces: bool;
// Add a whitespace prefix before encoding.
add_dummy_prefix: bool;
// Escape whitespaces during encoding so the decoder can restore them exactly as
// in the input.
escape_whitespaces: bool;
// Normalization parameters.
normalized_prefixes: Trie;
normalized_replacements: [byte];
}
root_type EncoderConfig;

View File

@ -0,0 +1,131 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
#include "src/sentencepiece_model.pb.h"
namespace mediapipe::tflite_operations::sentencepiece {
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
DecodePrecompiledCharsmap(
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
// This function "undoes" encoding done by
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
const uint32_t trie_size =
*reinterpret_cast<const uint32_t*>(precompiled_map);
const uint32_t* trie_ptr =
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
precompiled_map + sizeof(uint32_t) + trie_size);
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
sizeof(uint32_t) - trie_size;
return std::make_tuple(
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
}
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset) {
::sentencepiece::ModelProto model_config;
if (!model_config.ParseFromString(model_config_str)) {
return absl::InvalidArgumentError(
"Invalid configuration, can't parse SentencePiece model config " +
model_config.InitializationErrorString());
}
// Convert sentencepieces.
std::vector<std::string> pieces;
pieces.reserve(model_config.pieces_size());
std::vector<float> scores;
scores.reserve(model_config.pieces_size());
std::vector<int> ids;
ids.reserve(model_config.pieces_size());
float min_score = 0.0;
int index = 0;
for (const auto& piece : model_config.pieces()) {
switch (piece.type()) {
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
pieces.push_back(piece.piece());
ids.push_back(index);
if (piece.score() < min_score) {
min_score = piece.score();
}
break;
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
// Ignore unknown and control codes.
break;
default:
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
piece.piece());
}
scores.push_back(piece.score());
++index;
}
flatbuffers::FlatBufferBuilder builder(1024);
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
const auto pieces_score_vector = builder.CreateVector(scores);
TrieBuilder pieces_trie_builder(builder);
pieces_trie_builder.add_nodes(pieces_trie_vector);
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
// Converting normalization.
const auto normalization =
DecodePrecompiledCharsmap(model_config.normalizer_spec());
const auto normalization_trie = std::get<0>(normalization);
const auto normalization_strings = std::get<1>(normalization);
const auto normalization_trie_vector =
builder.CreateVector(normalization_trie);
TrieBuilder normalization_trie_builder(builder);
normalization_trie_builder.add_nodes(normalization_trie_vector);
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
const auto normalization_strings_fbs =
builder.CreateVector(normalization_strings);
EncoderConfigBuilder ecb(builder);
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
ecb.add_start_code(model_config.trainer_spec().bos_id());
ecb.add_end_code(model_config.trainer_spec().eos_id());
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
ecb.add_unknown_penalty(min_score - kUnkPenalty);
ecb.add_encoding_offset(encoding_offset);
ecb.add_pieces(pieces_trie_fbs);
ecb.add_pieces_scores(pieces_score_vector);
ecb.add_remove_extra_whitespaces(
model_config.normalizer_spec().remove_extra_whitespaces());
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
ecb.add_escape_whitespaces(
model_config.normalizer_spec().escape_whitespaces());
ecb.add_normalized_prefixes(normalization_trie_fbs);
ecb.add_normalized_replacements(normalization_strings_fbs);
FinishEncoderConfigBuffer(builder, ecb.Finish());
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
std::string ConvertSentencepieceModel(const std::string& model_string) {
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
assert(result.status().ok());
return result.value();
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,33 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#include <string>
#include "absl/status/statusor.h"
namespace mediapipe::tflite_operations::sentencepiece {
// Converts Sentencepiece configuration to flatbuffer format.
// encoding_offset is used by some encoders that combine different encodings.
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset = 0);
std::string ConvertSentencepieceModel(const std::string& model_string);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_

View File

@ -0,0 +1,236 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <algorithm>
#include <tuple>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace {
const char kSpaceSymbol[] = "\xe2\x96\x81";
template <typename processing_callback>
std::tuple<std::string, std::vector<int>> process_string(
const std::string& input, const std::vector<int>& offsets,
const processing_callback& pc) {
std::string result_string;
result_string.reserve(input.size());
std::vector<int> result_offsets;
result_offsets.reserve(offsets.size());
for (int i = 0, j = 0; i < input.size();) {
auto result = pc(input.data() + i, input.size() - i);
auto consumed = std::get<0>(result);
auto new_string = std::get<1>(result);
if (consumed == 0) {
// Skip the current byte and move forward.
result_string.push_back(input[i]);
result_offsets.push_back(offsets[j]);
i++;
j++;
continue;
}
result_string.append(new_string.data(), new_string.length());
for (int i = 0; i < new_string.length(); ++i) {
result_offsets.push_back(offsets[j]);
}
j += consumed;
i += consumed;
}
return std::make_tuple(result_string, result_offsets);
}
inline char is_whitespace(char c) {
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
}
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
int len) {
if (len == 0 || !is_whitespace(*data)) {
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
int num_consumed = 1;
for (; num_consumed < len && is_whitespace(data[num_consumed]);
++num_consumed) {
}
return num_consumed > 1
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
: std::make_tuple(0, utils::string_view(nullptr, 0));
}
std::tuple<int, utils::string_view> find_replacement(
const char* data, int len, const DoubleArrayTrie& dat,
const flatbuffers::Vector<int8_t>& replacements) {
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
if (!max_match.empty()) {
// Because flatbuffer byte is signed char which is not the same as char,
// there is the reinterpret_cast here.
const char* replaced_string_ptr =
reinterpret_cast<const char*>(replacements.data() + max_match.id);
return std::make_tuple(max_match.match_length,
utils::string_view(replaced_string_ptr));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
} // namespace
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config) {
std::vector<int> output_offsets;
std::string result = in_string;
output_offsets.reserve(in_string.length());
for (int i = 0; i < in_string.length(); ++i) {
output_offsets.push_back(i);
}
if (in_string.empty()) {
return std::make_tuple(result, output_offsets);
}
if (config.add_dummy_prefix()) {
result.insert(result.begin(), ' ');
output_offsets.insert(output_offsets.begin(), 0);
}
// Greedely replace normalized_prefixes with normalized_replacements
if (config.normalized_prefixes() != nullptr &&
config.normalized_replacements() != nullptr) {
const DoubleArrayTrie normalized_prefixes_matcher(
config.normalized_prefixes()->nodes());
const auto norm_replace = [&config, &normalized_prefixes_matcher](
const char* data, int len) {
return find_replacement(data, len, normalized_prefixes_matcher,
*config.normalized_replacements());
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, norm_replace);
}
if (config.remove_extra_whitespaces()) {
std::tie(result, output_offsets) =
process_string(result, output_offsets, remove_extra_whitespaces);
if (!result.empty() && is_whitespace(result.back())) {
result.pop_back();
output_offsets.pop_back();
}
}
if (config.escape_whitespaces()) {
const auto replace_whitespaces = [](const char* data, int len) {
if (len > 0 && is_whitespace(*data)) {
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, replace_whitespaces);
}
return std::make_tuple(result, output_offsets);
}
EncoderResult EncodeNormalizedString(const std::string& str,
const std::vector<int>& offsets,
const EncoderConfig& config, bool add_bos,
bool add_eos, bool reverse) {
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
const int unknown_code = config.unknown_code();
const float unknown_penalty = config.unknown_penalty();
struct LatticeElement {
float score = 0;
int code = -1;
int prev_position = -1;
LatticeElement(float score_, int code_, int prev_position_)
: score(score_), code(code_), prev_position(prev_position_) {}
LatticeElement() {}
};
const int length = str.length();
std::vector<LatticeElement> lattice(length + 1);
for (int i = 0; i < length; ++i) {
if (i > 0 && lattice[i].prev_position < 0) {
// This state is unreachable.
continue;
}
if (unknown_code >= 0) {
// Put unknown code.
const float penalized_score = lattice[i].score + unknown_penalty;
const int pos = i + 1;
LatticeElement& current_element = lattice[pos];
if (current_element.prev_position < 0 ||
current_element.score < penalized_score) {
current_element = LatticeElement(
penalized_score, unknown_code,
// If the current state is already reached by unknown code, merge
// states.
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
}
}
auto lattice_update = [&lattice, i,
piece_scores](const DoubleArrayTrie::Match& m) {
LatticeElement& target_element = lattice[i + m.match_length];
const float score = lattice[i].score + (*piece_scores)[m.id];
if (target_element.prev_position < 0 || target_element.score < score) {
target_element = LatticeElement(score, m.id, i);
}
};
piece_matcher.IteratePrefixMatches(
utils::string_view(str.data() + i, length - i), lattice_update);
}
EncoderResult result;
if (add_eos) {
result.codes.push_back(config.end_code());
result.offsets.push_back(length);
}
if (lattice[length].prev_position >= 0) {
for (int pos = length; pos > 0;) {
auto code = lattice[pos].code;
if (code != config.unknown_code()) {
code += config.encoding_offset();
}
result.codes.push_back(code);
pos = lattice[pos].prev_position;
result.offsets.push_back(offsets[pos]);
}
}
if (add_bos) {
result.codes.push_back(config.start_code());
result.offsets.push_back(0);
}
if (!reverse) {
std::reverse(result.codes.begin(), result.codes.end());
std::reverse(result.offsets.begin(), result.offsets.end());
}
return result;
}
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse) {
// Get the config from the buffer.
const EncoderConfig* config = GetEncoderConfig(config_buffer);
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
EncoderResult result;
result.type = EncoderResultType::WRONG_CONFIG;
return result;
}
std::string normalized_string;
std::vector<int> offsets;
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
add_eos, reverse);
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
// Sentencepiece encoder optimized with memmapped model.
#include <string>
#include <tuple>
#include <vector>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
namespace mediapipe::tflite_operations::sentencepiece {
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
struct EncoderResult {
EncoderResultType type = EncoderResultType::SUCCESS;
std::vector<int> codes;
std::vector<int> offsets;
};
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config);
// Encodes one string and returns ids and offsets. Takes the configuration as a
// type-erased buffer.
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_

View File

@ -0,0 +1,171 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <fstream>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "src/sentencepiece.pb.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/core/platform/env.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace internal {
tensorflow::Status TFReadFileToString(const std::string& filepath,
std::string* data) {
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
data);
}
absl::Status StdReadFileToString(const std::string& filepath,
std::string* data) {
std::ifstream infile(filepath);
if (!infile.is_open()) {
return absl::NotFoundError(
absl::StrFormat("Error when opening %s", filepath));
}
std::string contents((std::istreambuf_iterator<char>(infile)),
(std::istreambuf_iterator<char>()));
data->append(contents);
infile.close();
return absl::OkStatus();
}
} // namespace internal
namespace {
using ::mediapipe::file::JoinPath;
static char kConfigFilePath[] =
"/mediapipe/tasks/cc/text/custom_ops/"
"sentencepiece/testdata/sentencepiece.model";
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
flatbuffers::FlatBufferBuilder builder(1024);
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_add_dummy_prefix(true);
ecb.add_escape_whitespaces(true);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("x y", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
}
{
const auto result = NormalizeString("\tx y\n", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
}
}
TEST(OptimizedEncoder, NormalizeStringReplacement) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
const char norm_replacements[] = "A1\0A2\0A3\0A4";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(false);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("ABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
}
}
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
"X"};
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, " A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
}
}
TEST(OptimizedEncoder, ConfigConverter) {
std::string config;
auto status =
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
ASSERT_TRUE(status.ok());
::sentencepiece::SentencePieceProcessor processor;
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
const auto converted_model = ConvertSentencepieceModel(config);
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
const auto encoded =
EncodeString(test_string, converted_model.data(), false, false, false);
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
::sentencepiece::SentencePieceText reference_encoded;
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
for (int i = 0; i < encoded.codes.size(); ++i) {
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
}
}
} // namespace
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
namespace mediapipe::tflite_operations::sentencepiece {
// The constant is copied from
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
constexpr float kUnkPenalty = 10.0;
// These constants are copied from
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
//
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe::tflite_operations {
namespace sentencepiece::tokenizer {
namespace {
using ::tflite::SetTensorToDynamic;
constexpr int kSPModelIndex = 0;
constexpr int kInputIndex = 1;
constexpr int kAddBOSInput = 4;
constexpr int kAddEOSInput = 5;
constexpr int kReverseInput = 6;
constexpr int kOutputValuesInd = 0;
constexpr int kOutputSplitsInd = 1;
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
int index = 0;
for (const int size : sizes) {
array_size->data[index++] = size;
}
return array_size;
}
} // namespace
// Initializes text encoder object from serialized parameters.
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO: Add checks for input and output tensors.
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
SetTensorToDynamic(&output_values);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
SetTensorToDynamic(&output_splits);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor& model_tensor =
context->tensors[node->inputs->data[kSPModelIndex]];
const auto model_buffer_data = model_tensor.data.data;
const TfLiteTensor& input_text =
context->tensors[node->inputs->data[kInputIndex]];
const TfLiteTensor add_bos_tensor =
context->tensors[node->inputs->data[kAddBOSInput]];
const bool add_bos = add_bos_tensor.data.b[0];
const TfLiteTensor add_eos_tensor =
context->tensors[node->inputs->data[kAddEOSInput]];
const bool add_eos = add_eos_tensor.data.b[0];
const TfLiteTensor reverse_tensor =
context->tensors[node->inputs->data[kReverseInput]];
const bool reverse = reverse_tensor.data.b[0];
std::vector<int32> encoded;
std::vector<int32> splits;
const int num_strings = tflite::GetStringCount(&input_text);
for (int i = 0; i < num_strings; ++i) {
const auto strref = tflite::GetString(&input_text, i);
const auto res = EncodeString(std::string(strref.str, strref.len),
model_buffer_data, add_bos, add_eos, reverse);
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
"Sentencepiece conversion failed");
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
splits.emplace_back(encoded.size());
}
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(
context, &output_values,
CreateSizeArray({static_cast<int>(encoded.size())})));
int32_t* output_values_flat = output_values.data.i32;
std::copy(encoded.begin(), encoded.end(), output_values_flat);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, &output_splits,
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
int32_t* output_splits_flat = output_splits.data.i32;
*output_splits_flat = 0;
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
return kTfLiteOk;
}
} // namespace sentencepiece::tokenizer
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
static TfLiteRegistration r = {
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
return &r;
}
} // namespace mediapipe::tflite_operations

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe::tflite_operations {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
} // namespace mediapipe::tflite_operations
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
#include <ostream>
#include <string>
namespace mediapipe::tflite_operations::sentencepiece {
// AOSP and WASM doesn't support string_view,
// we put here a minimal re-implementation.
namespace utils {
class string_view {
public:
explicit string_view(const std::string& s)
: str_(s.data()), len_(s.length()) {}
string_view(const char* str, int len) : str_(str), len_(len) {}
// A constructor from c string.
explicit string_view(const char* s) : str_(s), len_(strlen(s)) {}
int length() const { return len_; }
const char* data() const { return str_; }
bool empty() const { return len_ == 0; }
unsigned char at(int i) const { return str_[i]; }
private:
const char* str_ = nullptr;
const int len_ = 0;
};
inline std::ostream& operator<<(std::ostream& os, const string_view& sv) {
os << std::string(sv.data(), sv.length());
return os;
}
inline bool operator==(const string_view& view1, const string_view& view2) {
if (view1.length() != view2.length()) {
return false;
}
return memcmp(view1.data(), view2.data(), view1.length()) == 0;
}
} // namespace utils
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_

View File

@ -32,7 +32,7 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe::tasks::text::language_detector {
namespace {
@ -75,7 +75,7 @@ absl::Status MatchesLanguageDetectorResult(
} // namespace
class LanguageDetectorTest : public tflite_shims::testing::Test {};
class LanguageDetectorTest : public tflite::testing::Test {};
TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) {
auto options = std::make_unique<LanguageDetectorOptions>();

View File

@ -89,7 +89,7 @@ cc_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
)

View File

@ -36,7 +36,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe::tasks::text::text_classifier {
namespace {
@ -87,7 +87,7 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual,
} // namespace
class TextClassifierTest : public tflite_shims::testing::Test {};
class TextClassifierTest : public tflite::testing::Test {};
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
auto options = std::make_unique<TextClassifierOptions>();

View File

@ -91,6 +91,6 @@ cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
)

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/embedding_result.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe::tasks::text::text_embedder {
namespace {
@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
// Embedding model with regex preprocessing.
constexpr char kRegexOneEmbeddingModel[] =
"regex_one_embedding_with_metadata.tflite";
constexpr char kUniversalSentenceEncoderModel[] =
"universal_sentence_encoder_qa_with_metadata.tflite";
// Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4;
@ -49,7 +51,7 @@ using ::mediapipe::file::JoinPath;
using ::testing::HasSubstr;
using ::testing::Optional;
class EmbedderTest : public tflite_shims::testing::Test {};
class EmbedderTest : public tflite::testing::Test {};
TEST_F(EmbedderTest, FailsWithMissingModel) {
auto text_embedder =
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result0,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result0.embeddings.size(), 1);
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("When you go to this restaurant, they hold the "
"pancake upside-down before they hand it "
"to you. It's a great gimmick."));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result1,
text_embedder->Embed(
"Let's make a plan to steal the declaration of independence."));
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -81,6 +81,6 @@ cc_test(
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
"@org_tensorflow//tensorflow/lite:test_util",
],
)

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe::tasks::text::utils {
@ -76,7 +76,7 @@ absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile(
} // namespace
class TextModelUtilsTest : public tflite_shims::testing::Test {};
class TextModelUtilsTest : public tflite::testing::Test {};
TEST_F(TextModelUtilsTest, BertClassifierModelTest) {
MP_ASSERT_OK_AND_ASSIGN(auto model_type,

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -105,7 +105,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class FaceBlendshapesTest : public tflite_shims::testing::Test {};
class FaceBlendshapesTest : public tflite::testing::Test {};
TEST_F(FaceBlendshapesTest, SmokeTest) {
// Prepare graph inputs.

View File

@ -67,7 +67,7 @@ constexpr char kPortraitExpectedFaceLandmarksName[] =
"portrait_expected_face_landmarks.pbtxt";
constexpr char kPortraitExpectedBlendshapesName[] =
"portrait_expected_blendshapes.pbtxt";
constexpr char kPortaitExpectedFaceGeomertyName[] =
constexpr char kPortraitExpectedFaceGeometryName[] =
"portrait_expected_face_geometry.pbtxt";
constexpr float kLandmarksDiffMargin = 0.03;
@ -100,7 +100,7 @@ struct FaceLandmarkerTestParams {
mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() {
auto face_geometry = GetExpectedProto<face_geometry::proto::FaceGeometry>(
kPortaitExpectedFaceGeomertyName);
kPortraitExpectedFaceGeometryName);
return face_geometry.pose_transform_matrix();
}

View File

@ -23,18 +23,12 @@ cc_library(
srcs = ["face_stylizer_graph.cc"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
"//mediapipe/calculators/image:warp_affine_calculator",
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:face_to_rect_calculator",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:inverse_matrix_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
@ -53,7 +47,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",

View File

@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The input image can be of any size with format RGB or RGBA.
// When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the stylized
// output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
// face. The stylized output image size is the same as the model output size.
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing.
// When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the stylized
// output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
// face. The stylized output image size is the same as the model output size.
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The "result_callback" provides:
// - When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the
// stylized output image size is the smaller of the model output size and
// the size of the 'region_of_interest' specified in
// 'image_processing_options'.
// face. The stylized output image size is the same as the model output
// size.
// - The input timestamp in milliseconds.
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions>

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
image_in >> preprocessing.In(kImageTag);
face_rect >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto transform_matrix = preprocessing.Out(kMatrixTag);
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
model_output_tensors >> tensors_to_image.In(kTensorsTag);
auto tensor_image = tensors_to_image.Out(kImageTag);
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
transform_matrix >> inverse_matrix.In(kMatrixTag);
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
auto& image_converter = graph.AddNode("ImageCloneCalculator");
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
.set_output_on_gpu(false);
tensor_image >> image_converter.In("");
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
auto& warp_affine_options =
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
warp_affine_options.set_border_mode(
WarpAffineCalculatorOptions::BORDER_ZERO);
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
tensor_image >> warp_affine.In(kImageTag);
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
image_size >> warp_affine.In(kOutputSizeTag);
auto image_to_crop = warp_affine.Out(kImageTag);
// The following calculators are for cropping and resizing the output image
// based on the roi and the model output size. As the WarpAffineCalculator
// rotates the image based on the transform matrix, the rotation info in the
// rect proto is stripped to prevent the ImageCroppingCalculator from
// performing extra rotation.
auto& strip_rotation =
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
face_rect >> strip_rotation.In(kNormRectTag);
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
auto& from_image = graph.AddNode("FromImageCalculator");
image_to_crop >> from_image.In(kImageTag);
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
auto& image_cropping_opts =
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
image_cropping_opts.set_output_max_width(
image_to_tensor_options.output_tensor_width());
image_cropping_opts.set_output_max_height(
image_to_tensor_options.output_tensor_height());
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
auto& to_image = graph.AddNode("ToImageCalculator");
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
// graph selects its cpu or gpu path based on the image preprocessing
// backend.
if (use_gpu) {
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
} else {
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
}
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
}
};

View File

@ -43,7 +43,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -137,7 +137,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>());
}
class HandLandmarkerTest : public tflite_shims::testing::Test {};
class HandLandmarkerTest : public tflite::testing::Test {};
TEST_F(HandLandmarkerTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(

View File

@ -41,7 +41,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {

View File

@ -59,7 +59,6 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::utils::AllowIf;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarksDetectorGraphOptions;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;

View File

@ -146,7 +146,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner(
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
}
// Helper function to create a Multi Hand Landmark TaskRunner.
@ -188,7 +188,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner(
return TaskRunner::Create(
graph.GetConfig(),
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
}
NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) {

View File

@ -39,9 +39,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -148,7 +148,7 @@ class MobileNetQuantizedOpResolverMissingOps
const MobileNetQuantizedOpResolverMissingOps& r) = delete;
};
class CreateTest : public tflite_shims::testing::Test {};
class CreateTest : public tflite::testing::Test {};
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
auto options = std::make_unique<ImageClassifierOptions>();
@ -265,7 +265,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class ImageModeTest : public tflite_shims::testing::Test {};
class ImageModeTest : public tflite::testing::Test {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
@ -605,7 +605,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
class VideoModeTest : public tflite_shims::testing::Test {};
class VideoModeTest : public tflite::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
@ -707,7 +707,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK(image_classifier->Close());
}
class LiveStreamModeTest : public tflite_shims::testing::Test {};
class LiveStreamModeTest : public tflite::testing::Test {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(

View File

@ -30,9 +30,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -103,7 +103,7 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver {
delete;
};
class CreateTest : public tflite_shims::testing::Test {};
class CreateTest : public tflite::testing::Test {};
TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) {
auto options = std::make_unique<ImageEmbedderOptions>();
@ -181,7 +181,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) {
MediaPipeTasksStatus::kInvalidTaskGraphConfigError))));
}
class ImageModeTest : public tflite_shims::testing::Test {};
class ImageModeTest : public tflite::testing::Test {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
@ -410,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
class VideoModeTest : public tflite_shims::testing::Test {};
class VideoModeTest : public tflite::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
@ -494,7 +494,7 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK(image_embedder->Close());
}
class LiveStreamModeTest : public tflite_shims::testing::Test {};
class LiveStreamModeTest : public tflite::testing::Test {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(

View File

@ -39,9 +39,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
@ -180,7 +180,7 @@ class DeepLabOpResolver : public ::tflite::MutableOpResolver {
DeepLabOpResolver(const DeepLabOpResolver& r) = delete;
};
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class CreateFromOptionsTest : public tflite::testing::Test {};
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
public:
@ -268,7 +268,7 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) {
}
}
class ImageModeTest : public tflite_shims::testing::Test {};
class ImageModeTest : public tflite::testing::Test {};
TEST_F(ImageModeTest, SucceedsWithCategoryMask) {
MP_ASSERT_OK_AND_ASSIGN(
@ -521,7 +521,7 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
class VideoModeTest : public tflite_shims::testing::Test {};
class VideoModeTest : public tflite::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(
@ -581,7 +581,7 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK(segmenter->Close());
}
class LiveStreamModeTest : public tflite_shims::testing::Test {};
class LiveStreamModeTest : public tflite::testing::Test {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(

View File

@ -64,7 +64,6 @@ using ::mediapipe::CalculatorGraphConfig;
using ::mediapipe::Image;
using ::mediapipe::NormalizedRect;
using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult;
using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions;
using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision::
image_segmenter::proto::ImageSegmenterGraphOptions;

View File

@ -39,9 +39,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/test_util.h"
#include "testing/base/public/gmock.h"
namespace mediapipe {
@ -124,7 +124,7 @@ MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold,
similarity_threshold;
}
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class CreateFromOptionsTest : public tflite::testing::Test {};
class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver {
public:
@ -261,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; });
class ImageModeTest : public tflite_shims::testing::Test {};
class ImageModeTest : public tflite::testing::Test {};
// TODO: fix this unit test after image segmenter handled post
// processing correctly with rotated image.

View File

@ -43,9 +43,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/test_util.h"
namespace tflite {
namespace ops {
@ -159,7 +159,7 @@ class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver {
MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete;
};
class CreateFromOptionsTest : public tflite_shims::testing::Test {};
class CreateFromOptionsTest : public tflite::testing::Test {};
TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) {
auto options = std::make_unique<ObjectDetectorOptions>();
@ -332,7 +332,7 @@ TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) {
// TODO: Add NumThreadsTest back after having an
// "acceleration configuration" field in the ObjectDetectorOptions.
class ImageModeTest : public tflite_shims::testing::Test {};
class ImageModeTest : public tflite::testing::Test {};
TEST_F(ImageModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
@ -618,7 +618,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
class VideoModeTest : public tflite_shims::testing::Test {};
class VideoModeTest : public tflite::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
@ -673,7 +673,7 @@ TEST_F(VideoModeTest, Succeeds) {
MP_ASSERT_OK(object_detector->Close());
}
class LiveStreamModeTest : public tflite_shims::testing::Test {};
class LiveStreamModeTest : public tflite::testing::Test {};
TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(

View File

@ -97,8 +97,10 @@ cc_library(
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(

View File

@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
// limit the number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<PoseLandmarkerGraphOptionsProto> options,
bool enable_flow_limiting) {
bool enable_flow_limiting, bool output_segmentation_masks) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName);
subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >>
graph.Out(kNormLandmarksTag);
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig(
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
graph.Out(kPoseAuxiliaryLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (output_segmentation_masks) {
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
graph.Out(kSegmentationMaskTag);
}
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag);
@ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
PoseLandmarkerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
options->running_mode == core::RunningMode::LIVE_STREAM,
options->output_segmentation_masks),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback))));

View File

@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists;
Source<std::vector<NormalizedRect>> pose_rects_next_frame;
Source<std::vector<Detection>> pose_detections;
Source<std::vector<Image>> segmentation_masks;
std::optional<Source<std::vector<Image>>> segmentation_masks;
Source<Image> image;
};
@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// input_stream: "IMAGE:image_in"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "NORM_LANDMARKS:pose_landmarks"
// output_stream: "LANDMARKS:world_landmarks"
// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
// output_stream: "WORLD_LANDMARKS:world_landmarks"
// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
// output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
// output_stream: "POSE_RECTS:pose_rects"
// output_stream: "SEGMENTATION_MASK:segmentation_masks"
@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
bool output_segmentation_masks =
HasOutput(sc->OriginalNode(), kSegmentationMaskTag);
if (sc->Options<PoseLandmarkerGraphOptions>()
.base_options()
.has_model_asset()) {
@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(
auto outs,
BuildPoseLandmarkerGraph(
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
ASSIGN_OR_RETURN(auto outs,
BuildPoseLandmarkerGraph(
*sc->MutableOptions<PoseLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)],
graph, output_segmentation_masks));
outs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)];
outs.world_landmark_lists >>
@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
kAuxiliaryLandmarksTag)];
outs.pose_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)];
outs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
outs.pose_detections >>
graph[Output<std::vector<Detection>>(kDetectionsTag)];
outs.image >> graph[Output<Image>(kImageTag)];
if (outs.segmentation_masks) {
*outs.segmentation_masks >>
graph[Output<std::vector<Image>>(kSegmentationMaskTag)];
}
// TODO remove when support is fixed.
// As mediapipe GraphBuilder currently doesn't support configuring
@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
// graph: the mediapipe graph instance to be updated.
absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph(
PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
Source<NormalizedRect> norm_rect_in, Graph& graph,
bool output_segmentation_masks) {
const int max_num_poses =
tasks_options.pose_detector_graph_options().num_poses();
@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph {
auto pose_rects_for_next_frame =
pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag)
.Cast<std::vector<NormalizedRect>>();
auto segmentation_masks =
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
.Cast<std::vector<Image>>();
std::optional<Source<std::vector<Image>>> segmentation_masks;
if (output_segmentation_masks) {
segmentation_masks =
pose_landmarks_detector_graph.Out(kSegmentationMaskTag)
.Cast<std::vector<Image>>();
}
if (tasks_options.base_options().use_stream_mode()) {
auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator");

Some files were not shown because too many files have changed in this diff Show More