Merge remote-tracking branch 'origin/master' into nguyencse/facemeshioslib
This commit is contained in:
commit
5093b7a2a9
16
WORKSPACE
16
WORKSPACE
|
@ -45,12 +45,13 @@ http_archive(
|
||||||
)
|
)
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "rules_foreign_cc",
|
name = "rules_foreign_cc",
|
||||||
strip_prefix = "rules_foreign_cc-0.1.0",
|
sha256 = "2a4d07cd64b0719b39a7c12218a3e507672b82a97b98c6a89d38565894cf7c51",
|
||||||
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip",
|
strip_prefix = "rules_foreign_cc-0.9.0",
|
||||||
|
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.9.0.tar.gz",
|
||||||
)
|
)
|
||||||
|
|
||||||
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
|
load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies")
|
||||||
|
|
||||||
rules_foreign_cc_dependencies()
|
rules_foreign_cc_dependencies()
|
||||||
|
|
||||||
|
@ -492,9 +493,10 @@ http_archive(
|
||||||
)
|
)
|
||||||
|
|
||||||
# TensorFlow repo should always go after the other external dependencies.
|
# TensorFlow repo should always go after the other external dependencies.
|
||||||
# TF on 2023-05-26.
|
# TF on 2023-06-13.
|
||||||
_TENSORFLOW_GIT_COMMIT = "67d5c561981edc45daf3f9d73ddd1a77963733ca"
|
_TENSORFLOW_GIT_COMMIT = "491681a5620e41bf079a582ac39c585cc86878b9"
|
||||||
_TENSORFLOW_SHA256 = "0c8326285e9cb695313e194b97d388eea70bf8bf5b13e8f0962ca8eed5179ece"
|
# curl -L https://github.com/tensorflow/tensorflow/archive/<TENSORFLOW_GIT_COMMIT>.tar.gz | shasum -a 256
|
||||||
|
_TENSORFLOW_SHA256 = "9f76389af7a2835e68413322c1eaabfadc912f02a76d71dc16be507f9ca3d3ac"
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "org_tensorflow",
|
name = "org_tensorflow",
|
||||||
urls = [
|
urls = [
|
||||||
|
|
|
@ -219,12 +219,10 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":time_series_framer_calculator_cc_proto",
|
":time_series_framer_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||||
"//mediapipe/framework/port:integral_types",
|
|
||||||
"//mediapipe/framework/port:logging",
|
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
|
||||||
"//mediapipe/util:time_series_util",
|
"//mediapipe/util:time_series_util",
|
||||||
"@com_google_audio_tools//audio/dsp:window_functions",
|
"@com_google_audio_tools//audio/dsp:window_functions",
|
||||||
"@eigen_archive//:eigen3",
|
"@eigen_archive//:eigen3",
|
||||||
|
@ -319,6 +317,20 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "time_series_framer_calculator_benchmark",
|
||||||
|
srcs = ["time_series_framer_calculator_benchmark.cc"],
|
||||||
|
deps = [
|
||||||
|
":time_series_framer_calculator",
|
||||||
|
":time_series_framer_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:matrix",
|
||||||
|
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||||
|
"@com_google_benchmark//:benchmark",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "time_series_framer_calculator_test",
|
name = "time_series_framer_calculator_test",
|
||||||
srcs = ["time_series_framer_calculator_test.cc"],
|
srcs = ["time_series_framer_calculator_test.cc"],
|
||||||
|
|
|
@ -15,9 +15,7 @@
|
||||||
// Defines TimeSeriesFramerCalculator.
|
// Defines TimeSeriesFramerCalculator.
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
#include <deque>
|
#include <vector>
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "Eigen/Core"
|
#include "Eigen/Core"
|
||||||
#include "audio/dsp/window_functions.h"
|
#include "audio/dsp/window_functions.h"
|
||||||
|
@ -25,9 +23,8 @@
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
|
||||||
#include "mediapipe/framework/port/logging.h"
|
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/util/time_series_util.h"
|
#include "mediapipe/util/time_series_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -88,11 +85,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
|
||||||
absl::Status Close(CalculatorContext* cc) override;
|
absl::Status Close(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Adds input data to the internal buffer.
|
|
||||||
void EnqueueInput(CalculatorContext* cc);
|
|
||||||
// Constructs and emits framed output packets.
|
|
||||||
void FrameOutput(CalculatorContext* cc);
|
|
||||||
|
|
||||||
Timestamp CurrentOutputTimestamp() {
|
Timestamp CurrentOutputTimestamp() {
|
||||||
if (use_local_timestamp_) {
|
if (use_local_timestamp_) {
|
||||||
return current_timestamp_;
|
return current_timestamp_;
|
||||||
|
@ -106,14 +98,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
|
||||||
Timestamp::kTimestampUnitsPerSecond);
|
Timestamp::kTimestampUnitsPerSecond);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the timestamp of a sample on a base, which is usually the time
|
|
||||||
// stamp of a packet.
|
|
||||||
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
|
|
||||||
int64_t number_of_samples) {
|
|
||||||
return timestamp_base + round(number_of_samples / sample_rate_ *
|
|
||||||
Timestamp::kTimestampUnitsPerSecond);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of input samples to advance after the current output frame is
|
// The number of input samples to advance after the current output frame is
|
||||||
// emitted.
|
// emitted.
|
||||||
int next_frame_step_samples() const {
|
int next_frame_step_samples() const {
|
||||||
|
@ -142,61 +126,174 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
|
||||||
Timestamp initial_input_timestamp_;
|
Timestamp initial_input_timestamp_;
|
||||||
// The current timestamp is updated along with the incoming packets.
|
// The current timestamp is updated along with the incoming packets.
|
||||||
Timestamp current_timestamp_;
|
Timestamp current_timestamp_;
|
||||||
int num_channels_;
|
|
||||||
|
|
||||||
// Each entry in this deque consists of a single sample, i.e. a
|
// Samples are buffered in a vector of sample blocks.
|
||||||
// single column vector, and its timestamp.
|
class SampleBlockBuffer {
|
||||||
std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
|
public:
|
||||||
|
// Initializes the buffer.
|
||||||
|
void Init(double sample_rate, int num_channels) {
|
||||||
|
ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate;
|
||||||
|
num_channels_ = num_channels;
|
||||||
|
num_samples_ = 0;
|
||||||
|
first_block_offset_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of channels, equal to the number of rows in each Matrix.
|
||||||
|
int num_channels() const { return num_channels_; }
|
||||||
|
// Total number of available samples over all blocks.
|
||||||
|
int num_samples() const { return num_samples_; }
|
||||||
|
|
||||||
|
// Pushes a new block of samples on the back of the buffer with `timestamp`
|
||||||
|
// being the input timestamp of the packet containing the Matrix.
|
||||||
|
void Push(const Matrix& samples, Timestamp timestamp);
|
||||||
|
// Copies `count` samples from the front of the buffer. If there are fewer
|
||||||
|
// samples than this, the result is zero padded to have `count` samples.
|
||||||
|
// The timestamp of the last copied sample is written to *last_timestamp.
|
||||||
|
// This output is used below to update `current_timestamp_`, which is only
|
||||||
|
// used when `use_local_timestamp` is true.
|
||||||
|
Matrix CopySamples(int count, Timestamp* last_timestamp) const;
|
||||||
|
// Drops `count` samples from the front of the buffer. If `count` exceeds
|
||||||
|
// `num_samples()`, the buffer is emptied. Returns how many samples were
|
||||||
|
// dropped.
|
||||||
|
int DropSamples(int count);
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Block {
|
||||||
|
// Matrix of num_channels rows by num_samples columns, a block of possibly
|
||||||
|
// multiple samples.
|
||||||
|
Matrix samples;
|
||||||
|
// Timestamp of the first sample in the Block. This comes from the input
|
||||||
|
// packet's timestamp that contains this Matrix.
|
||||||
|
Timestamp timestamp;
|
||||||
|
|
||||||
|
Block() : timestamp(Timestamp::Unstarted()) {}
|
||||||
|
Block(const Matrix& samples, Timestamp timestamp)
|
||||||
|
: samples(samples), timestamp(timestamp) {}
|
||||||
|
int num_samples() const { return samples.cols(); }
|
||||||
|
};
|
||||||
|
std::vector<Block> blocks_;
|
||||||
|
// Number of timestamp units per sample. Used to compute timestamps as
|
||||||
|
// nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n).
|
||||||
|
double ts_units_per_sample_;
|
||||||
|
// Number of rows in each Matrix.
|
||||||
|
int num_channels_;
|
||||||
|
// The total number of samples over all blocks, equal to
|
||||||
|
// (sum_i blocks_[i].num_samples()) - first_block_offset_.
|
||||||
|
int num_samples_;
|
||||||
|
// The number of samples in the first block that have been discarded. This
|
||||||
|
// way we can cheaply represent "partially discarding" a block.
|
||||||
|
int first_block_offset_;
|
||||||
|
} sample_buffer_;
|
||||||
|
|
||||||
bool use_window_;
|
bool use_window_;
|
||||||
Matrix window_;
|
Eigen::RowVectorXf window_;
|
||||||
|
|
||||||
bool use_local_timestamp_;
|
bool use_local_timestamp_;
|
||||||
};
|
};
|
||||||
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
|
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
|
||||||
|
|
||||||
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
|
void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples,
|
||||||
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
|
Timestamp timestamp) {
|
||||||
|
num_samples_ += samples.cols();
|
||||||
for (int i = 0; i < input_frame.cols(); ++i) {
|
blocks_.emplace_back(samples, timestamp);
|
||||||
sample_buffer_.emplace_back(std::make_pair(
|
|
||||||
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
|
Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples(
|
||||||
while (sample_buffer_.size() >=
|
int count, Timestamp* last_timestamp) const {
|
||||||
|
Matrix copied(num_channels_, count);
|
||||||
|
|
||||||
|
if (!blocks_.empty()) {
|
||||||
|
int num_copied = 0;
|
||||||
|
// First block has an offset for samples that have been discarded.
|
||||||
|
int offset = first_block_offset_;
|
||||||
|
int n;
|
||||||
|
Timestamp last_block_ts;
|
||||||
|
int last_sample_index;
|
||||||
|
|
||||||
|
for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) {
|
||||||
|
n = std::min(it->num_samples() - offset, count);
|
||||||
|
// Copy `n` samples from the next block.
|
||||||
|
copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n);
|
||||||
|
count -= n;
|
||||||
|
num_copied += n;
|
||||||
|
last_block_ts = it->timestamp;
|
||||||
|
last_sample_index = offset + n - 1;
|
||||||
|
offset = 0; // No samples have been discarded in subsequent blocks.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the timestamp of the last copied sample.
|
||||||
|
*last_timestamp =
|
||||||
|
last_block_ts + std::round(ts_units_per_sample_ * last_sample_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count > 0) {
|
||||||
|
copied.rightCols(count).setZero(); // Zero pad if needed.
|
||||||
|
}
|
||||||
|
|
||||||
|
return copied;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) {
|
||||||
|
if (blocks_.empty()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto block_it = blocks_.begin();
|
||||||
|
if (first_block_offset_ + count < block_it->num_samples()) {
|
||||||
|
// `count` is less than the remaining samples in the first block.
|
||||||
|
first_block_offset_ += count;
|
||||||
|
num_samples_ -= count;
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_samples_dropped = block_it->num_samples() - first_block_offset_;
|
||||||
|
count -= num_samples_dropped;
|
||||||
|
first_block_offset_ = 0;
|
||||||
|
|
||||||
|
for (++block_it; block_it != blocks_.end(); ++block_it) {
|
||||||
|
if (block_it->num_samples() > count) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
num_samples_dropped += block_it->num_samples();
|
||||||
|
count -= block_it->num_samples();
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks_.erase(blocks_.begin(), block_it); // Drop whole blocks.
|
||||||
|
if (!blocks_.empty()) {
|
||||||
|
first_block_offset_ = count; // Drop part of the next block.
|
||||||
|
num_samples_dropped += count;
|
||||||
|
}
|
||||||
|
|
||||||
|
num_samples_ -= num_samples_dropped;
|
||||||
|
return num_samples_dropped;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
|
||||||
|
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
||||||
|
initial_input_timestamp_ = cc->InputTimestamp();
|
||||||
|
current_timestamp_ = initial_input_timestamp_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add input data to the internal buffer.
|
||||||
|
sample_buffer_.Push(cc->Inputs().Index(0).Get<Matrix>(),
|
||||||
|
cc->InputTimestamp());
|
||||||
|
|
||||||
|
// Construct and emit framed output packets.
|
||||||
|
while (sample_buffer_.num_samples() >=
|
||||||
frame_duration_samples_ + samples_still_to_drop_) {
|
frame_duration_samples_ + samples_still_to_drop_) {
|
||||||
while (samples_still_to_drop_ > 0) {
|
sample_buffer_.DropSamples(samples_still_to_drop_);
|
||||||
sample_buffer_.pop_front();
|
Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
|
||||||
--samples_still_to_drop_;
|
¤t_timestamp_);
|
||||||
}
|
|
||||||
const int frame_step_samples = next_frame_step_samples();
|
const int frame_step_samples = next_frame_step_samples();
|
||||||
std::unique_ptr<Matrix> output_frame(
|
samples_still_to_drop_ = frame_step_samples;
|
||||||
new Matrix(num_channels_, frame_duration_samples_));
|
|
||||||
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
|
|
||||||
++i) {
|
|
||||||
output_frame->col(i) = sample_buffer_.front().first;
|
|
||||||
current_timestamp_ = sample_buffer_.front().second;
|
|
||||||
sample_buffer_.pop_front();
|
|
||||||
}
|
|
||||||
const int frame_overlap_samples =
|
|
||||||
frame_duration_samples_ - frame_step_samples;
|
|
||||||
if (frame_overlap_samples > 0) {
|
|
||||||
for (int i = 0; i < frame_overlap_samples; ++i) {
|
|
||||||
output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
|
|
||||||
current_timestamp_ = sample_buffer_[i].second;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
samples_still_to_drop_ = -frame_overlap_samples;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_window_) {
|
if (use_window_) {
|
||||||
*output_frame = (output_frame->array() * window_.array()).matrix();
|
// Apply the window to each row of output_frame.
|
||||||
|
output_frame.array().rowwise() *= window_.array();
|
||||||
}
|
}
|
||||||
|
|
||||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
|
||||||
CurrentOutputTimestamp());
|
.At(CurrentOutputTimestamp()));
|
||||||
++cumulative_output_frames_;
|
++cumulative_output_frames_;
|
||||||
cumulative_completed_samples_ += frame_step_samples;
|
cumulative_completed_samples_ += frame_step_samples;
|
||||||
}
|
}
|
||||||
|
@ -206,35 +303,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
|
||||||
// fact to enable packet queueing optimizations.
|
// fact to enable packet queueing optimizations.
|
||||||
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
|
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
|
|
||||||
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
|
|
||||||
initial_input_timestamp_ = cc->InputTimestamp();
|
|
||||||
current_timestamp_ = initial_input_timestamp_;
|
|
||||||
}
|
|
||||||
|
|
||||||
EnqueueInput(cc);
|
|
||||||
FrameOutput(cc);
|
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
|
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
|
||||||
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
|
sample_buffer_.DropSamples(samples_still_to_drop_);
|
||||||
sample_buffer_.pop_front();
|
|
||||||
--samples_still_to_drop_;
|
|
||||||
}
|
|
||||||
if (!sample_buffer_.empty() && pad_final_packet_) {
|
|
||||||
std::unique_ptr<Matrix> output_frame(new Matrix);
|
|
||||||
output_frame->setZero(num_channels_, frame_duration_samples_);
|
|
||||||
for (int i = 0; i < sample_buffer_.size(); ++i) {
|
|
||||||
output_frame->col(i) = sample_buffer_[i].first;
|
|
||||||
current_timestamp_ = sample_buffer_[i].second;
|
|
||||||
}
|
|
||||||
|
|
||||||
cc->Outputs().Index(0).Add(output_frame.release(),
|
if (sample_buffer_.num_samples() > 0 && pad_final_packet_) {
|
||||||
CurrentOutputTimestamp());
|
Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
|
||||||
|
¤t_timestamp_);
|
||||||
|
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
|
||||||
|
.At(CurrentOutputTimestamp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
@ -258,7 +338,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->Inputs().Index(0).Header(), &input_header));
|
cc->Inputs().Index(0).Header(), &input_header));
|
||||||
|
|
||||||
sample_rate_ = input_header.sample_rate();
|
sample_rate_ = input_header.sample_rate();
|
||||||
num_channels_ = input_header.num_channels();
|
sample_buffer_.Init(sample_rate_, input_header.num_channels());
|
||||||
frame_duration_samples_ = time_series_util::SecondsToSamples(
|
frame_duration_samples_ = time_series_util::SecondsToSamples(
|
||||||
framer_options.frame_duration_seconds(), sample_rate_);
|
framer_options.frame_duration_seconds(), sample_rate_);
|
||||||
RET_CHECK_GT(frame_duration_samples_, 0)
|
RET_CHECK_GT(frame_duration_samples_, 0)
|
||||||
|
@ -312,9 +392,8 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_window_) {
|
if (use_window_) {
|
||||||
window_ = Matrix::Ones(num_channels_, 1) *
|
window_ = Eigen::Map<Eigen::RowVectorXd>(window_vector.data(),
|
||||||
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
|
frame_duration_samples_)
|
||||||
frame_duration_samples_)
|
|
||||||
.cast<float>();
|
.cast<float>();
|
||||||
}
|
}
|
||||||
use_local_timestamp_ = framer_options.use_local_timestamp();
|
use_local_timestamp_ = framer_options.use_local_timestamp();
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Benchmark for TimeSeriesFramerCalculator.
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "benchmark/benchmark.h"
|
||||||
|
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
|
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
|
||||||
|
using ::mediapipe::Matrix;
|
||||||
|
|
||||||
|
void BM_TimeSeriesFramerCalculator(benchmark::State& state) {
|
||||||
|
constexpr float kSampleRate = 32000.0;
|
||||||
|
constexpr int kNumChannels = 2;
|
||||||
|
constexpr int kFrameDurationSeconds = 5.0;
|
||||||
|
std::mt19937 rng(0 /*seed*/);
|
||||||
|
// Input around a half second's worth of samples at a time.
|
||||||
|
std::uniform_int_distribution<int> input_size_dist(15000, 17000);
|
||||||
|
// Generate a pool of random blocks of samples up front.
|
||||||
|
std::vector<Matrix> sample_pool;
|
||||||
|
sample_pool.reserve(20);
|
||||||
|
for (int i = 0; i < 20; ++i) {
|
||||||
|
sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng)));
|
||||||
|
}
|
||||||
|
std::uniform_int_distribution<int> pool_index_dist(0, sample_pool.size() - 1);
|
||||||
|
|
||||||
|
mediapipe::CalculatorGraphConfig config;
|
||||||
|
config.add_input_stream("input");
|
||||||
|
config.add_output_stream("output");
|
||||||
|
auto* node = config.add_node();
|
||||||
|
node->set_calculator("TimeSeriesFramerCalculator");
|
||||||
|
node->add_input_stream("input");
|
||||||
|
node->add_output_stream("output");
|
||||||
|
mediapipe::TimeSeriesFramerCalculatorOptions* options =
|
||||||
|
node->mutable_options()->MutableExtension(
|
||||||
|
mediapipe::TimeSeriesFramerCalculatorOptions::ext);
|
||||||
|
options->set_frame_duration_seconds(kFrameDurationSeconds);
|
||||||
|
|
||||||
|
for (auto _ : state) {
|
||||||
|
state.PauseTiming(); // Pause benchmark timing.
|
||||||
|
|
||||||
|
// Prepare input packets of random blocks of samples.
|
||||||
|
std::vector<mediapipe::Packet> input_packets;
|
||||||
|
input_packets.reserve(32);
|
||||||
|
float t = 0;
|
||||||
|
for (int i = 0; i < 32; ++i) {
|
||||||
|
auto samples =
|
||||||
|
std::make_unique<Matrix>(sample_pool[pool_index_dist(rng)]);
|
||||||
|
const int num_samples = samples->cols();
|
||||||
|
input_packets.push_back(mediapipe::Adopt(samples.release())
|
||||||
|
.At(mediapipe::Timestamp::FromSeconds(t)));
|
||||||
|
t += num_samples / kSampleRate;
|
||||||
|
}
|
||||||
|
// Initialize graph.
|
||||||
|
mediapipe::CalculatorGraph graph;
|
||||||
|
CHECK_OK(graph.Initialize(config));
|
||||||
|
// Prepare input header.
|
||||||
|
auto header = std::make_unique<mediapipe::TimeSeriesHeader>();
|
||||||
|
header->set_sample_rate(kSampleRate);
|
||||||
|
header->set_num_channels(kNumChannels);
|
||||||
|
|
||||||
|
state.ResumeTiming(); // Resume benchmark timing.
|
||||||
|
|
||||||
|
CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}}));
|
||||||
|
for (auto& packet : input_packets) {
|
||||||
|
CHECK_OK(graph.AddPacketToInputStream("input", packet));
|
||||||
|
}
|
||||||
|
CHECK(!graph.HasError());
|
||||||
|
CHECK_OK(graph.CloseAllInputStreams());
|
||||||
|
CHECK_OK(graph.WaitUntilIdle());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_TimeSeriesFramerCalculator);
|
||||||
|
|
||||||
|
BENCHMARK_MAIN();
|
|
@ -117,6 +117,7 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/framework/formats:classification_proto",
|
"//mediapipe/framework/formats:classification_proto",
|
||||||
"//mediapipe/framework/formats:landmark_proto",
|
"//mediapipe/framework/formats:landmark_proto",
|
||||||
|
"//mediapipe/framework/formats:matrix_data_proto",
|
||||||
"//mediapipe/framework/formats:time_series_header_proto",
|
"//mediapipe/framework/formats:time_series_header_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -289,6 +290,7 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
@ -1167,6 +1169,7 @@ cc_library(
|
||||||
"//mediapipe/framework:collection_item_id",
|
"//mediapipe/framework:collection_item_id",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:matrix_data_cc_proto",
|
||||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
@ -104,4 +105,7 @@ typedef ConcatenateVectorCalculator<mediapipe::RenderData>
|
||||||
ConcatenateRenderDataVectorCalculator;
|
ConcatenateRenderDataVectorCalculator;
|
||||||
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
|
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
|
||||||
|
|
||||||
|
typedef ConcatenateVectorCalculator<mediapipe::Image>
|
||||||
|
ConcatenateImageVectorCalculator;
|
||||||
|
MEDIAPIPE_REGISTER_NODE(ConcatenateImageVectorCalculator);
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "mediapipe/framework/collection_item_id.h"
|
#include "mediapipe/framework/collection_item_id.h"
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
#include "mediapipe/framework/formats/matrix_data.pb.h"
|
||||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
#include "mediapipe/framework/port/canonical_errors.h"
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
@ -85,8 +86,12 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
packet.Set<LandmarkList>();
|
packet.Set<LandmarkList>();
|
||||||
} else if (packet_options.has_double_value()) {
|
} else if (packet_options.has_double_value()) {
|
||||||
packet.Set<double>();
|
packet.Set<double>();
|
||||||
|
} else if (packet_options.has_matrix_data_value()) {
|
||||||
|
packet.Set<MatrixData>();
|
||||||
} else if (packet_options.has_time_series_header_value()) {
|
} else if (packet_options.has_time_series_header_value()) {
|
||||||
packet.Set<TimeSeriesHeader>();
|
packet.Set<TimeSeriesHeader>();
|
||||||
|
} else if (packet_options.has_int64_value()) {
|
||||||
|
packet.Set<int64_t>();
|
||||||
} else {
|
} else {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"None of supported values were specified in options.");
|
"None of supported values were specified in options.");
|
||||||
|
@ -121,9 +126,13 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
MakePacket<LandmarkList>(packet_options.landmark_list_value()));
|
MakePacket<LandmarkList>(packet_options.landmark_list_value()));
|
||||||
} else if (packet_options.has_double_value()) {
|
} else if (packet_options.has_double_value()) {
|
||||||
packet.Set(MakePacket<double>(packet_options.double_value()));
|
packet.Set(MakePacket<double>(packet_options.double_value()));
|
||||||
|
} else if (packet_options.has_matrix_data_value()) {
|
||||||
|
packet.Set(MakePacket<MatrixData>(packet_options.matrix_data_value()));
|
||||||
} else if (packet_options.has_time_series_header_value()) {
|
} else if (packet_options.has_time_series_header_value()) {
|
||||||
packet.Set(MakePacket<TimeSeriesHeader>(
|
packet.Set(MakePacket<TimeSeriesHeader>(
|
||||||
packet_options.time_series_header_value()));
|
packet_options.time_series_header_value()));
|
||||||
|
} else if (packet_options.has_int64_value()) {
|
||||||
|
packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
|
||||||
} else {
|
} else {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"None of supported values were specified in options.");
|
"None of supported values were specified in options.");
|
||||||
|
|
|
@ -19,6 +19,7 @@ package mediapipe;
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/framework/formats/classification.proto";
|
import "mediapipe/framework/formats/classification.proto";
|
||||||
import "mediapipe/framework/formats/landmark.proto";
|
import "mediapipe/framework/formats/landmark.proto";
|
||||||
|
import "mediapipe/framework/formats/matrix_data.proto";
|
||||||
import "mediapipe/framework/formats/time_series_header.proto";
|
import "mediapipe/framework/formats/time_series_header.proto";
|
||||||
|
|
||||||
message ConstantSidePacketCalculatorOptions {
|
message ConstantSidePacketCalculatorOptions {
|
||||||
|
@ -29,14 +30,16 @@ message ConstantSidePacketCalculatorOptions {
|
||||||
message ConstantSidePacket {
|
message ConstantSidePacket {
|
||||||
oneof value {
|
oneof value {
|
||||||
int32 int_value = 1;
|
int32 int_value = 1;
|
||||||
|
uint64 uint64_value = 5;
|
||||||
|
int64 int64_value = 11;
|
||||||
float float_value = 2;
|
float float_value = 2;
|
||||||
|
double double_value = 9;
|
||||||
bool bool_value = 3;
|
bool bool_value = 3;
|
||||||
string string_value = 4;
|
string string_value = 4;
|
||||||
uint64 uint64_value = 5;
|
|
||||||
ClassificationList classification_list_value = 6;
|
ClassificationList classification_list_value = 6;
|
||||||
LandmarkList landmark_list_value = 7;
|
LandmarkList landmark_list_value = 7;
|
||||||
double double_value = 9;
|
|
||||||
TimeSeriesHeader time_series_header_value = 10;
|
TimeSeriesHeader time_series_header_value = 10;
|
||||||
|
MatrixData matrix_data_value = 12;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
|
||||||
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
|
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
|
||||||
DoTestSingleSidePacket("{ bool_value: true }", true);
|
DoTestSingleSidePacket("{ bool_value: true }", true);
|
||||||
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
|
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
|
||||||
|
DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
|
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
|
||||||
|
|
|
@ -228,7 +228,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -280,7 +279,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
|
@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
||||||
input_tensors.reserve(kNumInputTensorsForBert);
|
input_tensors.reserve(kNumInputTensorsForBert);
|
||||||
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
||||||
input_tensors.push_back(
|
input_tensors.push_back(
|
||||||
{Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
|
{Tensor::ElementType::kInt32,
|
||||||
|
Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
|
||||||
}
|
}
|
||||||
std::memcpy(input_tensors[input_ids_tensor_index_]
|
std::memcpy(input_tensors[input_ids_tensor_index_]
|
||||||
.GetCpuWriteView()
|
.GetCpuWriteView()
|
||||||
|
|
|
@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
||||||
// Read CPU input into tensors.
|
// Read CPU input into tensors.
|
||||||
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
||||||
|
|
||||||
|
// If the input tensors have dynamic shape, then the tensors need to be
|
||||||
|
// resized and reallocated before we can copy the tensor values.
|
||||||
|
bool resized_tensor_shapes = false;
|
||||||
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
|
if (input_tensors[i].shape().is_dynamic) {
|
||||||
|
interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
|
||||||
|
resized_tensor_shapes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Reallocation is needed for memory sanity.
|
||||||
|
if (resized_tensor_shapes) interpreter_->AllocateTensors();
|
||||||
|
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
const TfLiteType input_tensor_type =
|
const TfLiteType input_tensor_type =
|
||||||
interpreter_->tensor(interpreter_->inputs()[i])->type;
|
interpreter_->tensor(interpreter_->inputs()[i])->type;
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||||
// not found in the tokenizer vocab.
|
// not found in the tokenizer vocab.
|
||||||
std::vector<Tensor> result;
|
std::vector<Tensor> result;
|
||||||
result.push_back(
|
result.push_back(
|
||||||
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
|
{Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
|
||||||
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
||||||
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
||||||
kTensorsOut(cc).Send(std::move(result));
|
kTensorsOut(cc).Send(std::move(result));
|
||||||
|
|
|
@ -1077,6 +1077,7 @@ cc_test(
|
||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
":tensor_to_image_frame_calculator",
|
":tensor_to_image_frame_calculator",
|
||||||
|
":tensor_to_image_frame_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
|
|
@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float scale_factor_;
|
float scale_factor_;
|
||||||
|
bool scale_per_frame_min_max_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
||||||
|
@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
||||||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
||||||
scale_factor_ =
|
scale_factor_ =
|
||||||
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
||||||
|
scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
|
||||||
|
.scale_per_frame_min_max();
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
||||||
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
||||||
const int32_t total_size = height * width * depth;
|
const int32_t total_size = height * width * depth;
|
||||||
|
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
|
RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
|
||||||
|
<< "Setting scale_per_frame_min_max requires FLOAT input tensors.";
|
||||||
|
}
|
||||||
::std::unique_ptr<const ImageFrame> output;
|
::std::unique_ptr<const ImageFrame> output;
|
||||||
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
||||||
// Allocate buffer with alignments.
|
// Allocate buffer with alignments.
|
||||||
std::unique_ptr<uint8_t[]> buffer(
|
std::unique_ptr<uint8_t[]> buffer(
|
||||||
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
||||||
auto data = input_tensor.flat<float>().data();
|
auto data = input_tensor.flat<float>().data();
|
||||||
|
float min = 1e23;
|
||||||
|
float max = -1e23;
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
|
for (int i = 0; i < total_size; ++i) {
|
||||||
|
float d = scale_factor_ * data[i];
|
||||||
|
if (d < min) {
|
||||||
|
min = d;
|
||||||
|
}
|
||||||
|
if (d > max) {
|
||||||
|
max = d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int i = 0; i < total_size; ++i) {
|
for (int i = 0; i < total_size; ++i) {
|
||||||
float d = scale_factor_ * data[i];
|
float d = data[i];
|
||||||
if (d < 0) d = 0;
|
if (scale_per_frame_min_max_) {
|
||||||
if (d > 255) d = 255;
|
d = 255 * (d - min) / (max - min + 1e-9);
|
||||||
|
} else {
|
||||||
|
d = scale_factor_ * d;
|
||||||
|
if (d < 0) d = 0;
|
||||||
|
if (d > 255) d = 255;
|
||||||
|
}
|
||||||
buffer[i] = d;
|
buffer[i] = d;
|
||||||
}
|
}
|
||||||
output = ::absl::make_unique<ImageFrame>(
|
output = ::absl::make_unique<ImageFrame>(
|
||||||
|
|
|
@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
|
||||||
// Multiples floating point tensor outputs by this value before converting to
|
// Multiples floating point tensor outputs by this value before converting to
|
||||||
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
||||||
optional float scale_factor = 1 [default = 1.0];
|
optional float scale_factor = 1 [default = 1.0];
|
||||||
|
|
||||||
|
// If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
|
||||||
|
// per frame. This overrides any explicit scale_factor.
|
||||||
|
optional bool scale_per_frame_min_max = 2 [default = false];
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,9 @@
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
|
||||||
template <class TypeParam>
|
template <class TypeParam>
|
||||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpRunner() {
|
void SetUpRunner(bool scale_per_frame_min_max = false) {
|
||||||
CalculatorGraphConfig::Node config;
|
CalculatorGraphConfig::Node config;
|
||||||
config.set_calculator("TensorToImageFrameCalculator");
|
config.set_calculator("TensorToImageFrameCalculator");
|
||||||
config.add_input_stream("TENSOR:input_tensor");
|
config.add_input_stream("TENSOR:input_tensor");
|
||||||
config.add_output_stream("IMAGE:output_image");
|
config.add_output_stream("IMAGE:output_image");
|
||||||
|
config.mutable_options()
|
||||||
|
->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
|
||||||
|
->set_scale_per_frame_min_max(scale_per_frame_min_max);
|
||||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
|
Converts3DTensorToImageFrame2DGrayWithScaling) {
|
||||||
|
this->SetUpRunner(true);
|
||||||
|
auto& runner = this->runner_;
|
||||||
|
constexpr int kWidth = 16;
|
||||||
|
constexpr int kHeight = 8;
|
||||||
|
const tf::TensorShape tensor_shape{kHeight, kWidth};
|
||||||
|
auto tensor = absl::make_unique<tf::Tensor>(
|
||||||
|
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||||
|
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||||
|
|
||||||
|
// Writing sequence of integers as floats which we want normalized.
|
||||||
|
tensor_vec[0] = 255;
|
||||||
|
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||||
|
tensor_vec[i] = 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t time = 1234;
|
||||||
|
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||||
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
if (!std::is_same<TypeParam, float>::value) {
|
||||||
|
EXPECT_FALSE(runner->Run().ok());
|
||||||
|
return; // Short circuit because does not apply to other types.
|
||||||
|
} else {
|
||||||
|
EXPECT_TRUE(runner->Run().ok());
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner->Outputs().Tag(kImage).packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
|
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||||
|
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||||
|
EXPECT_EQ(kWidth, output_image.Width());
|
||||||
|
EXPECT_EQ(kHeight, output_image.Height());
|
||||||
|
|
||||||
|
EXPECT_EQ(255, output_image.PixelData()[0]);
|
||||||
|
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||||
|
const uint8_t pixel_value = output_image.PixelData()[i];
|
||||||
|
ASSERT_EQ(0, pixel_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -1355,6 +1355,23 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "calculator_graph_summary_packet_test",
|
||||||
|
srcs = ["calculator_graph_summary_packet_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":calculator_framework",
|
||||||
|
":packet",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||||
|
"//mediapipe/framework/tool:sink",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "calculator_runner_test",
|
name = "calculator_runner_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
|
|
@ -32,7 +32,7 @@ template <class T>
|
||||||
struct dependent_false : std::false_type {};
|
struct dependent_false : std::false_type {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, int index) {
|
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, size_t index) {
|
||||||
auto& vec = *vecp;
|
auto& vec = *vecp;
|
||||||
if (vec.size() <= index) {
|
if (vec.size() <= index) {
|
||||||
vec.resize(index + 1);
|
vec.resize(index + 1);
|
||||||
|
|
|
@ -109,9 +109,20 @@ class CalculatorContext {
|
||||||
// use OutputStream::SetOffset() directly.
|
// use OutputStream::SetOffset() directly.
|
||||||
void SetOffset(TimestampDiff offset);
|
void SetOffset(TimestampDiff offset);
|
||||||
|
|
||||||
// Returns the status of the graph run.
|
// DEPRECATED: This was intended to get graph run status during
|
||||||
|
// `CalculatorBase::Close` call. However, `Close` can run simultaneously with
|
||||||
|
// other calculators `CalculatorBase::Process`, hence the actual graph
|
||||||
|
// status may change any time and returned graph status here does not
|
||||||
|
// necessarily reflect the actual graph status.
|
||||||
//
|
//
|
||||||
// NOTE: This method should only be called during CalculatorBase::Close().
|
// As an alternative, instead of checking graph status in `Close` and doing
|
||||||
|
// work for "done" state, you can enable timestamp bound processing for your
|
||||||
|
// calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger
|
||||||
|
// `Process` on timestamp bound updates and handle "done" state there.
|
||||||
|
// Check examples in:
|
||||||
|
// mediapipe/framework/calculator_graph_summary_packet_test.cc.
|
||||||
|
//
|
||||||
|
ABSL_DEPRECATED("Does not reflect the actual graph status.")
|
||||||
absl::Status GraphStatus() const { return graph_status_; }
|
absl::Status GraphStatus() const { return graph_status_; }
|
||||||
|
|
||||||
ProfilingContext* GetProfilingContext() const {
|
ProfilingContext* GetProfilingContext() const {
|
||||||
|
|
|
@ -839,6 +839,13 @@ absl::Status CalculatorGraph::PrepareForRun(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status CalculatorGraph::WaitUntilIdle() {
|
absl::Status CalculatorGraph::WaitUntilIdle() {
|
||||||
|
if (has_sources_) {
|
||||||
|
LOG_FIRST_N(WARNING, 1)
|
||||||
|
<< "WaitUntilIdle called on a graph with source nodes, which "
|
||||||
|
"is not fully supported at the moment. Source nodes: "
|
||||||
|
<< ListSourceNodes();
|
||||||
|
}
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle());
|
MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle());
|
||||||
VLOG(2) << "Scheduler idle.";
|
VLOG(2) << "Scheduler idle.";
|
||||||
absl::Status status = absl::OkStatus();
|
absl::Status status = absl::OkStatus();
|
||||||
|
@ -1368,6 +1375,16 @@ const OutputStreamManager* CalculatorGraph::FindOutputStreamManager(
|
||||||
.get()[validated_graph_->OutputStreamIndex(name)];
|
.get()[validated_graph_->OutputStreamIndex(name)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string CalculatorGraph::ListSourceNodes() const {
|
||||||
|
std::vector<std::string> sources;
|
||||||
|
for (auto& node : nodes_) {
|
||||||
|
if (node->IsSource()) {
|
||||||
|
sources.push_back(node->DebugName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absl::StrJoin(sources, ", ");
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void PrintTimingToInfo(const std::string& label, int64_t timer_value) {
|
void PrintTimingToInfo(const std::string& label, int64_t timer_value) {
|
||||||
const int64_t total_seconds = timer_value / 1000000ll;
|
const int64_t total_seconds = timer_value / 1000000ll;
|
||||||
|
|
|
@ -229,8 +229,11 @@ class CalculatorGraph {
|
||||||
// Wait until the running graph is in the idle mode, which is when nothing can
|
// Wait until the running graph is in the idle mode, which is when nothing can
|
||||||
// be scheduled and nothing is running in the worker threads. This function
|
// be scheduled and nothing is running in the worker threads. This function
|
||||||
// can be called only after StartRun().
|
// can be called only after StartRun().
|
||||||
|
//
|
||||||
// NOTE: The graph must not have any source nodes because source nodes prevent
|
// NOTE: The graph must not have any source nodes because source nodes prevent
|
||||||
// the running graph from becoming idle until the source nodes are done.
|
// the running graph from becoming idle until the source nodes are done.
|
||||||
|
// Currently, `WaitUntilIdle` cannot be used reliably on graphs with any
|
||||||
|
// source nodes.
|
||||||
absl::Status WaitUntilIdle();
|
absl::Status WaitUntilIdle();
|
||||||
|
|
||||||
// Wait until a packet is emitted on one of the observed output streams.
|
// Wait until a packet is emitted on one of the observed output streams.
|
||||||
|
@ -594,6 +597,9 @@ class CalculatorGraph {
|
||||||
// status before taking any action.
|
// status before taking any action.
|
||||||
void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
|
void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
|
||||||
|
|
||||||
|
// Returns a comma-separated list of source nodes.
|
||||||
|
std::string ListSourceNodes() const;
|
||||||
|
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
// Owns the legacy GpuSharedData if we need to create one for backwards
|
// Owns the legacy GpuSharedData if we need to create one for backwards
|
||||||
// compatibility.
|
// compatibility.
|
||||||
|
|
430
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
430
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
|
@ -0,0 +1,430 @@
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Node;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::Eq;
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
using ::testing::IsEmpty;
|
||||||
|
using ::testing::Value;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
MATCHER_P2(IntPacket, value, timestamp, "") {
|
||||||
|
*result_listener << "where object is (value: " << arg.template Get<int>()
|
||||||
|
<< ", timestamp: " << arg.Timestamp() << ")";
|
||||||
|
return Value(arg.template Get<int>(), Eq(value)) &&
|
||||||
|
Value(arg.Timestamp(), Eq(timestamp));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates and produces sum of all passed inputs when no more packets can be
|
||||||
|
// expected on the input stream.
|
||||||
|
class SummaryPacketCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"SUMMARY"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
|
// Makes sure there are no automatic timestamp bound updates when Process
|
||||||
|
// is called.
|
||||||
|
cc->SetTimestampOffset(TimestampDiff::Unset());
|
||||||
|
// Currently, only ImmediateInputStreamHandler supports "done" timestamp
|
||||||
|
// bound update. (ImmediateInputStreamhandler handles multiple input
|
||||||
|
// streams differently, so, in that case, calculator adjustments may be
|
||||||
|
// required.)
|
||||||
|
// TODO: update all input stream handlers to support "done"
|
||||||
|
// timestamp bound update.
|
||||||
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
// Enables processing timestamp bound updates. For this use case we are
|
||||||
|
// specifically interested in "done" timestamp bound update. (E.g. when
|
||||||
|
// all input packet sources are closed.)
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
if (!kIn(cc).IsEmpty()) {
|
||||||
|
value_ += kIn(cc).Get();
|
||||||
|
value_set_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kOut(cc).IsClosed()) {
|
||||||
|
// This can happen:
|
||||||
|
// 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
|
||||||
|
// source calculator finished generating packets sent to kIn) and
|
||||||
|
// HasNextAllowedInStream() == true (which is an often case).
|
||||||
|
// 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
|
||||||
|
// invoke Process() with Timestamp::Max to indicate "Done" timestamp
|
||||||
|
// bound update.
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: input stream holding a packet with timestamp that has
|
||||||
|
// no next timestamp allowed in stream should always result in
|
||||||
|
// InputStream::IsDone() == true.
|
||||||
|
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
|
||||||
|
// `Process` may or may not be invoked for "done" timestamp bound when
|
||||||
|
// upstream calculator fails in `Close`. Hence, extra care is needed to
|
||||||
|
// identify whether the calculator needs to send output.
|
||||||
|
// TODO: remove when "done" timestamp bound flakiness fixed.
|
||||||
|
if (value_set_) {
|
||||||
|
// kOut(cc).Send(value_) can be used here as well, however in the case
|
||||||
|
// of source calculator sending inputs into kIn the resulting timestamp
|
||||||
|
// is not well defined (e.g. it can be the last packet timestamp or
|
||||||
|
// Timestamp::Max())
|
||||||
|
// TODO: last packet from source should always result in
|
||||||
|
// InputStream::IsDone() == true.
|
||||||
|
kOut(cc).Send(value_, Timestamp::Max());
|
||||||
|
}
|
||||||
|
kOut(cc).Close();
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int value_ = 0;
|
||||||
|
bool value_set_ = false;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnClosingAllPacketSources) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
send_packet(20, Timestamp(11));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
send_packet(20, Timestamp::Max());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnPreStreamTimestamp) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PreStream());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnPostStreamTimestamp) {
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PostStream());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
class IntGeneratorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
kOut(cc).Send(20, Timestamp(0));
|
||||||
|
kOut(cc).Send(10, Timestamp(1000));
|
||||||
|
return tool::StatusStop();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnSourceCalculatorCompletion) {
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "IntGeneratorCalculator"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmitOnCloseCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||||
|
|
||||||
|
absl::Status Close(CalculatorContext* cc) final {
|
||||||
|
kOut(cc).Send(20, Timestamp(0));
|
||||||
|
kOut(cc).Send(10, Timestamp(1000));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnAnotherCalculatorClosure) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
node {
|
||||||
|
calculator: "EmitOnCloseCalculator"
|
||||||
|
input_stream: "IN:input"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseInputStream("input"));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
class FailureInCloseCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||||
|
|
||||||
|
absl::Status Close(CalculatorContext* cc) final {
|
||||||
|
return absl::InternalError("error");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(FailureInCloseCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInClose) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
node {
|
||||||
|
calculator: "FailureInCloseCalculator"
|
||||||
|
input_stream: "IN:input"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseInputStream("input"));
|
||||||
|
EXPECT_THAT(graph.WaitUntilIdle(),
|
||||||
|
StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
class FailureInProcessCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
return absl::InternalError("error");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(FailureInProcessCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInProcess) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
node {
|
||||||
|
calculator: "FailureInProcessCalculator"
|
||||||
|
input_stream: "IN:input"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PostStream());
|
||||||
|
EXPECT_THAT(graph.WaitUntilIdle(),
|
||||||
|
StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -117,11 +117,18 @@ class Tensor {
|
||||||
Shape() = default;
|
Shape() = default;
|
||||||
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
||||||
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
|
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
|
||||||
|
Shape(std::initializer_list<int> dimensions, bool is_dynamic)
|
||||||
|
: dims(dimensions), is_dynamic(is_dynamic) {}
|
||||||
|
Shape(const std::vector<int>& dimensions, bool is_dynamic)
|
||||||
|
: dims(dimensions), is_dynamic(is_dynamic) {}
|
||||||
int num_elements() const {
|
int num_elements() const {
|
||||||
return std::accumulate(dims.begin(), dims.end(), 1,
|
return std::accumulate(dims.begin(), dims.end(), 1,
|
||||||
std::multiplies<int>());
|
std::multiplies<int>());
|
||||||
}
|
}
|
||||||
std::vector<int> dims;
|
std::vector<int> dims;
|
||||||
|
// The Tensor has dynamic rather than static shape so the TFLite interpreter
|
||||||
|
// needs to be reallocated. Only relevant for CPU.
|
||||||
|
bool is_dynamic = false;
|
||||||
};
|
};
|
||||||
// Quantization parameters corresponding to the zero_point and scale value
|
// Quantization parameters corresponding to the zero_point and scale value
|
||||||
// made available by TfLite quantized (uint8/int8) tensors.
|
// made available by TfLite quantized (uint8/int8) tensors.
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -34,6 +35,17 @@ TEST(General, TestDataTypes) {
|
||||||
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(General, TestDynamic) {
|
||||||
|
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true));
|
||||||
|
EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4);
|
||||||
|
EXPECT_TRUE(t1.shape().is_dynamic);
|
||||||
|
|
||||||
|
std::vector<int> t2_dims = {4, 3, 2, 3};
|
||||||
|
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true));
|
||||||
|
EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3);
|
||||||
|
EXPECT_TRUE(t2.shape().is_dynamic);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Cpu, TestMemoryAllocation) {
|
TEST(Cpu, TestMemoryAllocation) {
|
||||||
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
|
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
|
||||||
auto v1 = t1.GetCpuWriteView();
|
auto v1 = t1.GetCpuWriteView();
|
||||||
|
|
|
@ -273,8 +273,8 @@ absl::Status Scheduler::WaitForObservedOutput() {
|
||||||
// Idleness requires:
|
// Idleness requires:
|
||||||
// 1. either the graph has no source nodes or all source nodes are closed, and
|
// 1. either the graph has no source nodes or all source nodes are closed, and
|
||||||
// 2. no packets are added to graph input streams.
|
// 2. no packets are added to graph input streams.
|
||||||
// For simplicity, we only allow WaitUntilIdle() to be called on a graph with
|
// For simplicity, we only fully support WaitUntilIdle() to be called on a graph
|
||||||
// no source nodes. (This is enforced by CalculatorGraph::WaitUntilIdle().)
|
// with no source nodes.
|
||||||
// The application must ensure no other threads are adding packets to graph
|
// The application must ensure no other threads are adding packets to graph
|
||||||
// input streams while a WaitUntilIdle() call is in progress.
|
// input streams while a WaitUntilIdle() call is in progress.
|
||||||
absl::Status Scheduler::WaitUntilIdle() {
|
absl::Status Scheduler::WaitUntilIdle() {
|
||||||
|
|
|
@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
|
||||||
return *this + 1;
|
return *this + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Timestamp::HasNextAllowedInStream() const {
|
||||||
|
if (*this >= Max() || *this == PreStream()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
Timestamp Timestamp::PreviousAllowedInStream() const {
|
Timestamp Timestamp::PreviousAllowedInStream() const {
|
||||||
if (*this <= Min() || *this == PostStream()) {
|
if (*this <= Min() || *this == PostStream()) {
|
||||||
// Indicates that no previous timestamps may occur.
|
// Indicates that no previous timestamps may occur.
|
||||||
|
|
|
@ -186,6 +186,10 @@ class Timestamp {
|
||||||
// CHECKs that this->IsAllowedInStream().
|
// CHECKs that this->IsAllowedInStream().
|
||||||
Timestamp NextAllowedInStream() const;
|
Timestamp NextAllowedInStream() const;
|
||||||
|
|
||||||
|
// Returns true if there's a next timestamp in the range [Min .. Max] after
|
||||||
|
// this one.
|
||||||
|
bool HasNextAllowedInStream() const;
|
||||||
|
|
||||||
// Returns the previous timestamp in the range [Min .. Max], or
|
// Returns the previous timestamp in the range [Min .. Max], or
|
||||||
// Unstarted() if no Packets may preceed one with this timestamp.
|
// Unstarted() if no Packets may preceed one with this timestamp.
|
||||||
Timestamp PreviousAllowedInStream() const;
|
Timestamp PreviousAllowedInStream() const;
|
||||||
|
|
|
@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
|
||||||
Timestamp::PostStream().NextAllowedInStream());
|
Timestamp::PostStream().NextAllowedInStream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TimestampTest, HasNextAllowedInStream) {
|
||||||
|
EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
|
||||||
|
|
||||||
|
EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TimestampTest, SpecialValueDifferences) {
|
TEST(TimestampTest, SpecialValueDifferences) {
|
||||||
{ // Lower range
|
{ // Lower range
|
||||||
const std::vector<Timestamp> timestamps = {
|
const std::vector<Timestamp> timestamps = {
|
||||||
|
|
|
@ -530,6 +530,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
"""MediaPipe Task Library Helper Rules for iOS"""
|
"""MediaPipe Task Library Helper Rules for iOS"""
|
||||||
|
|
||||||
MPP_TASK_MINIMUM_OS_VERSION = "11.0"
|
MPP_TASK_MINIMUM_OS_VERSION = "12.0"
|
||||||
|
|
||||||
# When the static framework is built with bazel, the all header files are moved
|
# When the static framework is built with bazel, the all header files are moved
|
||||||
# to the "Headers" directory with no header path prefixes. This auxiliary rule
|
# to the "Headers" directory with no header path prefixes. This auxiliary rule
|
||||||
|
|
|
@ -50,6 +50,7 @@ def mediapipe_proto_library_impl(
|
||||||
def_cc_proto = True,
|
def_cc_proto = True,
|
||||||
def_py_proto = True,
|
def_py_proto = True,
|
||||||
def_java_lite_proto = True,
|
def_java_lite_proto = True,
|
||||||
|
def_kt_lite_proto = True,
|
||||||
def_objc_proto = True,
|
def_objc_proto = True,
|
||||||
def_java_proto = True,
|
def_java_proto = True,
|
||||||
def_jspb_proto = True,
|
def_jspb_proto = True,
|
||||||
|
@ -72,6 +73,7 @@ def mediapipe_proto_library_impl(
|
||||||
def_cc_proto: define the cc_proto_library target
|
def_cc_proto: define the cc_proto_library target
|
||||||
def_py_proto: define the py_proto_library target
|
def_py_proto: define the py_proto_library target
|
||||||
def_java_lite_proto: define the java_lite_proto_library target
|
def_java_lite_proto: define the java_lite_proto_library target
|
||||||
|
def_kt_lite_proto: define the kt_lite_proto_library target
|
||||||
def_objc_proto: define the objc_proto_library target
|
def_objc_proto: define the objc_proto_library target
|
||||||
def_java_proto: define the java_proto_library target
|
def_java_proto: define the java_proto_library target
|
||||||
def_jspb_proto: define the jspb_proto_library target
|
def_jspb_proto: define the jspb_proto_library target
|
||||||
|
@ -255,6 +257,7 @@ def mediapipe_proto_library(
|
||||||
def_cc_proto = True,
|
def_cc_proto = True,
|
||||||
def_py_proto = True,
|
def_py_proto = True,
|
||||||
def_java_lite_proto = True,
|
def_java_lite_proto = True,
|
||||||
|
def_kt_lite_proto = True,
|
||||||
def_portable_proto = True, # @unused
|
def_portable_proto = True, # @unused
|
||||||
def_objc_proto = True,
|
def_objc_proto = True,
|
||||||
def_java_proto = True,
|
def_java_proto = True,
|
||||||
|
@ -281,6 +284,7 @@ def mediapipe_proto_library(
|
||||||
def_cc_proto: define the cc_proto_library target
|
def_cc_proto: define the cc_proto_library target
|
||||||
def_py_proto: define the py_proto_library target
|
def_py_proto: define the py_proto_library target
|
||||||
def_java_lite_proto: define the java_lite_proto_library target
|
def_java_lite_proto: define the java_lite_proto_library target
|
||||||
|
def_kt_lite_proto: define the kt_lite_proto_library target
|
||||||
def_portable_proto: ignored since portable protos are gone
|
def_portable_proto: ignored since portable protos are gone
|
||||||
def_objc_proto: define the objc_proto_library target
|
def_objc_proto: define the objc_proto_library target
|
||||||
def_java_proto: define the java_proto_library target
|
def_java_proto: define the java_proto_library target
|
||||||
|
@ -304,6 +308,7 @@ def mediapipe_proto_library(
|
||||||
def_cc_proto = def_cc_proto,
|
def_cc_proto = def_cc_proto,
|
||||||
def_py_proto = def_py_proto,
|
def_py_proto = def_py_proto,
|
||||||
def_java_lite_proto = def_java_lite_proto,
|
def_java_lite_proto = def_java_lite_proto,
|
||||||
|
def_kt_lite_proto = def_kt_lite_proto,
|
||||||
def_objc_proto = def_objc_proto,
|
def_objc_proto = def_objc_proto,
|
||||||
def_java_proto = def_java_proto,
|
def_java_proto = def_java_proto,
|
||||||
def_jspb_proto = def_jspb_proto,
|
def_jspb_proto = def_jspb_proto,
|
||||||
|
@ -334,6 +339,7 @@ def mediapipe_proto_library(
|
||||||
def_cc_proto = def_cc_proto,
|
def_cc_proto = def_cc_proto,
|
||||||
def_py_proto = def_py_proto,
|
def_py_proto = def_py_proto,
|
||||||
def_java_lite_proto = def_java_lite_proto,
|
def_java_lite_proto = def_java_lite_proto,
|
||||||
|
def_kt_lite_proto = def_kt_lite_proto,
|
||||||
def_objc_proto = def_objc_proto,
|
def_objc_proto = def_objc_proto,
|
||||||
def_java_proto = def_java_proto,
|
def_java_proto = def_java_proto,
|
||||||
def_jspb_proto = def_jspb_proto,
|
def_jspb_proto = def_jspb_proto,
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
|
@ -1430,10 +1431,10 @@ std::vector<const FieldDescriptor*> GetFields(const Message* src) {
|
||||||
|
|
||||||
// Orders map entries in dst to match src.
|
// Orders map entries in dst to match src.
|
||||||
void OrderMapEntries(const Message* src, Message* dst,
|
void OrderMapEntries(const Message* src, Message* dst,
|
||||||
std::set<const Message*>* seen = nullptr) {
|
absl::flat_hash_set<const Message*>* seen = nullptr) {
|
||||||
std::unique_ptr<std::set<const Message*>> seen_owner;
|
std::unique_ptr<absl::flat_hash_set<const Message*>> seen_owner;
|
||||||
if (!seen) {
|
if (!seen) {
|
||||||
seen_owner = std::make_unique<std::set<const Message*>>();
|
seen_owner = std::make_unique<absl::flat_hash_set<const Message*>>();
|
||||||
seen = seen_owner.get();
|
seen = seen_owner.get();
|
||||||
}
|
}
|
||||||
if (seen->count(src) > 0) {
|
if (seen->count(src) > 0) {
|
||||||
|
|
|
@ -34,6 +34,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
import javax.microedition.khronos.egl.EGLConfig;
|
import javax.microedition.khronos.egl.EGLConfig;
|
||||||
import javax.microedition.khronos.opengles.GL10;
|
import javax.microedition.khronos.opengles.GL10;
|
||||||
|
|
||||||
|
@ -303,7 +304,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this when the texture is not a SurfaceTexture.
|
// Use this when the texture is not a SurfaceTexture.
|
||||||
public void setNextFrame(TextureFrame frame) {
|
public void setNextFrame(@Nullable TextureFrame frame) {
|
||||||
if (surfaceTexture != null) {
|
if (surfaceTexture != null) {
|
||||||
Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */);
|
Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */);
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,6 @@ android_library(
|
||||||
"MediaPipeRunner.java",
|
"MediaPipeRunner.java",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//java/com/google/android/libraries/camera/effects:__subpackages__",
|
|
||||||
"//mediapipe/java/com/google/mediapipe:__subpackages__",
|
"//mediapipe/java/com/google/mediapipe:__subpackages__",
|
||||||
],
|
],
|
||||||
exports = [
|
exports = [
|
||||||
|
|
|
@ -67,6 +67,7 @@ public class ExternalTextureRenderer {
|
||||||
private float[] textureTransformMatrix = new float[16];
|
private float[] textureTransformMatrix = new float[16];
|
||||||
private boolean flipY;
|
private boolean flipY;
|
||||||
private int rotation = Surface.ROTATION_0;
|
private int rotation = Surface.ROTATION_0;
|
||||||
|
private boolean doExplicitCpuSync = true;
|
||||||
|
|
||||||
/** Call this to setup the shader program before rendering. */
|
/** Call this to setup the shader program before rendering. */
|
||||||
public void setup() {
|
public void setup() {
|
||||||
|
@ -101,6 +102,14 @@ public class ExternalTextureRenderer {
|
||||||
this.rotation = rotation;
|
this.rotation = rotation;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configures whether the renderer should do an explicit CPU synchronization using glFinish upon
|
||||||
|
* each {@link #render} call. Defaults to true.
|
||||||
|
*/
|
||||||
|
public void setDoExplicitCpuSync(boolean doExplicitCpuSync) {
|
||||||
|
this.doExplicitCpuSync = doExplicitCpuSync;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Renders the surfaceTexture to the framebuffer with optional vertical flip.
|
* Renders the surfaceTexture to the framebuffer with optional vertical flip.
|
||||||
*
|
*
|
||||||
|
@ -150,8 +159,11 @@ public class ExternalTextureRenderer {
|
||||||
GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
|
GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
|
||||||
ShaderUtil.checkGlError("glBindTexture");
|
ShaderUtil.checkGlError("glBindTexture");
|
||||||
|
|
||||||
// TODO: add sync and go back to glFlush()
|
if (doExplicitCpuSync) {
|
||||||
GLES20.glFinish();
|
|
||||||
|
// TODO: add sync and go back to glFlush()
|
||||||
|
GLES20.glFinish();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -14,7 +14,10 @@
|
||||||
|
|
||||||
# Placeholder for internal Python strict library and test compatibility macro.
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
package(default_visibility = [
|
||||||
|
"//cloud/ml/applications/vision/model_garden/model_oss/mediapipe:__subpackages__",
|
||||||
|
"//mediapipe:__subpackages__",
|
||||||
|
])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._model: tf.keras.Model = None
|
self._model: tf.keras.Model = None
|
||||||
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
|
||||||
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
self._loss_function: Union[str, tf.keras.losses.Loss] = None
|
||||||
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
|
self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None
|
||||||
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
|
||||||
self._hparams: hp.BaseHParams = None
|
self._hparams: hp.BaseHParams = None
|
||||||
self._history: tf.keras.callbacks.History = None
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._model.compile(
|
self._model.compile(
|
||||||
optimizer=self._optimizer,
|
optimizer=self._optimizer,
|
||||||
loss=self._loss_function,
|
loss=self._loss_function,
|
||||||
metrics=[self._metric_function])
|
metrics=self._metric_functions,
|
||||||
|
)
|
||||||
|
|
||||||
latest_checkpoint = (
|
latest_checkpoint = (
|
||||||
tf.train.latest_checkpoint(checkpoint_path)
|
tf.train.latest_checkpoint(checkpoint_path)
|
||||||
|
|
|
@ -80,10 +80,30 @@ py_test(
|
||||||
deps = [":loss_functions"],
|
deps = [":loss_functions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Public target of the MediaPipe Model Maker Quantization Config.
|
||||||
|
|
||||||
|
# Quantization Config is used to export a quantized model. Please refer
|
||||||
|
# to the specific task documentations such as:
|
||||||
|
# https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize
|
||||||
|
# for usage information.
|
||||||
|
######################################################################
|
||||||
|
py_library(
|
||||||
|
name = "metrics",
|
||||||
|
srcs = ["metrics.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "metrics_test",
|
||||||
|
srcs = ["metrics_test.py"],
|
||||||
|
deps = [":metrics"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "quantization",
|
name = "quantization",
|
||||||
srcs = ["quantization.py"],
|
srcs = ["quantization.py"],
|
||||||
srcs_version = "PY3",
|
srcs_version = "PY3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
|
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
104
mediapipe/model_maker/python/core/utils/metrics.py
Normal file
104
mediapipe/model_maker/python/core/utils/metrics.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Metrics utility library."""
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def _get_binary_sparse_metric(metric: tf.metrics.Metric):
|
||||||
|
"""Helper method to create a BinarySparse version of a tf.keras.Metric.
|
||||||
|
|
||||||
|
BinarySparse is an implementation where the update_state(y_true, y_pred) takes
|
||||||
|
in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only
|
||||||
|
supports the binary classification case, and that class_id=0 is the negative
|
||||||
|
class and class_id=1 is the positive class.
|
||||||
|
|
||||||
|
Currently supported tf.metric.Metric classes
|
||||||
|
1. BinarySparseRecallAtPrecision
|
||||||
|
2. BinarySparsePrecisionAtRecall
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: A tf.metric.Metric class for which we want to generate a
|
||||||
|
BinarySparse version of this metric.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A class for the BinarySparse version of the specified tf.metrics.Metric
|
||||||
|
"""
|
||||||
|
|
||||||
|
class BinarySparseMetric(metric):
|
||||||
|
"""A BinarySparse wrapper class for a tf.keras.Metric.
|
||||||
|
|
||||||
|
This class has the same parameters and functions as the underlying
|
||||||
|
metric class. For example, the parameters for BinarySparseRecallAtPrecision
|
||||||
|
is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint
|
||||||
|
is that class_id must be set to 1 (or not specified) for the Binary metric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if 'class_id' in kwargs and kwargs['class_id'] != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f'Custom BinarySparseMetric for class:{metric.__name__} is '
|
||||||
|
'only supported for class_id=1, got class_id='
|
||||||
|
f'{kwargs["class_id"]} instead'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kwargs['class_id'] = 1
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
|
||||||
|
y_true_one_hot = tf.one_hot(y_true, 2)
|
||||||
|
return super().update_state(
|
||||||
|
y_true_one_hot, y_pred, sample_weight=sample_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return BinarySparseMetric
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sparse_metric(metric: tf.metrics.Metric):
|
||||||
|
"""Helper method to create a Sparse version of a tf.keras.Metric.
|
||||||
|
|
||||||
|
Sparse is an implementation where the update_state(y_true, y_pred) takes in
|
||||||
|
shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes).
|
||||||
|
|
||||||
|
Currently supported tf.metrics.Metric classes:
|
||||||
|
1. tf.metrics.Recall
|
||||||
|
2. tf.metrics.Precision
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: A tf.metric.Metric class for which we want to generate a Sparse
|
||||||
|
version of this metric.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A class for the Sparse version of the specified tf.keras.Metric.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SparseMetric(metric):
|
||||||
|
"""A Sparse wrapper class for a tf.keras.Metric."""
|
||||||
|
|
||||||
|
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||||
|
y_pred = tf.math.argmax(y_pred, axis=-1)
|
||||||
|
return super().update_state(y_true, y_pred, sample_weight=sample_weight)
|
||||||
|
|
||||||
|
return SparseMetric
|
||||||
|
|
||||||
|
|
||||||
|
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
|
||||||
|
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
|
||||||
|
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
|
||||||
|
tf.metrics.RecallAtPrecision
|
||||||
|
)
|
||||||
|
BinarySparsePrecisionAtRecall = _get_binary_sparse_metric(
|
||||||
|
tf.metrics.PrecisionAtRecall
|
||||||
|
)
|
74
mediapipe/model_maker/python/core/utils/metrics_test.py
Normal file
74
mediapipe/model_maker/python/core/utils/metrics_test.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.y_true = [0, 0, 1, 1, 0, 1]
|
||||||
|
self.y_pred = [
|
||||||
|
[0.9, 0.1], # 0, 0 y
|
||||||
|
[0.8, 0.2], # 0, 0 y
|
||||||
|
[0.7, 0.3], # 0, 1 n
|
||||||
|
[0.6, 0.4], # 0, 1 n
|
||||||
|
[0.3, 0.7], # 1, 0 y
|
||||||
|
[0.3, 0.7], # 1, 1 y
|
||||||
|
]
|
||||||
|
self.num_classes = 3
|
||||||
|
|
||||||
|
def _assert_metric_equals(self, metric, value):
|
||||||
|
metric.update_state(self.y_true, self.y_pred)
|
||||||
|
self.assertEqual(metric.result(), value)
|
||||||
|
|
||||||
|
def test_sparse_recall(self):
|
||||||
|
metric = metrics.SparseRecall()
|
||||||
|
self._assert_metric_equals(metric, 1 / 3)
|
||||||
|
|
||||||
|
def test_sparse_precision(self):
|
||||||
|
metric = metrics.SparsePrecision()
|
||||||
|
self._assert_metric_equals(metric, 1 / 2)
|
||||||
|
|
||||||
|
def test_binary_sparse_recall_at_precision(self):
|
||||||
|
metric = metrics.BinarySparseRecallAtPrecision(1.0)
|
||||||
|
self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1
|
||||||
|
metric = metrics.BinarySparseRecallAtPrecision(0.4)
|
||||||
|
self._assert_metric_equals(metric, 1.0)
|
||||||
|
|
||||||
|
def test_binary_sparse_precision_at_recall(self):
|
||||||
|
metric = metrics.BinarySparsePrecisionAtRecall(1.0)
|
||||||
|
self._assert_metric_equals(metric, 3 / 4)
|
||||||
|
metric = metrics.BinarySparsePrecisionAtRecall(0.7)
|
||||||
|
self._assert_metric_equals(metric, 3 / 4)
|
||||||
|
|
||||||
|
def test_binary_sparse_precision_at_recall_class_id_error(self):
|
||||||
|
# class_id=1 case should not error
|
||||||
|
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1)
|
||||||
|
# class_id=2 case should error
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
'Custom BinarySparseMetric for class:PrecisionAtRecall is only'
|
||||||
|
' supported for class_id=1, got class_id=2 instead',
|
||||||
|
):
|
||||||
|
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
|
@ -31,11 +31,11 @@ py_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":text_classifier",
|
":text_classifier",
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,12 +45,18 @@ py_library(
|
||||||
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "hyperparameters",
|
||||||
|
srcs = ["hyperparameters.py"],
|
||||||
|
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "model_spec",
|
name = "model_spec",
|
||||||
srcs = ["model_spec.py"],
|
srcs = ["model_spec.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
"//mediapipe/model_maker/python/core/utils:file_util",
|
"//mediapipe/model_maker/python/core/utils:file_util",
|
||||||
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
"//mediapipe/model_maker/python/text/core:bert_model_spec",
|
||||||
],
|
],
|
||||||
|
@ -61,9 +67,9 @@ py_test(
|
||||||
srcs = ["model_spec_test.py"],
|
srcs = ["model_spec_test.py"],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,9 +106,9 @@ py_library(
|
||||||
name = "text_classifier_options",
|
name = "text_classifier_options",
|
||||||
srcs = ["text_classifier_options.py"],
|
srcs = ["text_classifier_options.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -111,13 +117,14 @@ py_library(
|
||||||
srcs = ["text_classifier.py"],
|
srcs = ["text_classifier.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
":hyperparameters",
|
||||||
":model_options",
|
":model_options",
|
||||||
":model_spec",
|
":model_spec",
|
||||||
":preprocessor",
|
":preprocessor",
|
||||||
":text_classifier_options",
|
":text_classifier_options",
|
||||||
"//mediapipe/model_maker/python/core:hyperparameters",
|
|
||||||
"//mediapipe/model_maker/python/core/data:dataset",
|
"//mediapipe/model_maker/python/core/data:dataset",
|
||||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||||
|
"//mediapipe/model_maker/python/core/utils:metrics",
|
||||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||||
|
|
|
@ -13,19 +13,23 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""MediaPipe Public Python API for Text Classifier."""
|
"""MediaPipe Public Python API for Text Classifier."""
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters
|
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset
|
from mediapipe.model_maker.python.text.text_classifier import dataset
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options
|
from mediapipe.model_maker.python.text.text_classifier import model_options
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
from mediapipe.model_maker.python.text.text_classifier import model_spec
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier
|
||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||||
|
|
||||||
HParams = hyperparameters.BaseHParams
|
|
||||||
|
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
|
||||||
|
AverageWordEmbeddingModelOptions = (
|
||||||
|
model_options.AverageWordEmbeddingModelOptions
|
||||||
|
)
|
||||||
|
BertOptimizer = hyperparameters.BertOptimizer
|
||||||
|
BertHParams = hyperparameters.BertHParams
|
||||||
|
BertModelOptions = model_options.BertModelOptions
|
||||||
CSVParams = dataset.CSVParameters
|
CSVParams = dataset.CSVParameters
|
||||||
Dataset = dataset.Dataset
|
Dataset = dataset.Dataset
|
||||||
AverageWordEmbeddingModelOptions = (
|
|
||||||
model_options.AverageWordEmbeddingModelOptions)
|
|
||||||
BertModelOptions = model_options.BertModelOptions
|
|
||||||
SupportedModels = model_spec.SupportedModels
|
SupportedModels = model_spec.SupportedModels
|
||||||
TextClassifier = text_classifier.TextClassifier
|
TextClassifier = text_classifier.TextClassifier
|
||||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Hyperparameters for training object detection models."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import enum
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core import hyperparameters as hp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class AverageWordEmbeddingHParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for an AverageWordEmbeddingClassifier."""
|
||||||
|
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class BertOptimizer(enum.Enum):
|
||||||
|
"""Supported Optimizers for Bert Text Classifier."""
|
||||||
|
|
||||||
|
ADAMW = "adamw"
|
||||||
|
LAMB = "lamb"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertHParams(hp.BaseHParams):
|
||||||
|
"""The hyperparameters for a Bert Classifier.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
learning_rate: Learning rate to use for gradient descent training.
|
||||||
|
batch_size: Batch size for training.
|
||||||
|
epochs: Number of training iterations over the dataset.
|
||||||
|
optimizer: Optimizer to use for training. Only supported values are "adamw"
|
||||||
|
and "lamb".
|
||||||
|
"""
|
||||||
|
|
||||||
|
learning_rate: float = 3e-5
|
||||||
|
batch_size: int = 48
|
||||||
|
epochs: int = 2
|
||||||
|
optimizer: BertOptimizer = BertOptimizer.ADAMW
|
||||||
|
|
||||||
|
|
||||||
|
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
|
|
@ -17,13 +17,11 @@ import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
|
||||||
from mediapipe.model_maker.python.core.utils import file_util
|
from mediapipe.model_maker.python.core.utils import file_util
|
||||||
from mediapipe.model_maker.python.text.core import bert_model_spec
|
from mediapipe.model_maker.python.text.core import bert_model_spec
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
|
|
||||||
# BERT-based text classifier spec inherited from BertModelSpec
|
|
||||||
BertClassifierSpec = bert_model_spec.BertModelSpec
|
|
||||||
|
|
||||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||||
'text_classifier/mobilebert_tiny',
|
'text_classifier/mobilebert_tiny',
|
||||||
|
@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
|
||||||
is_folder=True,
|
is_folder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
EXBERT_FILES = file_util.DownloadedFiles(
|
||||||
|
'text_classifier/exbert',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AverageWordEmbeddingClassifierSpec:
|
class AverageWordEmbeddingClassifierSpec:
|
||||||
|
@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# `learning_rate` is unused for the average word embedding model
|
# `learning_rate` is unused for the average word embedding model
|
||||||
hparams: hp.BaseHParams = hp.BaseHParams(
|
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0)
|
epochs=10, batch_size=32, learning_rate=0
|
||||||
|
)
|
||||||
model_options: mo.AverageWordEmbeddingModelOptions = (
|
model_options: mo.AverageWordEmbeddingModelOptions = (
|
||||||
mo.AverageWordEmbeddingModelOptions())
|
mo.AverageWordEmbeddingModelOptions())
|
||||||
name: str = 'AverageWordEmbedding'
|
name: str = 'AverageWordEmbedding'
|
||||||
|
|
||||||
|
|
||||||
average_word_embedding_classifier_spec = functools.partial(
|
average_word_embedding_classifier_spec = functools.partial(
|
||||||
AverageWordEmbeddingClassifierSpec)
|
AverageWordEmbeddingClassifierSpec)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertClassifierSpec(bert_model_spec.BertModelSpec):
|
||||||
|
"""Specification for a Bert classifier model.
|
||||||
|
|
||||||
|
Only overrides the hparams attribute since the rest of the attributes are
|
||||||
|
inherited from the BertModelSpec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hparams: hp.BertHParams = hp.BertHParams()
|
||||||
|
|
||||||
|
|
||||||
mobilebert_classifier_spec = functools.partial(
|
mobilebert_classifier_spec = functools.partial(
|
||||||
BertClassifierSpec,
|
BertClassifierSpec,
|
||||||
downloaded_files=MOBILEBERT_TINY_FILES,
|
downloaded_files=MOBILEBERT_TINY_FILES,
|
||||||
hparams=hp.BaseHParams(
|
hparams=hp.BertHParams(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
),
|
),
|
||||||
name='MobileBert',
|
name='MobileBert',
|
||||||
tflite_input_name={
|
tflite_input_name={
|
||||||
'ids': 'serving_default_input_1:0',
|
'ids': 'serving_default_input_1:0',
|
||||||
'mask': 'serving_default_input_3:0',
|
|
||||||
'segment_ids': 'serving_default_input_2:0',
|
'segment_ids': 'serving_default_input_2:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
exbert_classifier_spec = functools.partial(
|
||||||
|
BertClassifierSpec,
|
||||||
|
downloaded_files=EXBERT_FILES,
|
||||||
|
hparams=hp.BertHParams(
|
||||||
|
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
|
||||||
|
),
|
||||||
|
name='ExBert',
|
||||||
|
tflite_input_name={
|
||||||
|
'ids': 'serving_default_input_1:0',
|
||||||
|
'segment_ids': 'serving_default_input_2:0',
|
||||||
|
'mask': 'serving_default_input_3:0',
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
|
||||||
"""Predefined text classifier model specs supported by Model Maker."""
|
"""Predefined text classifier model specs supported by Model Maker."""
|
||||||
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
|
||||||
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
|
||||||
|
EXBERT_CLASSIFIER = exbert_classifier_spec
|
||||||
|
|
|
@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.hparams,
|
model_spec_obj.hparams,
|
||||||
hp.BaseHParams(
|
hp.BertHParams(
|
||||||
epochs=3,
|
epochs=3,
|
||||||
batch_size=48,
|
batch_size=48,
|
||||||
learning_rate=3e-5,
|
learning_rate=3e-5,
|
||||||
distribution_strategy='off'))
|
distribution_strategy='off',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_predefined_average_word_embedding_spec(self):
|
def test_predefined_average_word_embedding_spec(self):
|
||||||
model_spec_obj = (
|
model_spec_obj = (
|
||||||
|
@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
dropout_rate=0.2))
|
dropout_rate=0.2))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
model_spec_obj.hparams,
|
model_spec_obj.hparams,
|
||||||
hp.BaseHParams(
|
hp.AverageWordEmbeddingHParams(
|
||||||
epochs=10,
|
epochs=10,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
learning_rate=0,
|
learning_rate=0,
|
||||||
|
@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
custom_bert_classifier_options)
|
custom_bert_classifier_options)
|
||||||
|
|
||||||
def test_custom_average_word_embedding_spec(self):
|
def test_custom_average_word_embedding_spec(self):
|
||||||
custom_hparams = hp.BaseHParams(
|
custom_hparams = hp.AverageWordEmbeddingHParams(
|
||||||
learning_rate=0.4,
|
learning_rate=0.4,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
epochs=10,
|
epochs=10,
|
||||||
|
@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
|
||||||
export_dir='foo/bar',
|
export_dir='foo/bar',
|
||||||
distribution_strategy='mirrored',
|
distribution_strategy='mirrored',
|
||||||
num_gpus=3,
|
num_gpus=3,
|
||||||
tpu='tpu/address')
|
tpu='tpu/address',
|
||||||
|
)
|
||||||
custom_average_word_embedding_model_options = (
|
custom_average_word_embedding_model_options = (
|
||||||
classifier_model_options.AverageWordEmbeddingModelOptions(
|
classifier_model_options.AverageWordEmbeddingModelOptions(
|
||||||
seq_len=512,
|
seq_len=512,
|
||||||
|
|
|
@ -19,14 +19,16 @@ import tempfile
|
||||||
from typing import Any, Optional, Sequence, Tuple
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow_addons import optimizers as tfa_optimizers
|
||||||
import tensorflow_hub as hub
|
import tensorflow_hub as hub
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
|
||||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||||
from mediapipe.model_maker.python.core.tasks import classifier
|
from mediapipe.model_maker.python.core.tasks import classifier
|
||||||
|
from mediapipe.model_maker.python.core.utils import metrics
|
||||||
from mediapipe.model_maker.python.core.utils import model_util
|
from mediapipe.model_maker.python.core.utils import model_util
|
||||||
from mediapipe.model_maker.python.core.utils import quantization
|
from mediapipe.model_maker.python.core.utils import quantization
|
||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
|
||||||
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
||||||
|
@ -54,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||||
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
|
||||||
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
|
||||||
f" got {options.supported_model}")
|
f" got {options.supported_model}")
|
||||||
if (isinstance(options.model_options, mo.BertModelOptions) and
|
if isinstance(options.model_options, mo.BertModelOptions) and (
|
||||||
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
|
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
|
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
|
||||||
|
f"{options.supported_model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextClassifier(classifier.Classifier):
|
class TextClassifier(classifier.Classifier):
|
||||||
"""API for creating and training a text classification model."""
|
"""API for creating and training a text classification model."""
|
||||||
|
|
||||||
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
|
def __init__(
|
||||||
label_names: Sequence[str]):
|
self, model_spec: Any, label_names: Sequence[str], shuffle: bool
|
||||||
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
|
model_spec=model_spec, label_names=label_names, shuffle=shuffle
|
||||||
|
)
|
||||||
self._model_spec = model_spec
|
self._model_spec = model_spec
|
||||||
self._hparams = hparams
|
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
|
||||||
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -106,7 +112,10 @@ class TextClassifier(classifier.Classifier):
|
||||||
if options.hparams is None:
|
if options.hparams is None:
|
||||||
options.hparams = options.supported_model.value().hparams
|
options.hparams = options.supported_model.value().hparams
|
||||||
|
|
||||||
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
|
if (
|
||||||
|
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
|
||||||
|
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
|
||||||
|
):
|
||||||
text_classifier = (
|
text_classifier = (
|
||||||
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
_BertClassifier.create_bert_classifier(train_data, validation_data,
|
||||||
options,
|
options,
|
||||||
|
@ -123,12 +132,24 @@ class TextClassifier(classifier.Classifier):
|
||||||
|
|
||||||
return text_classifier
|
return text_classifier
|
||||||
|
|
||||||
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(
|
||||||
|
self,
|
||||||
|
data: ds.Dataset,
|
||||||
|
batch_size: int = 32,
|
||||||
|
desired_precisions: Optional[Sequence[float]] = None,
|
||||||
|
desired_recalls: Optional[Sequence[float]] = None,
|
||||||
|
) -> Any:
|
||||||
"""Overrides Classifier.evaluate().
|
"""Overrides Classifier.evaluate().
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
data: Evaluation dataset. Must be a TextClassifier Dataset.
|
||||||
batch_size: Number of samples per evaluation step.
|
batch_size: Number of samples per evaluation step.
|
||||||
|
desired_precisions: If specified, adds a RecallAtPrecision metric per
|
||||||
|
desired_precisions[i] entry which tracks the recall given the constraint
|
||||||
|
on precision. Only supported for binary classification.
|
||||||
|
desired_recalls: If specified, adds a PrecisionAtRecall metric per
|
||||||
|
desired_recalls[i] entry which tracks the precision given the constraint
|
||||||
|
on recall. Only supported for binary classification.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The loss value and accuracy.
|
The loss value and accuracy.
|
||||||
|
@ -144,6 +165,28 @@ class TextClassifier(classifier.Classifier):
|
||||||
|
|
||||||
processed_data = self._text_preprocessor.preprocess(data)
|
processed_data = self._text_preprocessor.preprocess(data)
|
||||||
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
|
||||||
|
|
||||||
|
additional_metrics = []
|
||||||
|
if desired_precisions and len(data.label_names) == 2:
|
||||||
|
for precision in desired_precisions:
|
||||||
|
additional_metrics.append(
|
||||||
|
metrics.BinarySparseRecallAtPrecision(
|
||||||
|
precision, name=f"recall_at_precision_{precision}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if desired_recalls and len(data.label_names) == 2:
|
||||||
|
for recall in desired_recalls:
|
||||||
|
additional_metrics.append(
|
||||||
|
metrics.BinarySparsePrecisionAtRecall(
|
||||||
|
recall, name=f"precision_at_recall_{recall}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metric_functions = self._metric_functions + additional_metrics
|
||||||
|
self._model.compile(
|
||||||
|
optimizer=self._optimizer,
|
||||||
|
loss=self._loss_function,
|
||||||
|
metrics=metric_functions,
|
||||||
|
)
|
||||||
return self._model.evaluate(dataset)
|
return self._model.evaluate(dataset)
|
||||||
|
|
||||||
def export_model(
|
def export_model(
|
||||||
|
@ -161,9 +204,8 @@ class TextClassifier(classifier.Classifier):
|
||||||
path is {self._hparams.export_dir}/{model_name}.
|
path is {self._hparams.export_dir}/{model_name}.
|
||||||
quantization_config: The configuration for model quantization.
|
quantization_config: The configuration for model quantization.
|
||||||
"""
|
"""
|
||||||
if not tf.io.gfile.exists(self._hparams.export_dir):
|
|
||||||
tf.io.gfile.makedirs(self._hparams.export_dir)
|
|
||||||
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
tflite_file = os.path.join(self._hparams.export_dir, model_name)
|
||||||
|
tf.io.gfile.makedirs(os.path.dirname(tflite_file))
|
||||||
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
|
||||||
|
|
||||||
tflite_model = model_util.convert_to_tflite(
|
tflite_model = model_util.convert_to_tflite(
|
||||||
|
@ -174,7 +216,7 @@ class TextClassifier(classifier.Classifier):
|
||||||
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
|
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
|
||||||
tflite_model_with_metadata, metadata_json = writer.populate()
|
tflite_model_with_metadata, metadata_json = writer.populate()
|
||||||
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
|
||||||
with open(metadata_file, "w") as f:
|
with tf.io.gfile.GFile(metadata_file, "w") as f:
|
||||||
f.write(metadata_json)
|
f.write(metadata_json)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -191,13 +233,23 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
|
||||||
|
|
||||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
def __init__(
|
||||||
model_options: mo.AverageWordEmbeddingModelOptions,
|
self,
|
||||||
hparams: hp.BaseHParams, label_names: Sequence[str]):
|
model_spec: ms.AverageWordEmbeddingClassifierSpec,
|
||||||
super().__init__(model_spec, hparams, label_names)
|
model_options: mo.AverageWordEmbeddingModelOptions,
|
||||||
|
hparams: hp.AverageWordEmbeddingHParams,
|
||||||
|
label_names: Sequence[str],
|
||||||
|
):
|
||||||
|
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
|
self._hparams = hparams
|
||||||
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._loss_function = "sparse_categorical_crossentropy"
|
self._loss_function = "sparse_categorical_crossentropy"
|
||||||
self._metric_function = "accuracy"
|
self._metric_functions = [
|
||||||
|
"accuracy",
|
||||||
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
]
|
||||||
self._text_preprocessor: (
|
self._text_preprocessor: (
|
||||||
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
|
||||||
|
|
||||||
|
@ -306,16 +358,26 @@ class _BertClassifier(TextClassifier):
|
||||||
|
|
||||||
_INITIALIZER_RANGE = 0.02
|
_INITIALIZER_RANGE = 0.02
|
||||||
|
|
||||||
def __init__(self, model_spec: ms.BertClassifierSpec,
|
def __init__(
|
||||||
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
|
self,
|
||||||
label_names: Sequence[str]):
|
model_spec: ms.BertClassifierSpec,
|
||||||
super().__init__(model_spec, hparams, label_names)
|
model_options: mo.BertModelOptions,
|
||||||
|
hparams: hp.BertHParams,
|
||||||
|
label_names: Sequence[str],
|
||||||
|
):
|
||||||
|
super().__init__(model_spec, label_names, hparams.shuffle)
|
||||||
|
self._hparams = hparams
|
||||||
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
with self._hparams.get_strategy().scope():
|
with self._hparams.get_strategy().scope():
|
||||||
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||||
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
|
self._metric_functions = [
|
||||||
"test_accuracy", dtype=tf.float32
|
tf.keras.metrics.SparseCategoricalAccuracy(
|
||||||
)
|
"test_accuracy", dtype=tf.float32
|
||||||
|
),
|
||||||
|
metrics.SparsePrecision(name="precision", dtype=tf.float32),
|
||||||
|
metrics.SparseRecall(name="recall", dtype=tf.float32),
|
||||||
|
]
|
||||||
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -438,11 +500,26 @@ class _BertClassifier(TextClassifier):
|
||||||
initial_learning_rate=initial_lr,
|
initial_learning_rate=initial_lr,
|
||||||
decay_schedule_fn=lr_schedule,
|
decay_schedule_fn=lr_schedule,
|
||||||
warmup_steps=warmup_steps)
|
warmup_steps=warmup_steps)
|
||||||
|
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
|
||||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
|
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
|
||||||
self._optimizer.exclude_from_weight_decay(
|
)
|
||||||
var_names=["LayerNorm", "layer_norm", "bias"])
|
self._optimizer.exclude_from_weight_decay(
|
||||||
|
var_names=["LayerNorm", "layer_norm", "bias"]
|
||||||
|
)
|
||||||
|
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
|
||||||
|
self._optimizer = tfa_optimizers.LAMB(
|
||||||
|
lr_schedule,
|
||||||
|
weight_decay_rate=0.01,
|
||||||
|
epsilon=1e-6,
|
||||||
|
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
|
||||||
|
global_clipnorm=1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"BertHParams.optimizer must be set to ADAM or "
|
||||||
|
f"LAMB. Got {self._hparams.optimizer}."
|
||||||
|
)
|
||||||
|
|
||||||
def _save_vocab(self, vocab_filepath: str):
|
def _save_vocab(self, vocab_filepath: str):
|
||||||
tf.io.gfile.copy(
|
tf.io.gfile.copy(
|
||||||
|
|
|
@ -66,14 +66,16 @@ def run(data_dir,
|
||||||
quantization_config = None
|
quantization_config = None
|
||||||
if (supported_model ==
|
if (supported_model ==
|
||||||
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
|
||||||
hparams = text_classifier.HParams(
|
hparams = text_classifier.AverageWordEmbeddingHParams(
|
||||||
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
|
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
|
||||||
|
)
|
||||||
# Warning: This takes extremely long to run on CPU
|
# Warning: This takes extremely long to run on CPU
|
||||||
elif (
|
elif (
|
||||||
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
|
||||||
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
quantization_config = quantization.QuantizationConfig.for_dynamic()
|
||||||
hparams = text_classifier.HParams(
|
hparams = text_classifier.BertHParams(
|
||||||
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
|
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
|
||||||
|
)
|
||||||
|
|
||||||
# Fine-tunes the model.
|
# Fine-tunes the model.
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
|
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
|
||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
|
||||||
|
|
||||||
|
@ -34,5 +34,5 @@ class TextClassifierOptions:
|
||||||
architecture of the `supported_model`.
|
architecture of the `supported_model`.
|
||||||
"""
|
"""
|
||||||
supported_model: ms.SupportedModels
|
supported_model: ms.SupportedModels
|
||||||
hparams: Optional[hp.BaseHParams] = None
|
hparams: Optional[hp.HParams] = None
|
||||||
model_options: Optional[mo.TextClassifierModelOptions] = None
|
model_options: Optional[mo.TextClassifierModelOptions] = None
|
||||||
|
|
|
@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
def test_create_and_train_average_word_embedding_model(self):
|
def test_create_and_train_average_word_embedding_model(self):
|
||||||
train_data, validation_data = self._get_data()
|
train_data, validation_data = self._get_data()
|
||||||
options = (
|
options = text_classifier.TextClassifierOptions(
|
||||||
text_classifier.TextClassifierOptions(
|
supported_model=(
|
||||||
supported_model=(text_classifier.SupportedModels
|
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
|
||||||
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
|
),
|
||||||
hparams=text_classifier.HParams(
|
hparams=text_classifier.AverageWordEmbeddingHParams(
|
||||||
epochs=1, batch_size=1, learning_rate=0)))
|
epochs=1, batch_size=1, learning_rate=0
|
||||||
|
),
|
||||||
|
)
|
||||||
average_word_embedding_classifier = (
|
average_word_embedding_classifier = (
|
||||||
text_classifier.TextClassifier.create(train_data, validation_data,
|
text_classifier.TextClassifier.create(train_data, validation_data,
|
||||||
options))
|
options))
|
||||||
|
@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
|
||||||
options = text_classifier.TextClassifierOptions(
|
options = text_classifier.TextClassifierOptions(
|
||||||
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
|
||||||
model_options=text_classifier.BertModelOptions(
|
model_options=text_classifier.BertModelOptions(
|
||||||
do_fine_tuning=False, seq_len=2),
|
do_fine_tuning=False, seq_len=2
|
||||||
hparams=text_classifier.HParams(
|
),
|
||||||
|
hparams=text_classifier.BertHParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
learning_rate=3e-5,
|
learning_rate=3e-5,
|
||||||
distribution_strategy='off'))
|
distribution_strategy='off',
|
||||||
|
),
|
||||||
|
)
|
||||||
bert_classifier = text_classifier.TextClassifier.create(
|
bert_classifier = text_classifier.TextClassifier.create(
|
||||||
train_data, validation_data, options)
|
train_data, validation_data, options)
|
||||||
|
|
||||||
|
|
|
@ -20,13 +20,6 @@ licenses(["notice"])
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "testdata",
|
|
||||||
srcs = glob([
|
|
||||||
"testdata/**",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "constants",
|
name = "constants",
|
||||||
srcs = ["constants.py"],
|
srcs = ["constants.py"],
|
||||||
|
@ -72,18 +65,11 @@ py_library(
|
||||||
name = "dataset",
|
name = "dataset",
|
||||||
srcs = ["dataset.py"],
|
srcs = ["dataset.py"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":constants",
|
||||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
"//mediapipe/model_maker/python/vision/core:image_utils",
|
"//mediapipe/python:_framework_bindings",
|
||||||
],
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
)
|
"//mediapipe/tasks/python/vision:face_aligner",
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "dataset_test",
|
|
||||||
srcs = ["dataset_test.py"],
|
|
||||||
data = [":testdata"],
|
|
||||||
deps = [
|
|
||||||
":dataset",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
|
||||||
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
|
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
|
||||||
|
'face_stylizer/face_landmarker_v2.task',
|
||||||
|
'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
|
||||||
|
is_folder=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Dimension of the input style vector to the decoder
|
# Dimension of the input style vector to the decoder
|
||||||
STYLE_DIM = 512
|
STYLE_DIM = 512
|
||||||
|
|
|
@ -13,13 +13,37 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Face stylizer dataset library."""
|
"""Face stylizer dataset library."""
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
from mediapipe.model_maker.python.vision.core import image_utils
|
from mediapipe.model_maker.python.vision.face_stylizer import constants
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.vision import face_aligner
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_face_dataset(
|
||||||
|
all_image_paths: Sequence[str],
|
||||||
|
) -> Sequence[tf.Tensor]:
|
||||||
|
"""Preprocess face image dataset by aligning the face."""
|
||||||
|
path = constants.FACE_ALIGNER_TASK_FILES.get_path()
|
||||||
|
base_options = base_options_module.BaseOptions(model_asset_path=path)
|
||||||
|
options = face_aligner.FaceAlignerOptions(base_options=base_options)
|
||||||
|
aligner = face_aligner.FaceAligner.create_from_options(options)
|
||||||
|
|
||||||
|
preprocessed_images = []
|
||||||
|
for path in all_image_paths:
|
||||||
|
tf.compat.v1.logging.info('Preprocess image %s', path)
|
||||||
|
image = image_module.Image.create_from_file(path)
|
||||||
|
aligned_image = aligner.align(image)
|
||||||
|
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
|
||||||
|
preprocessed_images.append(aligned_image_tensor)
|
||||||
|
|
||||||
|
return preprocessed_images
|
||||||
|
|
||||||
|
|
||||||
# TODO: Change to a unlabeled dataset if it makes sense.
|
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||||
|
@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
):
|
):
|
||||||
raise ValueError('No images found under given directory')
|
raise ValueError('No images found under given directory')
|
||||||
|
|
||||||
|
image_data = _preprocess_face_dataset(all_image_paths)
|
||||||
label_names = sorted(
|
label_names = sorted(
|
||||||
name
|
name
|
||||||
for name in os.listdir(data_root)
|
for name in os.listdir(data_root)
|
||||||
|
@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
|
||||||
for path in all_image_paths
|
for path in all_image_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
|
||||||
|
|
||||||
image_ds = path_ds.map(
|
|
||||||
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load label
|
# Load label
|
||||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
|
|
@ -12,8 +12,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self._test_data_dirname = 'input/style'
|
|
||||||
|
|
||||||
def test_from_folder(self):
|
def test_from_folder(self):
|
||||||
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
test_data_dirname = 'input/style'
|
||||||
|
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
|
||||||
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
self.assertEqual(data.num_classes, 2)
|
self.assertEqual(data.num_classes, 2)
|
||||||
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
"""APIs to train face stylization model."""
|
"""APIs to train face stylization model."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -54,7 +54,6 @@ class FaceStylizer(object):
|
||||||
self._model_spec = model_spec
|
self._model_spec = model_spec
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
# TODO: Support face alignment in image preprocessor.
|
|
||||||
self._preprocessor = image_preprocessing.Preprocessor(
|
self._preprocessor = image_preprocessing.Preprocessor(
|
||||||
input_shape=self._model_spec.input_image_shape,
|
input_shape=self._model_spec.input_image_shape,
|
||||||
num_classes=1,
|
num_classes=1,
|
||||||
|
@ -128,7 +127,7 @@ class FaceStylizer(object):
|
||||||
def _train_model(
|
def _train_model(
|
||||||
self,
|
self,
|
||||||
train_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
preprocessor: Optional[Callable[..., bool]] = None,
|
preprocessor: Optional[Callable[..., Any]] = None,
|
||||||
):
|
):
|
||||||
"""Trains the face stylizer model.
|
"""Trains the face stylizer model.
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier):
|
||||||
self._model_options = model_options
|
self._model_options = model_options
|
||||||
self._hparams = hparams
|
self._hparams = hparams
|
||||||
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
|
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
|
||||||
self._metric_function = 'categorical_accuracy'
|
self._metric_functions = ['categorical_accuracy']
|
||||||
self._optimizer = 'adam'
|
self._optimizer = 'adam'
|
||||||
self._callbacks = self._get_callbacks()
|
self._callbacks = self._get_callbacks()
|
||||||
self._history = None
|
self._history = None
|
||||||
|
|
|
@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier):
|
||||||
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
|
||||||
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
|
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
|
||||||
label_smoothing=self._hparams.label_smoothing)
|
label_smoothing=self._hparams.label_smoothing)
|
||||||
self._metric_function = 'accuracy'
|
self._metric_functions = ['accuracy']
|
||||||
self._history = None # Training history returned from `keras_model.fit`.
|
self._history = None # Training history returned from `keras_model.fit`.
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
)
|
)
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def _build_model(self) -> tf.keras.Model:
|
def _build_model(self, omit_l2=False) -> tf.keras.Model:
|
||||||
"""Builds a RetinaNet object detector model."""
|
"""Builds a RetinaNet object detector model."""
|
||||||
input_specs = tf.keras.layers.InputSpec(
|
input_specs = tf.keras.layers.InputSpec(
|
||||||
shape=[None] + self._model_spec.input_image_shape
|
shape=[None] + self._model_spec.input_image_shape
|
||||||
)
|
)
|
||||||
l2_regularizer = tf.keras.regularizers.l2(
|
if omit_l2:
|
||||||
self._model_options.l2_weight_decay / 2.0
|
l2_regularizer = None
|
||||||
)
|
else:
|
||||||
|
l2_regularizer = tf.keras.regularizers.l2(
|
||||||
|
self._model_options.l2_weight_decay / 2.0
|
||||||
|
)
|
||||||
model_config = self._get_model_config()
|
model_config = self._get_model_config()
|
||||||
|
|
||||||
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
|
||||||
|
@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
|
||||||
|
|
||||||
def convert_to_qat(self) -> None:
|
def convert_to_qat(self) -> None:
|
||||||
"""Converts the model to a QAT RetinaNet model."""
|
"""Converts the model to a QAT RetinaNet model."""
|
||||||
model = self._build_model()
|
model = self._build_model(omit_l2=True)
|
||||||
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
|
||||||
model(dummy_input, training=True)
|
model(dummy_input, training=True)
|
||||||
model.set_weights(self._model.get_weights())
|
model.set_weights(self._model.get_weights())
|
||||||
|
|
|
@ -43,6 +43,7 @@ cc_library(
|
||||||
":base_audio_task_api",
|
":base_audio_task_api",
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -60,13 +61,8 @@ class AudioTaskApiFactory {
|
||||||
"Task graph config should only contain one task subgraph node.",
|
"Task graph config should only contain one task subgraph node.",
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
} else {
|
} else {
|
||||||
if (!node.options().HasExtension(Options::ext)) {
|
MP_RETURN_IF_ERROR(
|
||||||
return CreateStatusWithPayload(
|
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat(node.calculator(),
|
|
||||||
" is missing the required task options field."),
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
found_task_subgraph = true;
|
found_task_subgraph = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
|
"@com_google_absl//absl/log",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
|
|
@ -17,15 +17,56 @@ limitations under the License.
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include "absl/log/log.h"
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace core {
|
namespace core {
|
||||||
|
|
||||||
|
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
|
||||||
|
const BaseOptions::CpuOptions& options) {
|
||||||
|
proto::Acceleration acceleration_proto = proto::Acceleration();
|
||||||
|
acceleration_proto.mutable_tflite();
|
||||||
|
return acceleration_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
|
||||||
|
const BaseOptions::GpuOptions& options) {
|
||||||
|
proto::Acceleration acceleration_proto = proto::Acceleration();
|
||||||
|
auto* gpu = acceleration_proto.mutable_gpu();
|
||||||
|
gpu->set_use_advanced_gpu_api(true);
|
||||||
|
gpu->set_cached_kernel_path(options.cached_kernel_path);
|
||||||
|
gpu->set_serialized_model_dir(options.serialized_model_dir);
|
||||||
|
gpu->set_model_token(options.model_token);
|
||||||
|
return acceleration_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SetDelegateOptionsOrDie(const BaseOptions* base_options,
|
||||||
|
proto::BaseOptions& base_options_proto) {
|
||||||
|
if (base_options->delegate_options.has_value()) {
|
||||||
|
if (!std::holds_alternative<T>(*base_options->delegate_options)) {
|
||||||
|
LOG(FATAL) << "Specified Delegate type does not match the provided "
|
||||||
|
"delegate options.";
|
||||||
|
} else {
|
||||||
|
std::visit(
|
||||||
|
[&base_options_proto](const auto& delegate_options) {
|
||||||
|
proto::Acceleration acceleration_proto =
|
||||||
|
ConvertDelegateOptionsToAccelerationProto(delegate_options);
|
||||||
|
base_options_proto.mutable_acceleration()->Swap(
|
||||||
|
&acceleration_proto);
|
||||||
|
},
|
||||||
|
*base_options->delegate_options);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
|
proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
|
||||||
proto::BaseOptions base_options_proto;
|
proto::BaseOptions base_options_proto;
|
||||||
if (!base_options->model_asset_path.empty()) {
|
if (!base_options->model_asset_path.empty()) {
|
||||||
|
@ -53,11 +94,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
|
||||||
switch (base_options->delegate) {
|
switch (base_options->delegate) {
|
||||||
case BaseOptions::Delegate::CPU:
|
case BaseOptions::Delegate::CPU:
|
||||||
base_options_proto.mutable_acceleration()->mutable_tflite();
|
base_options_proto.mutable_acceleration()->mutable_tflite();
|
||||||
|
SetDelegateOptionsOrDie<BaseOptions::CpuOptions>(base_options,
|
||||||
|
base_options_proto);
|
||||||
break;
|
break;
|
||||||
case BaseOptions::Delegate::GPU:
|
case BaseOptions::Delegate::GPU:
|
||||||
base_options_proto.mutable_acceleration()
|
base_options_proto.mutable_acceleration()
|
||||||
->mutable_gpu()
|
->mutable_gpu()
|
||||||
->set_use_advanced_gpu_api(true);
|
->set_use_advanced_gpu_api(true);
|
||||||
|
SetDelegateOptionsOrDie<BaseOptions::GpuOptions>(base_options,
|
||||||
|
base_options_proto);
|
||||||
break;
|
break;
|
||||||
case BaseOptions::Delegate::EDGETPU_NNAPI:
|
case BaseOptions::Delegate::EDGETPU_NNAPI:
|
||||||
base_options_proto.mutable_acceleration()
|
base_options_proto.mutable_acceleration()
|
||||||
|
@ -65,7 +110,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
|
||||||
->set_accelerator_name("google-edgetpu");
|
->set_accelerator_name("google-edgetpu");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return base_options_proto;
|
return base_options_proto;
|
||||||
}
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -17,7 +17,9 @@ limitations under the License.
|
||||||
#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
|
#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
|
@ -38,7 +40,8 @@ struct BaseOptions {
|
||||||
std::string model_asset_path = "";
|
std::string model_asset_path = "";
|
||||||
|
|
||||||
// The delegate to run MediaPipe. If the delegate is not set, the default
|
// The delegate to run MediaPipe. If the delegate is not set, the default
|
||||||
// delegate CPU is used.
|
// delegate CPU is used. Use `delegate_options` to configure advanced
|
||||||
|
// features of the selected delegate."
|
||||||
enum Delegate {
|
enum Delegate {
|
||||||
CPU = 0,
|
CPU = 0,
|
||||||
GPU = 1,
|
GPU = 1,
|
||||||
|
@ -48,6 +51,30 @@ struct BaseOptions {
|
||||||
|
|
||||||
Delegate delegate = CPU;
|
Delegate delegate = CPU;
|
||||||
|
|
||||||
|
// Options for CPU.
|
||||||
|
struct CpuOptions {};
|
||||||
|
|
||||||
|
// Options for GPU.
|
||||||
|
struct GpuOptions {
|
||||||
|
// Load pre-compiled serialized binary cache to accelerate init process.
|
||||||
|
// Only available on Android. Kernel caching will only be enabled if this
|
||||||
|
// path is set. NOTE: binary cache usage may be skipped if valid serialized
|
||||||
|
// model, specified by "serialized_model_dir", exists.
|
||||||
|
std::string cached_kernel_path;
|
||||||
|
|
||||||
|
// A dir to load from and save to a pre-compiled serialized model used to
|
||||||
|
// accelerate init process.
|
||||||
|
// NOTE: serialized model takes precedence over binary cache
|
||||||
|
// specified by "cached_kernel_path", which still can be used if
|
||||||
|
// serialized model is invalid or missing.
|
||||||
|
std::string serialized_model_dir;
|
||||||
|
|
||||||
|
// Unique token identifying the model. Used in conjunction with
|
||||||
|
// "serialized_model_dir". It is the caller's responsibility to ensure
|
||||||
|
// there is no clash of the tokens.
|
||||||
|
std::string model_token;
|
||||||
|
};
|
||||||
|
|
||||||
// The file descriptor to a file opened with open(2), with optional additional
|
// The file descriptor to a file opened with open(2), with optional additional
|
||||||
// offset and length information.
|
// offset and length information.
|
||||||
struct FileDescriptorMeta {
|
struct FileDescriptorMeta {
|
||||||
|
@ -67,6 +94,10 @@ struct BaseOptions {
|
||||||
// built-in Ops.
|
// built-in Ops.
|
||||||
std::unique_ptr<tflite::OpResolver> op_resolver =
|
std::unique_ptr<tflite::OpResolver> op_resolver =
|
||||||
absl::make_unique<MediaPipeBuiltinOpResolver>();
|
absl::make_unique<MediaPipeBuiltinOpResolver>();
|
||||||
|
|
||||||
|
// Options for the chosen delegate. If not set, the default delegate options
|
||||||
|
// is used.
|
||||||
|
std::optional<std::variant<CpuOptions, GpuOptions>> delegate_options;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts a BaseOptions to a BaseOptionsProto.
|
// Converts a BaseOptions to a BaseOptionsProto.
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
@ -11,6 +14,8 @@
|
||||||
|
|
||||||
constexpr char kTestModelBundlePath[] =
|
constexpr char kTestModelBundlePath[] =
|
||||||
"mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
|
"mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
|
||||||
|
constexpr char kCachedModelDir[] = "/data/local/tmp";
|
||||||
|
constexpr char kModelToken[] = "dummy_model_token";
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -40,6 +45,44 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) {
|
||||||
EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
|
EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DelegateOptionsTest, SucceedCpuOptions) {
|
||||||
|
BaseOptions base_options;
|
||||||
|
base_options.delegate = BaseOptions::Delegate::CPU;
|
||||||
|
BaseOptions::CpuOptions cpu_options;
|
||||||
|
base_options.delegate_options = cpu_options;
|
||||||
|
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
|
||||||
|
EXPECT_TRUE(proto.acceleration().has_tflite());
|
||||||
|
ASSERT_FALSE(proto.acceleration().has_gpu());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DelegateOptionsTest, SucceedGpuOptions) {
|
||||||
|
BaseOptions base_options;
|
||||||
|
base_options.delegate = BaseOptions::Delegate::GPU;
|
||||||
|
BaseOptions::GpuOptions gpu_options;
|
||||||
|
gpu_options.cached_kernel_path = kCachedModelDir;
|
||||||
|
gpu_options.model_token = kModelToken;
|
||||||
|
base_options.delegate_options = gpu_options;
|
||||||
|
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
|
||||||
|
ASSERT_TRUE(proto.acceleration().has_gpu());
|
||||||
|
ASSERT_FALSE(proto.acceleration().has_tflite());
|
||||||
|
EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
|
||||||
|
EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir);
|
||||||
|
EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) {
|
||||||
|
BaseOptions base_options;
|
||||||
|
base_options.delegate = BaseOptions::Delegate::CPU;
|
||||||
|
BaseOptions::GpuOptions gpu_options;
|
||||||
|
gpu_options.cached_kernel_path = kCachedModelDir;
|
||||||
|
gpu_options.model_token = kModelToken;
|
||||||
|
base_options.delegate_options = gpu_options;
|
||||||
|
ASSERT_DEATH(
|
||||||
|
{ proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); },
|
||||||
|
"Specified Delegate type does not match the provided "
|
||||||
|
"delegate options.");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
|
|
|
@ -81,7 +81,6 @@ class TaskApiFactory {
|
||||||
return std::make_unique<T>(std::move(runner));
|
return std::make_unique<T>(std::move(runner));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename Options>
|
template <typename Options>
|
||||||
static absl::Status CheckHasValidOptions(
|
static absl::Status CheckHasValidOptions(
|
||||||
const CalculatorGraphConfig::Node& node) {
|
const CalculatorGraphConfig::Node& node) {
|
||||||
|
|
|
@ -86,10 +86,9 @@ cc_test(
|
||||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
"@com_google_absl//absl/flags:flag",
|
"@com_google_absl//absl/flags:flag",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:cord",
|
"@com_google_absl//absl/strings:cord",
|
||||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
"@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep
|
||||||
"@org_tensorflow//tensorflow/lite:test_util",
|
"@org_tensorflow//tensorflow/lite:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,8 +15,6 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -24,7 +22,6 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/cord.h"
|
#include "absl/strings/cord.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
|
|
@ -45,7 +45,7 @@ constexpr char kUniversalSentenceEncoderModel[] =
|
||||||
// Tolerance for embedding vector coordinate values.
|
// Tolerance for embedding vector coordinate values.
|
||||||
constexpr float kEpsilon = 1e-4;
|
constexpr float kEpsilon = 1e-4;
|
||||||
// Tolerancy for cosine similarity evaluation.
|
// Tolerancy for cosine similarity evaluation.
|
||||||
constexpr double kSimilarityTolerancy = 1e-6;
|
constexpr double kSimilarityTolerancy = 2e-2;
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
|
@ -79,6 +79,8 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
|
||||||
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
|
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
|
||||||
|
#elif defined(__FMA__)
|
||||||
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.3605f, kEpsilon);
|
||||||
#else
|
#else
|
||||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
@ -87,7 +89,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
|
||||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||||
ASSERT_EQ(result1.embeddings.size(), 1);
|
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||||
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
|
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
|
||||||
|
#ifdef __FMA__
|
||||||
|
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 21.254150f, kEpsilon);
|
||||||
|
#else
|
||||||
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
|
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
|
||||||
|
#endif
|
||||||
|
|
||||||
// Check cosine similarity.
|
// Check cosine similarity.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
|
|
@ -43,6 +43,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:rect",
|
"//mediapipe/tasks/cc/components/containers:rect",
|
||||||
"//mediapipe/tasks/cc/core:base_task_api",
|
"//mediapipe/tasks/cc/core:base_task_api",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
|
@ -58,6 +59,7 @@ cc_library(
|
||||||
":base_vision_task_api",
|
":base_vision_task_api",
|
||||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/framework/calculator.pb.h"
|
#include "mediapipe/framework/calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
|
||||||
|
@ -60,13 +61,8 @@ class VisionTaskApiFactory {
|
||||||
"Task graph config should only contain one task subgraph node.",
|
"Task graph config should only contain one task subgraph node.",
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
||||||
} else {
|
} else {
|
||||||
if (!node.options().HasExtension(Options::ext)) {
|
MP_RETURN_IF_ERROR(
|
||||||
return CreateStatusWithPayload(
|
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
|
||||||
absl::StatusCode::kInvalidArgument,
|
|
||||||
absl::StrCat(node.calculator(),
|
|
||||||
" is missing the required task options field."),
|
|
||||||
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
|
|
||||||
}
|
|
||||||
found_task_subgraph = true;
|
found_task_subgraph = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,6 +153,8 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: open source hand joints graph
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "hand_landmarker_result",
|
name = "hand_landmarker_result",
|
||||||
srcs = ["hand_landmarker_result.cc"],
|
srcs = ["hand_landmarker_result.cc"],
|
||||||
|
|
|
@ -41,3 +41,5 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: open source hand joints graph
|
||||||
|
|
|
@ -52,6 +52,7 @@ cc_library(
|
||||||
name = "interactive_segmenter_graph",
|
name = "interactive_segmenter_graph",
|
||||||
srcs = ["interactive_segmenter_graph.cc"],
|
srcs = ["interactive_segmenter_graph.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||||
"//mediapipe/calculators/image:set_alpha_calculator",
|
"//mediapipe/calculators/image:set_alpha_calculator",
|
||||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||||
"//mediapipe/calculators/util:flat_color_image_calculator",
|
"//mediapipe/calculators/util:flat_color_image_calculator",
|
||||||
|
@ -60,6 +61,7 @@ cc_library(
|
||||||
"//mediapipe/calculators/util:to_image_calculator",
|
"//mediapipe/calculators/util:to_image_calculator",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
|
|
@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
@ -35,6 +37,51 @@ namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace interactive_segmenter {
|
namespace interactive_segmenter {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// A calculator to add thickness to the render data according to the image size,
|
||||||
|
// so that the render data is scale invariant to the image size. If the render
|
||||||
|
// data already has thickness, it will be kept as is.
|
||||||
|
class AddThicknessToRenderDataCalculator : public api2::Node {
|
||||||
|
public:
|
||||||
|
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
|
||||||
|
static constexpr api2::Input<mediapipe::RenderData> kRenderDataIn{
|
||||||
|
"RENDER_DATA"};
|
||||||
|
static constexpr api2::Output<mediapipe::RenderData> kRenderDataOut{
|
||||||
|
"RENDER_DATA"};
|
||||||
|
|
||||||
|
static constexpr int kModelInputTensorWidth = 512;
|
||||||
|
static constexpr int kModelInputTensorHeight = 512;
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
mediapipe::RenderData render_data = kRenderDataIn(cc).Get();
|
||||||
|
Image image = kImageIn(cc).Get();
|
||||||
|
double thickness = std::max(
|
||||||
|
std::max(image.width() / static_cast<double>(kModelInputTensorWidth),
|
||||||
|
image.height() / static_cast<double>(kModelInputTensorHeight)),
|
||||||
|
1.0);
|
||||||
|
|
||||||
|
for (auto& annotation : *render_data.mutable_render_annotations()) {
|
||||||
|
if (!annotation.has_thickness()) {
|
||||||
|
annotation.set_thickness(thickness);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
kRenderDataOut(cc).Send(render_data);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
|
||||||
|
// moved to next line.
|
||||||
|
// clang-format off
|
||||||
|
MEDIAPIPE_REGISTER_NODE(
|
||||||
|
::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator);
|
||||||
|
// clang-format on
|
||||||
|
// NOLINTEND
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -59,6 +106,7 @@ constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
|
||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
|
||||||
constexpr absl::string_view kRoiTag{"ROI"};
|
constexpr absl::string_view kRoiTag{"ROI"};
|
||||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
|
||||||
|
constexpr absl::string_view kRenderDataTag{"RENDER_DATA"};
|
||||||
|
|
||||||
// Updates the graph to return `roi` stream which has same dimension as
|
// Updates the graph to return `roi` stream which has same dimension as
|
||||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
|
||||||
|
@ -69,14 +117,23 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||||
const absl::string_view image_tag_with_suffix =
|
const absl::string_view image_tag_with_suffix =
|
||||||
use_gpu ? kImageGpuTag : kImageCpuTag;
|
use_gpu ? kImageGpuTag : kImageCpuTag;
|
||||||
|
|
||||||
|
// Adds thickness to the render data so that the render data is scale
|
||||||
|
// invariant to the input image size.
|
||||||
|
auto& add_thickness = graph.AddNode(
|
||||||
|
"mediapipe::tasks::vision::interactive_segmenter::internal::"
|
||||||
|
"AddThicknessToRenderDataCalculator");
|
||||||
|
image >> add_thickness.In(kImageTag);
|
||||||
|
roi >> add_thickness.In(kRenderDataTag);
|
||||||
|
auto roi_with_thickness = add_thickness.Out(kRenderDataTag);
|
||||||
|
|
||||||
// Generates a blank canvas with same size as input image.
|
// Generates a blank canvas with same size as input image.
|
||||||
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
|
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
|
||||||
auto& flat_color_options =
|
auto& flat_color_options =
|
||||||
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
|
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
|
||||||
// SetAlphaCalculator only takes 1st channel.
|
// SetAlphaCalculator only takes 1st channel.
|
||||||
flat_color_options.mutable_color()->set_r(0);
|
flat_color_options.mutable_color()->set_r(0);
|
||||||
image >> flat_color.In(kImageTag)[0];
|
image >> flat_color.In(kImageTag);
|
||||||
auto blank_canvas = flat_color.Out(kImageTag)[0];
|
auto blank_canvas = flat_color.Out(kImageTag);
|
||||||
|
|
||||||
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
auto& from_mp_image = graph.AddNode("FromImageCalculator");
|
||||||
blank_canvas >> from_mp_image.In(kImageTag);
|
blank_canvas >> from_mp_image.In(kImageTag);
|
||||||
|
@ -85,7 +142,7 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
|
||||||
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
|
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
|
||||||
blank_canvas_in_cpu_or_gpu >>
|
blank_canvas_in_cpu_or_gpu >>
|
||||||
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
|
||||||
roi >> roi_to_alpha.In(0);
|
roi_with_thickness >> roi_to_alpha.In(0);
|
||||||
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
|
||||||
|
|
||||||
return alpha;
|
return alpha;
|
||||||
|
@ -163,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
|
||||||
image >> from_mp_image.In(kImageTag);
|
image >> from_mp_image.In(kImageTag);
|
||||||
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
|
||||||
|
|
||||||
|
// Creates an RGBA image with model input tensor size.
|
||||||
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
|
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
|
||||||
|
|
||||||
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
|
auto& set_alpha = graph.AddNode("SetAlphaCalculator");
|
||||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/tool/test_util.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
|
@ -70,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
|
||||||
// Golden mask for the dogs in cats_and_dogs.jpg.
|
// Golden mask for the dogs in cats_and_dogs.jpg.
|
||||||
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
|
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
|
||||||
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
|
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
|
||||||
|
constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"};
|
||||||
|
constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"};
|
||||||
|
constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"};
|
||||||
|
constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"};
|
||||||
|
|
||||||
constexpr float kGoldenMaskSimilarity = 0.97;
|
constexpr float kGoldenMaskSimilarity = 0.97;
|
||||||
|
|
||||||
|
@ -183,6 +188,7 @@ struct InteractiveSegmenterTestParams {
|
||||||
std::string test_name;
|
std::string test_name;
|
||||||
RegionOfInterest::Format format;
|
RegionOfInterest::Format format;
|
||||||
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
|
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
|
||||||
|
absl::string_view input_image_file;
|
||||||
absl::string_view golden_mask_file;
|
absl::string_view golden_mask_file;
|
||||||
float similarity_threshold;
|
float similarity_threshold;
|
||||||
};
|
};
|
||||||
|
@ -220,8 +226,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
const InteractiveSegmenterTestParams& params = GetParam();
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image,
|
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
params.input_image_file)));
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
@ -244,6 +250,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
|
||||||
EXPECT_THAT(actual_mask,
|
EXPECT_THAT(actual_mask,
|
||||||
SimilarToUint8Mask(expected_mask, params.similarity_threshold,
|
SimilarToUint8Mask(expected_mask, params.similarity_threshold,
|
||||||
kGoldenMaskMagnificationFactor));
|
kGoldenMaskMagnificationFactor));
|
||||||
|
|
||||||
|
cv::Mat visualized_mask;
|
||||||
|
actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
|
||||||
|
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
|
||||||
|
visualized_mask.cols, visualized_mask.rows,
|
||||||
|
visualized_mask.step, visualized_mask.data,
|
||||||
|
[visualized_mask](uint8_t[]) {});
|
||||||
|
MP_EXPECT_OK(SavePngTestOutput(
|
||||||
|
visualized_image, absl::StrFormat("%s_category_mask", params.test_name)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
|
@ -252,8 +267,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
const InteractiveSegmenterTestParams& params = GetParam();
|
const InteractiveSegmenterTestParams& params = GetParam();
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image image,
|
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
|
params.input_image_file)));
|
||||||
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
auto options = std::make_unique<InteractiveSegmenterOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kPtmModel);
|
JoinPath("./", kTestDataDirectory, kPtmModel);
|
||||||
|
@ -275,6 +290,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
|
||||||
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
|
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
|
||||||
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
|
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
|
||||||
params.similarity_threshold));
|
params.similarity_threshold));
|
||||||
|
cv::Mat visualized_mask;
|
||||||
|
actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
|
||||||
|
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
|
||||||
|
visualized_mask.cols, visualized_mask.rows,
|
||||||
|
visualized_mask.step, visualized_mask.data,
|
||||||
|
[visualized_mask](uint8_t[]) {});
|
||||||
|
MP_EXPECT_OK(SavePngTestOutput(
|
||||||
|
visualized_image,
|
||||||
|
absl::StrFormat("%s_confidence_mask", params.test_name)));
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
@ -282,21 +306,28 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
::testing::ValuesIn<InteractiveSegmenterTestParams>(
|
||||||
{// Keypoint input.
|
{// Keypoint input.
|
||||||
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
|
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1,
|
||||||
|
0.84f},
|
||||||
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
|
||||||
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
|
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2,
|
||||||
kGoldenMaskSimilarity},
|
kGoldenMaskSimilarity},
|
||||||
|
{"PenguinsSmall", RegionOfInterest::Format::kKeyPoint,
|
||||||
|
NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask,
|
||||||
|
0.9f},
|
||||||
|
{"PenguinsLarge", RegionOfInterest::Format::kKeyPoint,
|
||||||
|
NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask,
|
||||||
|
0.9f},
|
||||||
// Scribble input.
|
// Scribble input.
|
||||||
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
|
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
|
||||||
std::vector{NormalizedKeypoint{0.44, 0.70},
|
std::vector{NormalizedKeypoint{0.44, 0.70},
|
||||||
NormalizedKeypoint{0.44, 0.71},
|
NormalizedKeypoint{0.44, 0.71},
|
||||||
NormalizedKeypoint{0.44, 0.72}},
|
NormalizedKeypoint{0.44, 0.72}},
|
||||||
kCatsAndDogsMaskDog1, 0.84f},
|
kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f},
|
||||||
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
|
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
|
||||||
std::vector{NormalizedKeypoint{0.66, 0.66},
|
std::vector{NormalizedKeypoint{0.66, 0.66},
|
||||||
NormalizedKeypoint{0.66, 0.67},
|
NormalizedKeypoint{0.66, 0.67},
|
||||||
NormalizedKeypoint{0.66, 0.68}},
|
NormalizedKeypoint{0.66, 0.68}},
|
||||||
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
|
kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
|
||||||
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
|
||||||
info) { return info.param.test_name; });
|
info) { return info.param.test_name; });
|
||||||
|
|
||||||
|
|
|
@ -60,3 +60,9 @@ objc_library(
|
||||||
srcs = ["sources/MPPLandmark.m"],
|
srcs = ["sources/MPPLandmark.m"],
|
||||||
hdrs = ["sources/MPPLandmark.h"],
|
hdrs = ["sources/MPPLandmark.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPConnection",
|
||||||
|
srcs = ["sources/MPPConnection.m"],
|
||||||
|
hdrs = ["sources/MPPConnection.h"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
/** The value class representing a landmark connection. */
|
||||||
|
NS_SWIFT_NAME(Connection)
|
||||||
|
@interface MPPConnection : NSObject
|
||||||
|
|
||||||
|
@property(nonatomic, readonly) NSUInteger start;
|
||||||
|
|
||||||
|
@property(nonatomic, readonly) NSUInteger end;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes a new `MPPConnection` with the start and end landmarks integer constants.
|
||||||
|
*
|
||||||
|
* @param start The integer representing the starting landmark of the connection.
|
||||||
|
* @param end The integer representing the ending landmark of the connection.
|
||||||
|
*
|
||||||
|
* @return An instance of `MPPConnection` initialized with the given start and end landmarks integer
|
||||||
|
* constants.
|
||||||
|
*/
|
||||||
|
- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
|
||||||
|
|
||||||
|
@implementation MPPConnection
|
||||||
|
|
||||||
|
- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end {
|
||||||
|
self = [super init];
|
||||||
|
if (self) {
|
||||||
|
_start = start;
|
||||||
|
_end = end;
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -54,3 +54,20 @@ ios_unit_test(
|
||||||
":MPPImageObjcTestLibrary",
|
":MPPImageObjcTestLibrary",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPMaskObjcTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["MPPMaskTests.m"],
|
||||||
|
deps = ["//mediapipe/tasks/ios/vision/core:MPPMask"],
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPMaskObjcTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||||
|
deps = [
|
||||||
|
":MPPMaskObjcTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
127
mediapipe/tasks/ios/test/vision/core/MPPMaskTests.m
Normal file
127
mediapipe/tasks/ios/test/vision/core/MPPMaskTests.m
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
|
||||||
|
|
||||||
|
#import <XCTest/XCTest.h>
|
||||||
|
|
||||||
|
/** Unit tests for `MPPMask`. */
|
||||||
|
@interface MPPMaskTests : XCTestCase
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPMaskTests
|
||||||
|
|
||||||
|
#pragma mark - Tests
|
||||||
|
|
||||||
|
- (void)testInitWithUInt8ArrayNoCopySucceeds {
|
||||||
|
|
||||||
|
NSInteger width = 2;
|
||||||
|
NSInteger height = 3;
|
||||||
|
|
||||||
|
UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
|
||||||
|
float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
|
||||||
|
|
||||||
|
MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:NO];
|
||||||
|
|
||||||
|
XCTAssertEqual(mask.width, width);
|
||||||
|
XCTAssertEqual(mask.height, height);
|
||||||
|
|
||||||
|
// Test if UInt8 mask is not copied.
|
||||||
|
XCTAssertEqual(mask.uint8Data, (const UInt8*)uint8Data);
|
||||||
|
XCTAssertNotEqual(mask.float32Data, NULL);
|
||||||
|
|
||||||
|
for (int i = 0 ; i < width * height ; i ++) {
|
||||||
|
XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f, @"index i = %d", i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if repeated Float32 mask accesses return the same array in memory.
|
||||||
|
XCTAssertEqual(mask.float32Data, mask.float32Data);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testInitWithUInt8ArrayCopySucceeds {
|
||||||
|
|
||||||
|
NSInteger width = 2;
|
||||||
|
NSInteger height = 3;
|
||||||
|
|
||||||
|
UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
|
||||||
|
float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
|
||||||
|
|
||||||
|
MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:YES];
|
||||||
|
|
||||||
|
XCTAssertEqual(mask.width, width);
|
||||||
|
XCTAssertEqual(mask.height, height);
|
||||||
|
|
||||||
|
// Test if UInt8 mask is copied.
|
||||||
|
XCTAssertNotEqual(mask.uint8Data, (const UInt8*)uint8Data);
|
||||||
|
XCTAssertNotEqual(mask.float32Data, NULL);
|
||||||
|
|
||||||
|
for (int i = 0 ; i < width * height ; i ++) {
|
||||||
|
XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if repeated Float32 mask accesses return the same array in memory.
|
||||||
|
XCTAssertEqual(mask.float32Data, mask.float32Data);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testInitWithFloat32ArrayNoCopySucceeds {
|
||||||
|
|
||||||
|
NSInteger width = 2;
|
||||||
|
NSInteger height = 3;
|
||||||
|
|
||||||
|
UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
|
||||||
|
float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
|
||||||
|
MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:NO];
|
||||||
|
|
||||||
|
XCTAssertEqual(mask.width, width);
|
||||||
|
XCTAssertEqual(mask.height, height);
|
||||||
|
|
||||||
|
// Test if Float32 mask is not copied.
|
||||||
|
XCTAssertEqual(mask.float32Data, (const float*)float32Data);
|
||||||
|
XCTAssertNotEqual(mask.uint8Data, NULL);
|
||||||
|
|
||||||
|
for (int i = 0 ; i < width * height ; i ++) {
|
||||||
|
XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if repeated UInt8 mask accesses return the same array in memory.
|
||||||
|
XCTAssertEqual(mask.uint8Data, mask.uint8Data);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testInitWithFloat32ArrayCopySucceeds {
|
||||||
|
|
||||||
|
NSInteger width = 2;
|
||||||
|
NSInteger height = 3;
|
||||||
|
|
||||||
|
UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
|
||||||
|
float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
|
||||||
|
|
||||||
|
MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:YES];
|
||||||
|
|
||||||
|
XCTAssertEqual(mask.width, width);
|
||||||
|
XCTAssertEqual(mask.height, height);
|
||||||
|
|
||||||
|
// Test if Float32 mask is copied.
|
||||||
|
XCTAssertNotEqual(mask.float32Data, (const float*)float32Data);
|
||||||
|
XCTAssertNotEqual(mask.uint8Data, NULL);
|
||||||
|
|
||||||
|
for (int i = 0 ; i < width * height ; i ++) {
|
||||||
|
XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if repeated UInt8 mask accesses return the same array in memory.
|
||||||
|
XCTAssertEqual(mask.uint8Data, mask.uint8Data);
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -155,12 +155,12 @@ static const float kKeypointErrorThreshold = 1e-2;
|
||||||
NSInteger iterationCount = 100;
|
NSInteger iterationCount = 100;
|
||||||
|
|
||||||
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
|
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
|
||||||
// normal expectation will fail if expectation.fullfill() is not called
|
// normal expectation will fail if expectation.fulfill() is not called
|
||||||
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
||||||
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
// Since it is not possible to predict how many times the expectation is supposed to be
|
// Since it is not possible to predict how many times the expectation is supposed to be
|
||||||
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
// fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
|
// `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
|
||||||
// `iterationCount` times.
|
// `iterationCount` times.
|
||||||
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||||
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
||||||
|
@ -385,13 +385,13 @@ static const float kKeypointErrorThreshold = 1e-2;
|
||||||
NSInteger iterationCount = 100;
|
NSInteger iterationCount = 100;
|
||||||
|
|
||||||
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
|
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
|
||||||
// normal expectation will fail if expectation.fullfill() is not called times. An normal
|
// normal expectation will fail if expectation.fulfill() is not called times. An normal
|
||||||
// expectation will fail if expectation.fullfill() is not called
|
// expectation will fail if expectation.fulfill() is not called
|
||||||
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
||||||
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
// Since it it not possible to determine how many times the expectation is supposed to be
|
// Since it it not possible to determine how many times the expectation is supposed to be
|
||||||
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
// fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
|
// `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
|
||||||
// `iterationCount` times.
|
// `iterationCount` times.
|
||||||
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||||
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
||||||
|
|
|
@ -174,12 +174,12 @@ constexpr float kFacialTransformationMatrixErrorThreshold = 0.2f;
|
||||||
NSInteger iterationCount = 100;
|
NSInteger iterationCount = 100;
|
||||||
|
|
||||||
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
|
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
|
||||||
// normal expectation will fail if expectation.fullfill() is not called
|
// normal expectation will fail if expectation.fulfill() is not called
|
||||||
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
||||||
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
|
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
// Since it is not possible to predict how many times the expectation is supposed to be
|
// Since it is not possible to predict how many times the expectation is supposed to be
|
||||||
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
// fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
|
// `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
|
||||||
// `iterationCount` times.
|
// `iterationCount` times.
|
||||||
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||||
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
|
||||||
|
|
62
mediapipe/tasks/ios/test/vision/gesture_recognizer/BUILD
Normal file
62
mediapipe/tasks/ios/test/vision/gesture_recognizer/BUILD
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:ios.bzl",
|
||||||
|
"MPP_TASK_MINIMUM_OS_VERSION",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||||
|
"tflite_ios_lab_runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||||
|
TFL_DEFAULT_TAGS = [
|
||||||
|
"apple",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Following sanitizer tests are not supported by iOS test targets.
|
||||||
|
TFL_DISABLED_SANITIZER_TAGS = [
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
]
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPGestureRecognizerObjcTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["MPPGestureRecognizerTests.m"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
"-x objective-c++",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/vision:gesture_recognizer.task",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/gesture_recognizer/utils:MPPGestureRecognizerResultProtobufHelpers",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
|
||||||
|
"//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer",
|
||||||
|
] + select({
|
||||||
|
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||||
|
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||||
|
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||||
|
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPGestureRecognizerObjcTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||||
|
deps = [
|
||||||
|
":MPPGestureRecognizerObjcTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,706 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <XCTest/XCTest.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizer.h"
|
||||||
|
|
||||||
|
static NSString *const kPbFileExtension = @"pbtxt";
|
||||||
|
|
||||||
|
typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
|
||||||
|
|
||||||
|
static ResourceFileInfo *const kGestureRecognizerBundleAssetFile =
|
||||||
|
@{@"name" : @"gesture_recognizer", @"type" : @"task"};
|
||||||
|
|
||||||
|
static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kFistImage = @{@"name" : @"fist", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kPointingUpRotatedImage =
|
||||||
|
@{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
|
||||||
|
|
||||||
|
static ResourceFileInfo *const kExpectedFistLandmarksFile =
|
||||||
|
@{@"name" : @"fist_landmarks", @"type" : kPbFileExtension};
|
||||||
|
static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
|
||||||
|
@{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
|
||||||
|
|
||||||
|
static NSString *const kFistLabel = @"Closed_Fist";
|
||||||
|
static NSString *const kExpectedThumbUpLabel = @"Thumb_Up";
|
||||||
|
static NSString *const kExpectedPointingUpLabel = @"Pointing_Up";
|
||||||
|
static NSString *const kRockLabel = @"Rock";
|
||||||
|
|
||||||
|
static const NSInteger kGestureExpectedIndex = -1;
|
||||||
|
|
||||||
|
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
|
|
||||||
|
static NSString *const kLiveStreamTestsDictGestureRecognizerKey = @"gesture_recognizer";
|
||||||
|
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
|
|
||||||
|
#define AssertEqualErrors(error, expectedError) \
|
||||||
|
XCTAssertNotNil(error); \
|
||||||
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
|
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||||
|
|
||||||
|
#define AssertEqualGestures(gesture, expectedGesture, handIndex, gestureIndex) \
|
||||||
|
XCTAssertEqual(gesture.index, kGestureExpectedIndex, @"hand index = %d gesture index j = %d", \
|
||||||
|
handIndex, gestureIndex); \
|
||||||
|
XCTAssertEqualObjects(gesture.categoryName, expectedGesture.categoryName, \
|
||||||
|
@"hand index = %d gesture index j = %d", handIndex, gestureIndex);
|
||||||
|
|
||||||
|
#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex) \
|
||||||
|
XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance, \
|
||||||
|
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
|
||||||
|
XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance, \
|
||||||
|
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
|
||||||
|
|
||||||
|
#define AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult) \
|
||||||
|
XCTAssertTrue(gestureRecognizerResult.gestures.count == 0); \
|
||||||
|
XCTAssertTrue(gestureRecognizerResult.handedness.count == 0); \
|
||||||
|
XCTAssertTrue(gestureRecognizerResult.landmarks.count == 0); \
|
||||||
|
XCTAssertTrue(gestureRecognizerResult.worldLandmarks.count == 0);
|
||||||
|
|
||||||
|
@interface MPPGestureRecognizerTests : XCTestCase <MPPGestureRecognizerLiveStreamDelegate> {
|
||||||
|
NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
|
||||||
|
NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
|
||||||
|
}
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPGestureRecognizerTests
|
||||||
|
|
||||||
|
#pragma mark Expected Results
|
||||||
|
|
||||||
|
+ (MPPGestureRecognizerResult *)emptyGestureRecognizerResult {
|
||||||
|
return [[MPPGestureRecognizerResult alloc] initWithGestures:@[]
|
||||||
|
handedness:@[]
|
||||||
|
landmarks:@[]
|
||||||
|
worldLandmarks:@[]
|
||||||
|
timestampInMilliseconds:0];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (MPPGestureRecognizerResult *)thumbUpGestureRecognizerResult {
|
||||||
|
NSString *filePath =
|
||||||
|
[MPPGestureRecognizerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
|
||||||
|
|
||||||
|
return [MPPGestureRecognizerResult
|
||||||
|
gestureRecognizerResultsFromProtobufFileWithName:filePath
|
||||||
|
gestureLabel:kExpectedThumbUpLabel
|
||||||
|
shouldRemoveZPosition:YES];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (MPPGestureRecognizerResult *)fistGestureRecognizerResultWithLabel:(NSString *)gestureLabel {
|
||||||
|
NSString *filePath = [MPPGestureRecognizerTests filePathWithFileInfo:kExpectedFistLandmarksFile];
|
||||||
|
|
||||||
|
return [MPPGestureRecognizerResult gestureRecognizerResultsFromProtobufFileWithName:filePath
|
||||||
|
gestureLabel:gestureLabel
|
||||||
|
shouldRemoveZPosition:YES];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Assert Gesture Recognizer Results
|
||||||
|
|
||||||
|
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandLandmarks:
|
||||||
|
(NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
|
||||||
|
XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
|
||||||
|
if (multiHandLandmarks.count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
|
||||||
|
NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
|
||||||
|
|
||||||
|
XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
|
||||||
|
for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
|
||||||
|
MPPNormalizedLandmark *landmark = topHandLandmarks[i];
|
||||||
|
XCTAssertNotNil(landmark);
|
||||||
|
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
|
||||||
|
(NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
|
||||||
|
XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
|
||||||
|
if (expectedMultiHandWorldLandmarks.count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
|
||||||
|
NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
|
||||||
|
|
||||||
|
XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
|
||||||
|
for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
|
||||||
|
MPPLandmark *landmark = topHandWorldLandmarks[i];
|
||||||
|
XCTAssertNotNil(landmark);
|
||||||
|
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertMultiHandGestures:(NSArray<NSArray<MPPCategory *> *> *)multiHandGestures
|
||||||
|
areApproximatelyEqualToExpectedMultiHandGestures:
|
||||||
|
(NSArray<NSArray<MPPCategory *> *> *)expectedMultiHandGestures {
|
||||||
|
XCTAssertEqual(multiHandGestures.count, expectedMultiHandGestures.count);
|
||||||
|
if (multiHandGestures.count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
NSArray<MPPCategory *> *topHandGestures = multiHandGestures[0];
|
||||||
|
NSArray<MPPCategory *> *expectedTopHandGestures = expectedMultiHandGestures[0];
|
||||||
|
|
||||||
|
XCTAssertEqual(topHandGestures.count, expectedTopHandGestures.count);
|
||||||
|
for (int i = 0; i < expectedTopHandGestures.count; i++) {
|
||||||
|
MPPCategory *gesture = topHandGestures[i];
|
||||||
|
XCTAssertNotNil(gesture);
|
||||||
|
AssertEqualGestures(gesture, expectedTopHandGestures[i], 0, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertGestureRecognizerResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:
|
||||||
|
(MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
|
||||||
|
[self assertMultiHandLandmarks:gestureRecognizerResult.landmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandLandmarks:expectedGestureRecognizerResult.landmarks];
|
||||||
|
[self assertMultiHandWorldLandmarks:gestureRecognizerResult.worldLandmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedGestureRecognizerResult
|
||||||
|
.worldLandmarks];
|
||||||
|
[self assertMultiHandGestures:gestureRecognizerResult.gestures
|
||||||
|
areApproximatelyEqualToExpectedMultiHandGestures:expectedGestureRecognizerResult.gestures];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertResultsOfRecognizeImageWithFileInfo:(ResourceFileInfo *)fileInfo
|
||||||
|
usingGestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
|
||||||
|
approximatelyEqualsGestureRecognizerResult:
|
||||||
|
(MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult =
|
||||||
|
[self recognizeImageWithFileInfo:fileInfo usingGestureRecognizer:gestureRecognizer];
|
||||||
|
[self assertGestureRecognizerResult:gestureRecognizerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:expectedGestureRecognizerResult];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark File
|
||||||
|
|
||||||
|
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
|
||||||
|
NSString *filePath = [MPPGestureRecognizerTests filePathWithName:fileInfo[@"name"]
|
||||||
|
extension:fileInfo[@"type"]];
|
||||||
|
return filePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||||
|
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
|
||||||
|
ofType:extension];
|
||||||
|
return filePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Gesture Recognizer Initializers
|
||||||
|
|
||||||
|
- (MPPGestureRecognizerOptions *)gestureRecognizerOptionsWithModelFileInfo:
|
||||||
|
(ResourceFileInfo *)modelFileInfo {
|
||||||
|
NSString *modelPath = [MPPGestureRecognizerTests filePathWithFileInfo:modelFileInfo];
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[[MPPGestureRecognizerOptions alloc] init];
|
||||||
|
gestureRecognizerOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
|
return gestureRecognizerOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPGestureRecognizer *)createGestureRecognizerWithOptionsSucceeds:
|
||||||
|
(MPPGestureRecognizerOptions *)gestureRecognizerOptions {
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:nil];
|
||||||
|
XCTAssertNotNil(gestureRecognizer);
|
||||||
|
|
||||||
|
return gestureRecognizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertCreateGestureRecognizerWithOptions:
|
||||||
|
(MPPGestureRecognizerOptions *)gestureRecognizerOptions
|
||||||
|
failsWithExpectedError:(NSError *)expectedError {
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:&error];
|
||||||
|
|
||||||
|
XCTAssertNil(gestureRecognizer);
|
||||||
|
AssertEqualErrors(error, expectedError);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Recognize Helpers
|
||||||
|
|
||||||
|
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
|
||||||
|
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
|
||||||
|
fileName:fileInfo[@"name"]
|
||||||
|
ofType:fileInfo[@"type"]];
|
||||||
|
XCTAssertNotNil(image);
|
||||||
|
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
|
||||||
|
orientation:(UIImageOrientation)orientation {
|
||||||
|
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
|
||||||
|
fileName:fileInfo[@"name"]
|
||||||
|
ofType:fileInfo[@"type"]
|
||||||
|
orientation:orientation];
|
||||||
|
XCTAssertNotNil(image);
|
||||||
|
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPGestureRecognizerResult *)recognizeImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
|
||||||
|
usingGestureRecognizer:
|
||||||
|
(MPPGestureRecognizer *)gestureRecognizer {
|
||||||
|
MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
|
||||||
|
error:nil];
|
||||||
|
XCTAssertNotNil(gestureRecognizerResult);
|
||||||
|
|
||||||
|
return gestureRecognizerResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark General Tests
|
||||||
|
|
||||||
|
- (void)testRecognizeWithModelPathSucceeds {
|
||||||
|
NSString *modelPath =
|
||||||
|
[MPPGestureRecognizerTests filePathWithFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[[MPPGestureRecognizer alloc] initWithModelPath:modelPath error:nil];
|
||||||
|
XCTAssertNotNil(gestureRecognizer);
|
||||||
|
|
||||||
|
[self assertResultsOfRecognizeImageWithFileInfo:kThumbUpImage
|
||||||
|
usingGestureRecognizer:gestureRecognizer
|
||||||
|
approximatelyEqualsGestureRecognizerResult:[MPPGestureRecognizerTests
|
||||||
|
thumbUpGestureRecognizerResult]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithEmptyResultsSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult =
|
||||||
|
[self recognizeImageWithFileInfo:kNoHandsImage usingGestureRecognizer:gestureRecognizer];
|
||||||
|
AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithScoreThresholdSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult =
|
||||||
|
[self recognizeImageWithFileInfo:kThumbUpImage usingGestureRecognizer:gestureRecognizer];
|
||||||
|
|
||||||
|
MPPGestureRecognizerResult *expectedGestureRecognizerResult =
|
||||||
|
[MPPGestureRecognizerTests thumbUpGestureRecognizerResult];
|
||||||
|
|
||||||
|
XCTAssertTrue(gestureRecognizerResult.gestures.count == 1);
|
||||||
|
AssertEqualGestures(gestureRecognizerResult.gestures[0][0],
|
||||||
|
expectedGestureRecognizerResult.gestures[0][0], 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithNumHandsSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
const NSInteger numHands = 2;
|
||||||
|
gestureRecognizerOptions.numHands = numHands;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult =
|
||||||
|
[self recognizeImageWithFileInfo:kTwoHandsImage usingGestureRecognizer:gestureRecognizer];
|
||||||
|
|
||||||
|
XCTAssertTrue(gestureRecognizerResult.handedness.count == numHands);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithRotationSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
gestureRecognizerOptions.numHands = 1;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
|
||||||
|
orientation:UIImageOrientationRight];
|
||||||
|
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
|
||||||
|
error:nil];
|
||||||
|
|
||||||
|
XCTAssertNotNil(gestureRecognizerResult);
|
||||||
|
|
||||||
|
XCTAssertEqual(gestureRecognizerResult.gestures.count, 1);
|
||||||
|
XCTAssertEqualObjects(gestureRecognizerResult.gestures[0][0].categoryName,
|
||||||
|
kExpectedPointingUpLabel);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithCannedGestureFistSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
gestureRecognizerOptions.numHands = 1;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
|
||||||
|
[self assertResultsOfRecognizeImageWithFileInfo:kFistImage
|
||||||
|
usingGestureRecognizer:gestureRecognizer
|
||||||
|
approximatelyEqualsGestureRecognizerResult:
|
||||||
|
[MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithAllowGestureFistSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
|
||||||
|
|
||||||
|
gestureRecognizerOptions.numHands = 1;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
|
||||||
|
[self assertResultsOfRecognizeImageWithFileInfo:kFistImage
|
||||||
|
usingGestureRecognizer:gestureRecognizer
|
||||||
|
approximatelyEqualsGestureRecognizerResult:
|
||||||
|
[MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithDenyGestureFistSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
|
||||||
|
|
||||||
|
gestureRecognizerOptions.numHands = 1;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult =
|
||||||
|
[self recognizeImageWithFileInfo:kFistImage usingGestureRecognizer:gestureRecognizer];
|
||||||
|
AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithPreferAllowlistOverDenylistSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *gestureRecognizerOptions =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
|
||||||
|
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
|
||||||
|
|
||||||
|
gestureRecognizerOptions.numHands = 1;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
|
||||||
|
|
||||||
|
[self assertResultsOfRecognizeImageWithFileInfo:kFistImage
|
||||||
|
usingGestureRecognizer:gestureRecognizer
|
||||||
|
approximatelyEqualsGestureRecognizerResult:
|
||||||
|
[MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Running Mode Tests
|
||||||
|
|
||||||
|
- (void)testCreateGestureRecognizerFailsWithDelegateInNonLiveStreamMode {
|
||||||
|
MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
|
||||||
|
for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
options.runningMode = runningModesToTest[i];
|
||||||
|
options.gestureRecognizerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
[self assertCreateGestureRecognizerWithOptions:options
|
||||||
|
failsWithExpectedError:
|
||||||
|
[NSError
|
||||||
|
errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"The vision task is in image or video mode. The "
|
||||||
|
@"delegate must not be set in the task's options."
|
||||||
|
}]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testCreateGestureRecognizerFailsWithMissingDelegateInLiveStreamMode {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
|
||||||
|
[self
|
||||||
|
assertCreateGestureRecognizerWithOptions:options
|
||||||
|
failsWithExpectedError:
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"The vision task is in live stream mode. An "
|
||||||
|
@"object must be set as the delegate of the task "
|
||||||
|
@"in its options to ensure asynchronous delivery "
|
||||||
|
@"of results."
|
||||||
|
}]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeFailsWithCallingWrongApiInImageMode {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kFistImage];
|
||||||
|
|
||||||
|
NSError *liveStreamApiCallError;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&liveStreamApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedLiveStreamApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
|
||||||
|
@"stream mode. Current Running Mode: Image"
|
||||||
|
}];
|
||||||
|
|
||||||
|
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
||||||
|
|
||||||
|
NSError *videoApiCallError;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&videoApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedVideoApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"video mode. Current Running Mode: Image"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeFailsWithCallingWrongApiInVideoMode {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeVideo;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kFistImage];
|
||||||
|
|
||||||
|
NSError *liveStreamApiCallError;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&liveStreamApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedLiveStreamApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
|
||||||
|
@"stream mode. Current Running Mode: Video"
|
||||||
|
}];
|
||||||
|
|
||||||
|
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
||||||
|
|
||||||
|
NSError *imageApiCallError;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedImageApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"image mode. Current Running Mode: Video"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeFailsWithCallingWrongApiInLiveStreamMode {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
options.gestureRecognizerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kFistImage];
|
||||||
|
|
||||||
|
NSError *imageApiCallError;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedImageApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"image mode. Current Running Mode: Live Stream"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
||||||
|
|
||||||
|
NSError *videoApiCallError;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&videoApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedVideoApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"video mode. Current Running Mode: Live Stream"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithVideoModeSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeVideo;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
MPPGestureRecognizerResult *gestureRecognizerResult =
|
||||||
|
[gestureRecognizer recognizeVideoFrame:image timestampInMilliseconds:i error:nil];
|
||||||
|
[self assertGestureRecognizerResult:gestureRecognizerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
|
||||||
|
thumbUpGestureRecognizerResult]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithOutOfOrderTimestampsAndLiveStreamModeFails {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
options.gestureRecognizerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||||
|
initWithDescription:@"recognizeWithOutOfOrderTimestampsAndLiveStream"];
|
||||||
|
|
||||||
|
expectation.expectedFulfillmentCount = 1;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
_outOfOrderTimestampTestDict = @{
|
||||||
|
kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
|
||||||
|
kLiveStreamTestsDictExpectationKey : expectation
|
||||||
|
};
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image timestampInMilliseconds:1 error:nil]);
|
||||||
|
|
||||||
|
NSError *error;
|
||||||
|
XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&error]);
|
||||||
|
|
||||||
|
NSError *expectedError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(error, expectedError);
|
||||||
|
|
||||||
|
NSTimeInterval timeout = 0.5f;
|
||||||
|
[self waitForExpectations:@[ expectation ] timeout:timeout];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testRecognizeWithLiveStreamModeSucceeds {
|
||||||
|
MPPGestureRecognizerOptions *options =
|
||||||
|
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
options.gestureRecognizerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
NSInteger iterationCount = 100;
|
||||||
|
|
||||||
|
// Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
|
||||||
|
// times. An normal expectation will fail if expectation.fulfill() is not called
|
||||||
|
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
||||||
|
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
|
// Since in our case we cannot predict how many times the expectation is supposed to be fulfilled
|
||||||
|
// setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
|
// `expectation.isInverted = true` ensures that test succeeds ifexpectation is fulfilled <=
|
||||||
|
// `iterationCount` times.
|
||||||
|
XCTestExpectation *expectation =
|
||||||
|
[[XCTestExpectation alloc] initWithDescription:@"recognizeWithLiveStream"];
|
||||||
|
|
||||||
|
expectation.expectedFulfillmentCount = iterationCount + 1;
|
||||||
|
expectation.inverted = YES;
|
||||||
|
|
||||||
|
MPPGestureRecognizer *gestureRecognizer =
|
||||||
|
[self createGestureRecognizerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
_liveStreamSucceedsTestDict = @{
|
||||||
|
kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
|
||||||
|
kLiveStreamTestsDictExpectationKey : expectation
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
|
||||||
|
// with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
|
||||||
|
// `CMSampleBuffer`.
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
for (int i = 0; i < iterationCount; i++) {
|
||||||
|
XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image
|
||||||
|
timestampInMilliseconds:i
|
||||||
|
error:nil]);
|
||||||
|
}
|
||||||
|
|
||||||
|
NSTimeInterval timeout = 0.5f;
|
||||||
|
[self waitForExpectations:@[ expectation ] timeout:timeout];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)gestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
|
||||||
|
didFinishRecognitionWithResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
|
||||||
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
|
error:(NSError *)error {
|
||||||
|
[self assertGestureRecognizerResult:gestureRecognizerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
|
||||||
|
thumbUpGestureRecognizerResult]];
|
||||||
|
|
||||||
|
if (gestureRecognizer == _outOfOrderTimestampTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
|
||||||
|
[_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
||||||
|
} else if (gestureRecognizer ==
|
||||||
|
_liveStreamSucceedsTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
|
||||||
|
[_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -0,0 +1,22 @@
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPGestureRecognizerResultProtobufHelpers",
|
||||||
|
srcs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.mm"],
|
||||||
|
hdrs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
"-x objective-c++",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
|
||||||
|
"//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizerResult",
|
||||||
|
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
@interface MPPGestureRecognizerResult (ProtobufHelpers)
|
||||||
|
|
||||||
|
+ (MPPGestureRecognizerResult *)
|
||||||
|
gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
|
||||||
|
gestureLabel:(NSString *)gestureLabel
|
||||||
|
shouldRemoveZPosition:(BOOL)removeZPosition;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,65 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+Helpers.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
|
||||||
|
#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ClassificationListProto = ::mediapipe::ClassificationList;
|
||||||
|
using ClassificationProto = ::mediapipe::Classification;
|
||||||
|
using LandmarksDetectionResultProto =
|
||||||
|
::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||||
|
using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
@implementation MPPGestureRecognizerResult (ProtobufHelpers)
|
||||||
|
|
||||||
|
+ (MPPGestureRecognizerResult *)
|
||||||
|
gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
|
||||||
|
gestureLabel:(NSString *)gestureLabel
|
||||||
|
shouldRemoveZPosition:(BOOL)removeZPosition {
|
||||||
|
LandmarksDetectionResultProto landmarkDetectionResultProto;
|
||||||
|
|
||||||
|
if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (removeZPosition) {
|
||||||
|
// Remove z position of landmarks, because they are not used in correctness testing. For video
|
||||||
|
// or live stream mode, the z positions varies a lot during tracking from frame to frame.
|
||||||
|
for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
|
||||||
|
auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
|
||||||
|
landmark.clear_z();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ClassificationListProto gesturesProto;
|
||||||
|
ClassificationProto *classificationProto = gesturesProto.add_classification();
|
||||||
|
classificationProto->set_label([gestureLabel UTF8String]);
|
||||||
|
|
||||||
|
return [MPPGestureRecognizerResult
|
||||||
|
gestureRecognizerResultWithHandGesturesProto:{gesturesProto}
|
||||||
|
handednessProto:{landmarkDetectionResultProto.classifications()}
|
||||||
|
handLandmarksProto:{landmarkDetectionResultProto.landmarks()}
|
||||||
|
worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
|
||||||
|
timestampInMilliseconds:0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
62
mediapipe/tasks/ios/test/vision/hand_landmarker/BUILD
Normal file
62
mediapipe/tasks/ios/test/vision/hand_landmarker/BUILD
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:ios.bzl",
|
||||||
|
"MPP_TASK_MINIMUM_OS_VERSION",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
|
||||||
|
"tflite_ios_lab_runner",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
|
||||||
|
TFL_DEFAULT_TAGS = [
|
||||||
|
"apple",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Following sanitizer tests are not supported by iOS test targets.
|
||||||
|
TFL_DISABLED_SANITIZER_TAGS = [
|
||||||
|
"noasan",
|
||||||
|
"nomsan",
|
||||||
|
"notsan",
|
||||||
|
]
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPHandLandmarkerObjcTestLibrary",
|
||||||
|
testonly = 1,
|
||||||
|
srcs = ["MPPHandLandmarkerTests.m"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
"-x objective-c++",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/vision:hand_landmarker.task",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/hand_landmarker/utils:MPPHandLandmarkerResultProtobufHelpers",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
|
||||||
|
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker",
|
||||||
|
] + select({
|
||||||
|
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||||
|
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||||
|
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
|
||||||
|
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
ios_unit_test(
|
||||||
|
name = "MPPHandLandmarkerObjcTest",
|
||||||
|
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
|
||||||
|
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||||
|
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
|
||||||
|
deps = [
|
||||||
|
":MPPHandLandmarkerObjcTestLibrary",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,557 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <XCTest/XCTest.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarker.h"
|
||||||
|
|
||||||
|
static NSString *const kPbFileExtension = @"pbtxt";
|
||||||
|
|
||||||
|
typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
|
||||||
|
|
||||||
|
static ResourceFileInfo *const kHandLandmarkerBundleAssetFile =
|
||||||
|
@{@"name" : @"hand_landmarker", @"type" : @"task"};
|
||||||
|
|
||||||
|
static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
|
||||||
|
static ResourceFileInfo *const kPointingUpRotatedImage =
|
||||||
|
@{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
|
||||||
|
|
||||||
|
static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
|
||||||
|
@{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
|
||||||
|
static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
|
||||||
|
@{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
|
||||||
|
|
||||||
|
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
|
|
||||||
|
static NSString *const kLiveStreamTestsDictHandLandmarkerKey = @"gesture_recognizer";
|
||||||
|
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
|
|
||||||
|
#define AssertEqualErrors(error, expectedError) \
|
||||||
|
XCTAssertNotNil(error); \
|
||||||
|
XCTAssertEqualObjects(error.domain, expectedError.domain); \
|
||||||
|
XCTAssertEqual(error.code, expectedError.code); \
|
||||||
|
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
|
||||||
|
|
||||||
|
#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex) \
|
||||||
|
XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance, \
|
||||||
|
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
|
||||||
|
XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance, \
|
||||||
|
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
|
||||||
|
|
||||||
|
#define AssertHandLandmarkerResultIsEmpty(handLandmarkerResult) \
|
||||||
|
XCTAssertTrue(handLandmarkerResult.handedness.count == 0); \
|
||||||
|
XCTAssertTrue(handLandmarkerResult.landmarks.count == 0); \
|
||||||
|
XCTAssertTrue(handLandmarkerResult.worldLandmarks.count == 0);
|
||||||
|
|
||||||
|
@interface MPPHandLandmarkerTests : XCTestCase <MPPHandLandmarkerLiveStreamDelegate> {
|
||||||
|
NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
|
||||||
|
NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
|
||||||
|
}
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPHandLandmarkerTests
|
||||||
|
|
||||||
|
#pragma mark Results
|
||||||
|
|
||||||
|
+ (MPPHandLandmarkerResult *)emptyHandLandmarkerResult {
|
||||||
|
return [[MPPHandLandmarkerResult alloc] initWithLandmarks:@[]
|
||||||
|
worldLandmarks:@[]
|
||||||
|
handedness:@[]
|
||||||
|
|
||||||
|
timestampInMilliseconds:0];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
|
||||||
|
NSString *filePath = [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
|
||||||
|
|
||||||
|
return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
|
||||||
|
shouldRemoveZPosition:YES];
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (MPPHandLandmarkerResult *)pointingUpRotatedHandLandmarkerResult {
|
||||||
|
NSString *filePath =
|
||||||
|
[MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
|
||||||
|
|
||||||
|
return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
|
||||||
|
shouldRemoveZPosition:YES];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandLandmarks:
|
||||||
|
(NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
|
||||||
|
XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
|
||||||
|
if (multiHandLandmarks.count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
|
||||||
|
NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
|
||||||
|
|
||||||
|
XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
|
||||||
|
for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
|
||||||
|
MPPNormalizedLandmark *landmark = topHandLandmarks[i];
|
||||||
|
XCTAssertNotNil(landmark);
|
||||||
|
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
|
||||||
|
(NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
|
||||||
|
XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
|
||||||
|
if (expectedMultiHandWorldLandmarks.count == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
|
||||||
|
NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
|
||||||
|
|
||||||
|
XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
|
||||||
|
for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
|
||||||
|
MPPLandmark *landmark = topHandWorldLandmarks[i];
|
||||||
|
XCTAssertNotNil(landmark);
|
||||||
|
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
|
||||||
|
[self assertMultiHandLandmarks:handLandmarkerResult.landmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
|
||||||
|
[self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
|
||||||
|
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedHandLandmarkerResult
|
||||||
|
.worldLandmarks];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark File
|
||||||
|
|
||||||
|
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
|
||||||
|
NSString *filePath = [MPPHandLandmarkerTests filePathWithName:fileInfo[@"name"]
|
||||||
|
extension:fileInfo[@"type"]];
|
||||||
|
return filePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
|
||||||
|
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
|
||||||
|
ofType:extension];
|
||||||
|
return filePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Hand Landmarker Initializers
|
||||||
|
|
||||||
|
- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
|
||||||
|
(ResourceFileInfo *)modelFileInfo {
|
||||||
|
NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
|
||||||
|
MPPHandLandmarkerOptions *handLandmarkerOptions = [[MPPHandLandmarkerOptions alloc] init];
|
||||||
|
handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
|
return handLandmarkerOptions;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
|
||||||
|
(MPPHandLandmarkerOptions *)handLandmarkerOptions {
|
||||||
|
NSError *error;
|
||||||
|
MPPHandLandmarker *handLandmarker =
|
||||||
|
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
|
||||||
|
XCTAssertNotNil(handLandmarker);
|
||||||
|
XCTAssertNil(error);
|
||||||
|
|
||||||
|
return handLandmarker;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertCreateHandLandmarkerWithOptions:(MPPHandLandmarkerOptions *)handLandmarkerOptions
|
||||||
|
failsWithExpectedError:(NSError *)expectedError {
|
||||||
|
NSError *error = nil;
|
||||||
|
MPPHandLandmarker *handLandmarker =
|
||||||
|
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
|
||||||
|
|
||||||
|
XCTAssertNil(handLandmarker);
|
||||||
|
AssertEqualErrors(error, expectedError);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Assert Hand Landmarker Results
|
||||||
|
|
||||||
|
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
|
||||||
|
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
|
||||||
|
fileName:fileInfo[@"name"]
|
||||||
|
ofType:fileInfo[@"type"]];
|
||||||
|
XCTAssertNotNil(image);
|
||||||
|
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
|
||||||
|
orientation:(UIImageOrientation)orientation {
|
||||||
|
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
|
||||||
|
fileName:fileInfo[@"name"]
|
||||||
|
ofType:fileInfo[@"type"]
|
||||||
|
orientation:orientation];
|
||||||
|
XCTAssertNotNil(image);
|
||||||
|
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
|
||||||
|
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker {
|
||||||
|
MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
|
||||||
|
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
|
||||||
|
XCTAssertNotNil(handLandmarkerResult);
|
||||||
|
|
||||||
|
return handLandmarkerResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)assertResultsOfDetectInImageWithFileInfo:(ResourceFileInfo *)fileInfo
|
||||||
|
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
|
||||||
|
approximatelyEqualsHandLandmarkerResult:
|
||||||
|
(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
|
||||||
|
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:fileInfo
|
||||||
|
usingHandLandmarker:handLandmarker];
|
||||||
|
[self assertHandLandmarkerResult:handLandmarkerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark General Tests
|
||||||
|
|
||||||
|
- (void)testDetectWithModelPathSucceeds {
|
||||||
|
NSString *modelPath =
|
||||||
|
[MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
MPPHandLandmarker *handLandmarker = [[MPPHandLandmarker alloc] initWithModelPath:modelPath
|
||||||
|
error:nil];
|
||||||
|
XCTAssertNotNil(handLandmarker);
|
||||||
|
|
||||||
|
[self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
|
||||||
|
usingHandLandmarker:handLandmarker
|
||||||
|
approximatelyEqualsHandLandmarkerResult:[MPPHandLandmarkerTests
|
||||||
|
thumbUpHandLandmarkerResult]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectWithEmptyResultsSucceeds {
|
||||||
|
MPPHandLandmarkerOptions *handLandmarkerOptions =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker =
|
||||||
|
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
||||||
|
|
||||||
|
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kNoHandsImage
|
||||||
|
usingHandLandmarker:handLandmarker];
|
||||||
|
AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectWithNumHandsSucceeds {
|
||||||
|
MPPHandLandmarkerOptions *handLandmarkerOptions =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
|
||||||
|
const NSInteger numHands = 2;
|
||||||
|
handLandmarkerOptions.numHands = numHands;
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker =
|
||||||
|
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
||||||
|
|
||||||
|
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kTwoHandsImage
|
||||||
|
usingHandLandmarker:handLandmarker];
|
||||||
|
|
||||||
|
XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectWithRotationSucceeds {
|
||||||
|
MPPHandLandmarkerOptions *handLandmarkerOptions =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker =
|
||||||
|
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
||||||
|
|
||||||
|
MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
|
||||||
|
orientation:UIImageOrientationRight];
|
||||||
|
|
||||||
|
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
|
||||||
|
|
||||||
|
[self assertHandLandmarkerResult:handLandmarkerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests
|
||||||
|
pointingUpRotatedHandLandmarkerResult]];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma mark Running Mode Tests
|
||||||
|
|
||||||
|
- (void)testCreateHandLandmarkerFailsWithDelegateInNonLiveStreamMode {
|
||||||
|
MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
|
||||||
|
for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
|
||||||
|
options.runningMode = runningModesToTest[i];
|
||||||
|
options.handLandmarkerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
[self
|
||||||
|
assertCreateHandLandmarkerWithOptions:options
|
||||||
|
failsWithExpectedError:
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"The vision task is in image or video mode. The "
|
||||||
|
@"delegate must not be set in the task's options."
|
||||||
|
}]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testCreateHandLandmarkerFailsWithMissingDelegateInLiveStreamMode {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
|
||||||
|
[self assertCreateHandLandmarkerWithOptions:options
|
||||||
|
failsWithExpectedError:
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"The vision task is in live stream mode. An "
|
||||||
|
@"object must be set as the delegate of the task "
|
||||||
|
@"in its options to ensure asynchronous delivery "
|
||||||
|
@"of results."
|
||||||
|
}]];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectFailsWithCallingWrongApiInImageMode {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
NSError *liveStreamApiCallError;
|
||||||
|
XCTAssertFalse([handLandmarker detectAsyncInImage:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&liveStreamApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedLiveStreamApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
|
||||||
|
@"stream mode. Current Running Mode: Image"
|
||||||
|
}];
|
||||||
|
|
||||||
|
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
||||||
|
|
||||||
|
NSError *videoApiCallError;
|
||||||
|
XCTAssertFalse([handLandmarker detectInVideoFrame:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&videoApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedVideoApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"video mode. Current Running Mode: Image"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectFailsWithCallingWrongApiInVideoMode {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeVideo;
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
NSError *liveStreamApiCallError;
|
||||||
|
XCTAssertFalse([handLandmarker detectAsyncInImage:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&liveStreamApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedLiveStreamApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
|
||||||
|
@"stream mode. Current Running Mode: Video"
|
||||||
|
}];
|
||||||
|
|
||||||
|
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
||||||
|
|
||||||
|
NSError *imageApiCallError;
|
||||||
|
XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedImageApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"image mode. Current Running Mode: Video"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectFailsWithCallingWrongApiInLiveStreamMode {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
options.handLandmarkerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
NSError *imageApiCallError;
|
||||||
|
XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedImageApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"image mode. Current Running Mode: Live Stream"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
||||||
|
|
||||||
|
NSError *videoApiCallError;
|
||||||
|
XCTAssertFalse([handLandmarker detectInVideoFrame:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&videoApiCallError]);
|
||||||
|
|
||||||
|
NSError *expectedVideoApiCallError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
|
||||||
|
@"video mode. Current Running Mode: Live Stream"
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectWithVideoModeSucceeds {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeVideo;
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInVideoFrame:image
|
||||||
|
timestampInMilliseconds:i
|
||||||
|
error:nil];
|
||||||
|
[self assertHandLandmarkerResult:handLandmarkerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectWithOutOfOrderTimestampsAndLiveStreamModeFails {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
options.handLandmarkerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
XCTestExpectation *expectation = [[XCTestExpectation alloc]
|
||||||
|
initWithDescription:@"detectWiththOutOfOrderTimestampsAndLiveStream"];
|
||||||
|
|
||||||
|
expectation.expectedFulfillmentCount = 1;
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
_outOfOrderTimestampTestDict = @{
|
||||||
|
kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
|
||||||
|
kLiveStreamTestsDictExpectationKey : expectation
|
||||||
|
};
|
||||||
|
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:1 error:nil]);
|
||||||
|
|
||||||
|
NSError *error;
|
||||||
|
XCTAssertFalse([handLandmarker detectAsyncInImage:image timestampInMilliseconds:0 error:&error]);
|
||||||
|
|
||||||
|
NSError *expectedError =
|
||||||
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
code:MPPTasksErrorCodeInvalidArgumentError
|
||||||
|
userInfo:@{
|
||||||
|
NSLocalizedDescriptionKey :
|
||||||
|
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
|
||||||
|
}];
|
||||||
|
AssertEqualErrors(error, expectedError);
|
||||||
|
|
||||||
|
NSTimeInterval timeout = 0.5f;
|
||||||
|
[self waitForExpectations:@[ expectation ] timeout:timeout];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)testDetectWithLiveStreamModeSucceeds {
|
||||||
|
MPPHandLandmarkerOptions *options =
|
||||||
|
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
|
options.runningMode = MPPRunningModeLiveStream;
|
||||||
|
options.handLandmarkerLiveStreamDelegate = self;
|
||||||
|
|
||||||
|
NSInteger iterationCount = 100;
|
||||||
|
|
||||||
|
// Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
|
||||||
|
// times. An normal expectation will fail if expectation.fulfill() is not called
|
||||||
|
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
|
||||||
|
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
|
// Since in our case we cannot predict how many times the expectation is supposed to be fullfilled
|
||||||
|
// setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
|
// `expectation.isInverted = true` ensures that test succeeds ifexpectation is fullfilled <=
|
||||||
|
// `iterationCount` times.
|
||||||
|
XCTestExpectation *expectation =
|
||||||
|
[[XCTestExpectation alloc] initWithDescription:@"detectWithLiveStream"];
|
||||||
|
|
||||||
|
expectation.expectedFulfillmentCount = iterationCount + 1;
|
||||||
|
expectation.inverted = YES;
|
||||||
|
|
||||||
|
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
|
||||||
|
|
||||||
|
_liveStreamSucceedsTestDict = @{
|
||||||
|
kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
|
||||||
|
kLiveStreamTestsDictExpectationKey : expectation
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
|
||||||
|
// with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
|
||||||
|
// `CMSampleBuffer`.
|
||||||
|
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
|
||||||
|
|
||||||
|
for (int i = 0; i < iterationCount; i++) {
|
||||||
|
XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:i error:nil]);
|
||||||
|
}
|
||||||
|
|
||||||
|
NSTimeInterval timeout = 0.5f;
|
||||||
|
[self waitForExpectations:@[ expectation ] timeout:timeout];
|
||||||
|
}
|
||||||
|
|
||||||
|
- (void)handLandmarker:(MPPHandLandmarker *)handLandmarker
|
||||||
|
didFinishDetectionWithResult:(MPPHandLandmarkerResult *)handLandmarkerResult
|
||||||
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
|
error:(NSError *)error {
|
||||||
|
[self assertHandLandmarkerResult:handLandmarkerResult
|
||||||
|
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
|
||||||
|
|
||||||
|
if (handLandmarker == _outOfOrderTimestampTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
|
||||||
|
[_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
||||||
|
} else if (handLandmarker == _liveStreamSucceedsTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
|
||||||
|
[_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
22
mediapipe/tasks/ios/test/vision/hand_landmarker/utils/BUILD
Normal file
22
mediapipe/tasks/ios/test/vision/hand_landmarker/utils/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPHandLandmarkerResultProtobufHelpers",
|
||||||
|
srcs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.mm"],
|
||||||
|
hdrs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
"-x objective-c++",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
|
||||||
|
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarkerResult",
|
||||||
|
"//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,26 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarkerResult.h"
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
@interface MPPHandLandmarkerResult (ProtobufHelpers)
|
||||||
|
|
||||||
|
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
|
||||||
|
shouldRemoveZPosition:(BOOL)removeZPosition;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
|
@ -0,0 +1,58 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
|
||||||
|
#import "mediapipe/tasks/ios/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+Helpers.h"
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
|
||||||
|
#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ClassificationListProto = ::mediapipe::ClassificationList;
|
||||||
|
using ClassificationProto = ::mediapipe::Classification;
|
||||||
|
using LandmarksDetectionResultProto =
|
||||||
|
::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||||
|
using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
@implementation MPPHandLandmarkerResult (ProtobufHelpers)
|
||||||
|
|
||||||
|
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
|
||||||
|
shouldRemoveZPosition:(BOOL)removeZPosition {
|
||||||
|
LandmarksDetectionResultProto landmarkDetectionResultProto;
|
||||||
|
|
||||||
|
if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (removeZPosition) {
|
||||||
|
// Remove z position of landmarks, because they are not used in correctness testing. For video
|
||||||
|
// or live stream mode, the z positions varies a lot during tracking from frame to frame.
|
||||||
|
for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
|
||||||
|
auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
|
||||||
|
landmark.clear_z();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [MPPHandLandmarkerResult
|
||||||
|
handLandmarkerResultWithLandmarksProto:{landmarkDetectionResultProto.landmarks()}
|
||||||
|
worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
|
||||||
|
handednessProto:{landmarkDetectionResultProto.classifications()}
|
||||||
|
timestampInMilliSeconds:0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -673,10 +673,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
|
||||||
// If `expectation.isInverted = true`, the test will only succeed if
|
// If `expectation.isInverted = true`, the test will only succeed if
|
||||||
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
|
||||||
// Since in our case we cannot predict how many times the expectation is
|
// Since in our case we cannot predict how many times the expectation is
|
||||||
// supposed to be fullfilled setting,
|
// supposed to be fulfilled setting,
|
||||||
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
|
||||||
// `expectation.isInverted = true` ensures that test succeeds if
|
// `expectation.isInverted = true` ensures that test succeeds if
|
||||||
// expectation is fullfilled <= `iterationCount` times.
|
// expectation is fulfilled <= `iterationCount` times.
|
||||||
XCTestExpectation *expectation =
|
XCTestExpectation *expectation =
|
||||||
[[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
|
[[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
|
||||||
|
|
||||||
|
|
|
@ -64,3 +64,17 @@ objc_library(
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPMask",
|
||||||
|
srcs = ["sources/MPPMask.mm"],
|
||||||
|
hdrs = ["sources/MPPMask.h"],
|
||||||
|
copts = [
|
||||||
|
"-ObjC++",
|
||||||
|
"-std=c++17",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/ios/common:MPPCommon",
|
||||||
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
118
mediapipe/tasks/ios/vision/core/sources/MPPMask.h
Normal file
118
mediapipe/tasks/ios/vision/core/sources/MPPMask.h
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
|
/** The underlying type of the segmentation mask. */
|
||||||
|
typedef NS_ENUM(NSUInteger, MPPMaskDataType) {
|
||||||
|
|
||||||
|
/** Represents the native `UInt8 *` type. */
|
||||||
|
MPPMaskDataTypeUInt8,
|
||||||
|
|
||||||
|
/** Represents the native `float *` type. */
|
||||||
|
MPPMaskDataTypeFloat32,
|
||||||
|
|
||||||
|
} NS_SWIFT_NAME(MaskDataType);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The wrapper class for MediaPipe segmentation masks.
|
||||||
|
*
|
||||||
|
* Masks are stored as `UInt8 *` or `float *` objects.
|
||||||
|
* Every mask has an underlying type which can be accessed using `dataType`. You can access the
|
||||||
|
* mask as any other type using the appropriate properties. For example, if the underlying type is
|
||||||
|
* `MPPMaskDataTypeUInt8`, in addition to accessing the mask using `uint8Data`, you can access
|
||||||
|
* `float32Data` to get the 32 bit float data (with values ranging from 0.0 to 1.0). The first
|
||||||
|
* time you access the data as a type different from the underlying type, an expensive type
|
||||||
|
* conversion is performed. Subsequent accesses return a pointer to the memory location fo the same
|
||||||
|
* type converted array. As type conversions can be expensive, it is recommended to limit the
|
||||||
|
* accesses to data of types different from the underlying type.
|
||||||
|
*
|
||||||
|
* Masks that are returned from a MediaPipe Tasks are owned by by the underlying C++ Task. If you
|
||||||
|
* need to extend the lifetime of these objects, you can invoke the `[MPPMask copy:]` method.
|
||||||
|
*/
|
||||||
|
NS_SWIFT_NAME(Mask)
|
||||||
|
@interface MPPMask : NSObject <NSCopying>
|
||||||
|
|
||||||
|
/** The width of the mask. */
|
||||||
|
@property(nonatomic, readonly) NSInteger width;
|
||||||
|
|
||||||
|
/** The height of the mask. */
|
||||||
|
@property(nonatomic, readonly) NSInteger height;
|
||||||
|
|
||||||
|
/** The data type of the mask. */
|
||||||
|
@property(nonatomic, readonly) MPPMaskDataType dataType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The pointer to the memory location where the underlying mask as a single channel `UInt8` array is
|
||||||
|
* stored. Uint8 values use the full value range and range from 0 to 255.
|
||||||
|
*/
|
||||||
|
@property(nonatomic, readonly, assign) const UInt8 *uint8Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The pointer to the memory location where the underlying mask as a single channel float 32 array
|
||||||
|
* is stored. Float values range from 0.0 to 1.0.
|
||||||
|
*/
|
||||||
|
@property(nonatomic, readonly, assign) const float *float32Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes an `MPPMask` object of type `MPPMaskDataTypeUInt8` with the given `UInt8*` data,
|
||||||
|
* width and height.
|
||||||
|
*
|
||||||
|
* If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
|
||||||
|
* `uint8Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
|
||||||
|
* the `MPPMask` must outlive the passed in `uint8Data`.
|
||||||
|
*
|
||||||
|
* @param uint8Data A pointer to the memory location of the `UInt8` data array.
|
||||||
|
* @param width The width of the mask.
|
||||||
|
* @param height The height of the mask.
|
||||||
|
* @param shouldCopy The height of the mask.
|
||||||
|
*
|
||||||
|
* @return A new `MPPMask` instance with the given `UInt8*` data, width and height.
|
||||||
|
*/
|
||||||
|
- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
|
||||||
|
width:(NSInteger)width
|
||||||
|
height:(NSInteger)height
|
||||||
|
shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes an `MPPMask` object of type `MPPMaskDataTypeFloat32` with the given `float*` data,
|
||||||
|
* width and height.
|
||||||
|
*
|
||||||
|
* If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
|
||||||
|
* `float32Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
|
||||||
|
* the `MPPMask` must outlive the passed in `float32Data`.
|
||||||
|
*
|
||||||
|
* @param float32Data A pointer to the memory location of the `float` data array.
|
||||||
|
* @param width The width of the mask.
|
||||||
|
* @param height The height of the mask.
|
||||||
|
*
|
||||||
|
* @return A new `MPPMask` instance with the given `float*` data, width and height.
|
||||||
|
*/
|
||||||
|
- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
|
||||||
|
width:(NSInteger)width
|
||||||
|
height:(NSInteger)height
|
||||||
|
shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
|
// TODO: Add methods for CVPixelBuffer conversion.
|
||||||
|
|
||||||
|
/** Unavailable. */
|
||||||
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
NS_ASSUME_NONNULL_END
|
135
mediapipe/tasks/ios/vision/core/sources/MPPMask.mm
Normal file
135
mediapipe/tasks/ios/vision/core/sources/MPPMask.mm
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
|
||||||
|
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||||
|
|
||||||
|
@interface MPPMask () {
|
||||||
|
const UInt8 *_uint8Data;
|
||||||
|
const float *_float32Data;
|
||||||
|
std::unique_ptr<UInt8[]> _uint8DataPtr;
|
||||||
|
std::unique_ptr<float[]> _float32DataPtr;
|
||||||
|
}
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation MPPMask
|
||||||
|
|
||||||
|
- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
|
||||||
|
width:(NSInteger)width
|
||||||
|
height:(NSInteger)height
|
||||||
|
shouldCopy:(BOOL)shouldCopy {
|
||||||
|
|
||||||
|
self = [super init];
|
||||||
|
if (self) {
|
||||||
|
_width = width;
|
||||||
|
_height = height;
|
||||||
|
_dataType = MPPMaskDataTypeUInt8;
|
||||||
|
|
||||||
|
if (shouldCopy) {
|
||||||
|
size_t length = _width * _height;
|
||||||
|
_uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
|
||||||
|
_uint8Data = _uint8DataPtr.get();
|
||||||
|
memcpy((UInt8 *)_uint8Data, uint8Data, length * sizeof(UInt8));
|
||||||
|
} else {
|
||||||
|
_uint8Data = uint8Data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
|
||||||
|
width:(NSInteger)width
|
||||||
|
height:(NSInteger)height
|
||||||
|
shouldCopy:(BOOL)shouldCopy {
|
||||||
|
self = [super init];
|
||||||
|
if (self) {
|
||||||
|
_width = width;
|
||||||
|
_height = height;
|
||||||
|
_dataType = MPPMaskDataTypeFloat32;
|
||||||
|
|
||||||
|
if (shouldCopy) {
|
||||||
|
size_t length = _width * _height;
|
||||||
|
_float32DataPtr = std::unique_ptr<float[]>(new float[length]);
|
||||||
|
_float32Data = _float32DataPtr.get();
|
||||||
|
memcpy((float *)_float32Data, float32Data, length * sizeof(float));
|
||||||
|
} else {
|
||||||
|
_float32Data = float32Data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
- (const UInt8 *)uint8Data {
|
||||||
|
switch (_dataType) {
|
||||||
|
case MPPMaskDataTypeUInt8: {
|
||||||
|
return _uint8Data;
|
||||||
|
}
|
||||||
|
case MPPMaskDataTypeFloat32: {
|
||||||
|
if (_uint8DataPtr) {
|
||||||
|
return _uint8DataPtr.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t length = _width * _height;
|
||||||
|
_uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
|
||||||
|
UInt8 *data = _uint8DataPtr.get();
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
data[i] = _float32Data[i] * 255;
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (const float *)float32Data {
|
||||||
|
switch (_dataType) {
|
||||||
|
case MPPMaskDataTypeUInt8: {
|
||||||
|
if (_float32DataPtr) {
|
||||||
|
return _float32DataPtr.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t length = _width * _height;
|
||||||
|
_float32DataPtr = std::unique_ptr<float[]>(new float[length]);
|
||||||
|
float *data = _float32DataPtr.get();
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
data[i] = (float)_uint8Data[i] / 255;
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
case MPPMaskDataTypeFloat32: {
|
||||||
|
return _float32Data;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
- (id)copyWithZone:(NSZone *)zone {
|
||||||
|
switch (_dataType) {
|
||||||
|
case MPPMaskDataTypeUInt8:
|
||||||
|
return [[MPPMask alloc] initWithUInt8Data:self.uint8Data
|
||||||
|
width:self.width
|
||||||
|
height:self.height
|
||||||
|
shouldCopy:YES];
|
||||||
|
case MPPMaskDataTypeFloat32:
|
||||||
|
return [[MPPMask alloc] initWithFloat32Data:self.float32Data
|
||||||
|
width:self.width
|
||||||
|
height:self.height
|
||||||
|
shouldCopy:YES];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@end
|
|
@ -165,7 +165,7 @@ static NSString *const kTaskPrefix = @"com.mediapipe.tasks.vision";
|
||||||
// For 90° and 270° rotations, we need to swap width and height.
|
// For 90° and 270° rotations, we need to swap width and height.
|
||||||
// This is due to the internal behavior of ImageToTensorCalculator, which:
|
// This is due to the internal behavior of ImageToTensorCalculator, which:
|
||||||
// - first denormalizes the provided rect by multiplying the rect width or height by the image
|
// - first denormalizes the provided rect by multiplying the rect width or height by the image
|
||||||
// width or height, repectively.
|
// width or height, respectively.
|
||||||
// - then rotates this by denormalized rect by the provided rotation, and uses this for cropping,
|
// - then rotates this by denormalized rect by the provided rotation, and uses this for cropping,
|
||||||
// - then finally rotates this back.
|
// - then finally rotates this back.
|
||||||
if (rotationDegrees % 180 == 0) {
|
if (rotationDegrees % 180 == 0) {
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
- (id)copyWithZone:(NSZone *)zone {
|
- (id)copyWithZone:(NSZone *)zone {
|
||||||
MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
|
MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
|
||||||
|
|
||||||
|
faceDetectorOptions.runningMode = self.runningMode;
|
||||||
faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
|
faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
|
||||||
faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
|
faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
|
||||||
faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;
|
faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;
|
||||||
|
|
|
@ -43,9 +43,16 @@ objc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "MPPFaceLandmarksConnections",
|
||||||
|
hdrs = ["sources/MPPFaceLandmarksConnections.h"],
|
||||||
|
module_name = "MPPFaceLandmarksConnections",
|
||||||
|
deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"],
|
||||||
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "MPPFaceLandmarker",
|
name = "MPPFaceLandmarker",
|
||||||
srcs = ["sources/MPPFaceLandmarker.m"],
|
srcs = ["sources/MPPFaceLandmarker.mm"],
|
||||||
hdrs = ["sources/MPPFaceLandmarker.h"],
|
hdrs = ["sources/MPPFaceLandmarker.h"],
|
||||||
copts = [
|
copts = [
|
||||||
"-ObjC++",
|
"-ObjC++",
|
||||||
|
@ -55,9 +62,11 @@ objc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":MPPFaceLandmarkerOptions",
|
":MPPFaceLandmarkerOptions",
|
||||||
":MPPFaceLandmarkerResult",
|
":MPPFaceLandmarkerResult",
|
||||||
|
":MPPFaceLandmarksConnections",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/ios/components/containers:MPPConnection",
|
||||||
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
"//mediapipe/tasks/ios/core:MPPTaskInfo",
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
"//mediapipe/tasks/ios/vision/core:MPPImage",
|
||||||
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
|
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
|
||||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
|
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
|
||||||
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h"
|
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h"
|
||||||
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerResult.h"
|
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerResult.h"
|
||||||
|
@ -147,6 +148,83 @@ NS_SWIFT_NAME(FaceLandmarker)
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
|
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the lips.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the lips.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)lipsConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the left eye.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the left eye.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)leftEyeConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the left eyebrow.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the left eyebrow.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)leftEyebrowConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the left iris.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the left iris.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)leftIrisConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the right eye.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the right eyr.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)rightEyeConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the right eyebrow.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the right eyebrow.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)rightEyebrowConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the right iris.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the right iris.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)rightIrisConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks of the face oval.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks of the face oval.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)faceOvalConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between making up the contours of the face.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the contours of the face.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)contoursConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks making up the tesselation of the face.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks making up the tesselation of the face.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)tesselationConnections;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the connections between all the landmarks in the face.
|
||||||
|
*
|
||||||
|
* @return An array of connections between all the landmarks in the face.
|
||||||
|
*/
|
||||||
|
+ (NSArray<MPPConnection *> *)faceConnections;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
+ (instancetype)new NS_UNAVAILABLE;
|
+ (instancetype)new NS_UNAVAILABLE;
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user