Internal change
PiperOrigin-RevId: 525660743
This commit is contained in:
parent
44aa607e06
commit
331692577e
|
@ -61,12 +61,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";
|
||||||
// overload GPU/TPU/...
|
// overload GPU/TPU/...
|
||||||
class SimpleSemaphore {
|
class SimpleSemaphore {
|
||||||
public:
|
public:
|
||||||
explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {}
|
explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {}
|
||||||
SimpleSemaphore(const SimpleSemaphore&) = delete;
|
SimpleSemaphore(const SimpleSemaphore&) = delete;
|
||||||
SimpleSemaphore(SimpleSemaphore&&) = delete;
|
SimpleSemaphore(SimpleSemaphore&&) = delete;
|
||||||
|
|
||||||
// Acquires the semaphore by certain amount.
|
// Acquires the semaphore by certain amount.
|
||||||
void Acquire(uint32 amount) {
|
void Acquire(uint32_t amount) {
|
||||||
mutex_.Lock();
|
mutex_.Lock();
|
||||||
while (count_ < amount) {
|
while (count_ < amount) {
|
||||||
cond_.Wait(&mutex_);
|
cond_.Wait(&mutex_);
|
||||||
|
@ -76,7 +76,7 @@ class SimpleSemaphore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Releases the semaphore by certain amount.
|
// Releases the semaphore by certain amount.
|
||||||
void Release(uint32 amount) {
|
void Release(uint32_t amount) {
|
||||||
mutex_.Lock();
|
mutex_.Lock();
|
||||||
count_ += amount;
|
count_ += amount;
|
||||||
cond_.SignalAll();
|
cond_.SignalAll();
|
||||||
|
@ -84,7 +84,7 @@ class SimpleSemaphore {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint32 count_;
|
uint32_t count_;
|
||||||
absl::Mutex mutex_;
|
absl::Mutex mutex_;
|
||||||
absl::CondVar cond_;
|
absl::CondVar cond_;
|
||||||
};
|
};
|
||||||
|
@ -488,7 +488,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
// necessary.
|
// necessary.
|
||||||
absl::Status OutputBatch(CalculatorContext* cc,
|
absl::Status OutputBatch(CalculatorContext* cc,
|
||||||
std::unique_ptr<InferenceState> inference_state) {
|
std::unique_ptr<InferenceState> inference_state) {
|
||||||
const int64 start_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;
|
||||||
|
|
||||||
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
|
||||||
|
@ -544,7 +544,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
get_session_run_throttle(options_.max_concurrent_session_runs());
|
get_session_run_throttle(options_.max_concurrent_session_runs());
|
||||||
session_run_throttle->Acquire(1);
|
session_run_throttle->Acquire(1);
|
||||||
}
|
}
|
||||||
const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
tf::Status tf_status;
|
tf::Status tf_status;
|
||||||
{
|
{
|
||||||
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
|
||||||
|
@ -562,7 +562,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
// informative error message.
|
// informative error message.
|
||||||
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
|
||||||
|
|
||||||
const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
|
||||||
->IncrementBy(run_end_time - run_start_time);
|
->IncrementBy(run_end_time - run_start_time);
|
||||||
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();
|
||||||
|
@ -611,7 +611,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get end time and report.
|
// Get end time and report.
|
||||||
const int64 end_time = absl::ToUnixMicros(clock_->TimeNow());
|
const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow());
|
||||||
cc->GetCounter(kTotalUsecsCounterSuffix)
|
cc->GetCounter(kTotalUsecsCounterSuffix)
|
||||||
->IncrementBy(end_time - start_time);
|
->IncrementBy(end_time - start_time);
|
||||||
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
|
||||||
|
@ -650,7 +650,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
||||||
|
|
||||||
// The static singleton semaphore to throttle concurrent session runs.
|
// The static singleton semaphore to throttle concurrent session runs.
|
||||||
static SimpleSemaphore* get_session_run_throttle(
|
static SimpleSemaphore* get_session_run_throttle(
|
||||||
int32 max_concurrent_session_runs) {
|
int32_t max_concurrent_session_runs) {
|
||||||
static SimpleSemaphore* session_run_throttle =
|
static SimpleSemaphore* session_run_throttle =
|
||||||
new SimpleSemaphore(max_concurrent_session_runs);
|
new SimpleSemaphore(max_concurrent_session_runs);
|
||||||
return session_run_throttle;
|
return session_run_throttle;
|
||||||
|
|
|
@ -197,15 +197,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// timestamp and the associated feature. This information is used in process
|
// timestamp and the associated feature. This information is used in process
|
||||||
// to output batches of packets in order.
|
// to output batches of packets in order.
|
||||||
timestamps_.clear();
|
timestamps_.clear();
|
||||||
int64 last_timestamp_seen = Timestamp::PreStream().Value();
|
int64_t last_timestamp_seen = Timestamp::PreStream().Value();
|
||||||
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
|
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
|
||||||
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
|
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
|
||||||
if (absl::StrContains(map_kv.first, "/timestamp")) {
|
if (absl::StrContains(map_kv.first, "/timestamp")) {
|
||||||
LOG(INFO) << "Found feature timestamps: " << map_kv.first
|
LOG(INFO) << "Found feature timestamps: " << map_kv.first
|
||||||
<< " with size: " << map_kv.second.feature_size();
|
<< " with size: " << map_kv.second.feature_size();
|
||||||
int64 recent_timestamp = Timestamp::PreStream().Value();
|
int64_t recent_timestamp = Timestamp::PreStream().Value();
|
||||||
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
|
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
|
||||||
int64 next_timestamp =
|
int64_t next_timestamp =
|
||||||
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
|
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
|
||||||
RET_CHECK_GT(next_timestamp, recent_timestamp)
|
RET_CHECK_GT(next_timestamp, recent_timestamp)
|
||||||
<< "Timestamps must be sequential. If you're seeing this message "
|
<< "Timestamps must be sequential. If you're seeing this message "
|
||||||
|
@ -361,8 +361,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// any particular call to Process(). At the every end, we output the
|
// any particular call to Process(). At the every end, we output the
|
||||||
// poststream packets. If we only have poststream packets,
|
// poststream packets. If we only have poststream packets,
|
||||||
// last_timestamp_key_ will be empty.
|
// last_timestamp_key_ will be empty.
|
||||||
int64 start_timestamp = 0;
|
int64_t start_timestamp = 0;
|
||||||
int64 end_timestamp = 0;
|
int64_t end_timestamp = 0;
|
||||||
if (last_timestamp_key_.empty() || process_poststream_) {
|
if (last_timestamp_key_.empty() || process_poststream_) {
|
||||||
process_poststream_ = true;
|
process_poststream_ = true;
|
||||||
start_timestamp = Timestamp::PostStream().Value();
|
start_timestamp = Timestamp::PostStream().Value();
|
||||||
|
@ -481,14 +481,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
|
||||||
// Store a map from the keys for each stream to the timestamps for each
|
// Store a map from the keys for each stream to the timestamps for each
|
||||||
// key. This allows us to identify which packets to output for each stream
|
// key. This allows us to identify which packets to output for each stream
|
||||||
// for timestamps within a given time window.
|
// for timestamps within a given time window.
|
||||||
std::map<std::string, std::vector<int64>> timestamps_;
|
std::map<std::string, std::vector<int64_t>> timestamps_;
|
||||||
// Store the stream with the latest timestamp in the SequenceExample.
|
// Store the stream with the latest timestamp in the SequenceExample.
|
||||||
std::string last_timestamp_key_;
|
std::string last_timestamp_key_;
|
||||||
// Store the index of the current timestamp. Will be less than
|
// Store the index of the current timestamp. Will be less than
|
||||||
// timestamps_[last_timestamp_key_].size().
|
// timestamps_[last_timestamp_key_].size().
|
||||||
int current_timestamp_index_;
|
int current_timestamp_index_;
|
||||||
// Store the very first timestamp, so we output everything on the first frame.
|
// Store the very first timestamp, so we output everything on the first frame.
|
||||||
int64 first_timestamp_seen_;
|
int64_t first_timestamp_seen_;
|
||||||
// List of keypoint names.
|
// List of keypoint names.
|
||||||
std::vector<std::string> keypoint_names_;
|
std::vector<std::string> keypoint_names_;
|
||||||
// Default keypoint location when missing.
|
// Default keypoint location when missing.
|
||||||
|
|
|
@ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 time = 1234;
|
const int64_t time = 1234;
|
||||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
Adopt(input.release()).At(Timestamp(time)));
|
Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) {
|
||||||
// 2^i can be represented exactly in floating point numbers if 'i' is small.
|
// 2^i can be represented exactly in floating point numbers if 'i' is small.
|
||||||
input->at(i) = static_cast<float>(1 << i);
|
input->at(i) = static_cast<float>(1 << i);
|
||||||
}
|
}
|
||||||
const int64 time = 1234;
|
const int64_t time = 1234;
|
||||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||||
Adopt(input.release()).At(Timestamp(time)));
|
Adopt(input.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user