Merge remote-tracking branch 'origin/master' into nguyencse/facemeshioslib

This commit is contained in:
nguyencse 2023-06-28 11:54:36 +07:00
commit 5093b7a2a9
145 changed files with 8379 additions and 507 deletions

View File

@ -45,12 +45,13 @@ http_archive(
)
http_archive(
name = "rules_foreign_cc",
strip_prefix = "rules_foreign_cc-0.1.0",
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip",
name = "rules_foreign_cc",
sha256 = "2a4d07cd64b0719b39a7c12218a3e507672b82a97b98c6a89d38565894cf7c51",
strip_prefix = "rules_foreign_cc-0.9.0",
url = "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.9.0.tar.gz",
)
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies")
rules_foreign_cc_dependencies()
@ -492,9 +493,10 @@ http_archive(
)
# TensorFlow repo should always go after the other external dependencies.
# TF on 2023-05-26.
_TENSORFLOW_GIT_COMMIT = "67d5c561981edc45daf3f9d73ddd1a77963733ca"
_TENSORFLOW_SHA256 = "0c8326285e9cb695313e194b97d388eea70bf8bf5b13e8f0962ca8eed5179ece"
# TF on 2023-06-13.
_TENSORFLOW_GIT_COMMIT = "491681a5620e41bf079a582ac39c585cc86878b9"
# curl -L https://github.com/tensorflow/tensorflow/archive/<TENSORFLOW_GIT_COMMIT>.tar.gz | shasum -a 256
_TENSORFLOW_SHA256 = "9f76389af7a2835e68413322c1eaabfadc912f02a76d71dc16be507f9ca3d3ac"
http_archive(
name = "org_tensorflow",
urls = [

View File

@ -219,12 +219,10 @@ cc_library(
deps = [
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util:time_series_util",
"@com_google_audio_tools//audio/dsp:window_functions",
"@eigen_archive//:eigen3",
@ -319,6 +317,20 @@ cc_test(
],
)
cc_binary(
name = "time_series_framer_calculator_benchmark",
srcs = ["time_series_framer_calculator_benchmark.cc"],
deps = [
":time_series_framer_calculator",
":time_series_framer_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"@com_google_benchmark//:benchmark",
],
)
cc_test(
name = "time_series_framer_calculator_test",
srcs = ["time_series_framer_calculator_test.cc"],

View File

@ -15,9 +15,7 @@
// Defines TimeSeriesFramerCalculator.
#include <math.h>
#include <deque>
#include <memory>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "audio/dsp/window_functions.h"
@ -25,9 +23,8 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
@ -88,11 +85,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
absl::Status Close(CalculatorContext* cc) override;
private:
// Adds input data to the internal buffer.
void EnqueueInput(CalculatorContext* cc);
// Constructs and emits framed output packets.
void FrameOutput(CalculatorContext* cc);
Timestamp CurrentOutputTimestamp() {
if (use_local_timestamp_) {
return current_timestamp_;
@ -106,14 +98,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
Timestamp::kTimestampUnitsPerSecond);
}
// Returns the timestamp of a sample on a base, which is usually the time
// stamp of a packet.
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
int64_t number_of_samples) {
return timestamp_base + round(number_of_samples / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
}
// The number of input samples to advance after the current output frame is
// emitted.
int next_frame_step_samples() const {
@ -142,61 +126,174 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
Timestamp initial_input_timestamp_;
// The current timestamp is updated along with the incoming packets.
Timestamp current_timestamp_;
int num_channels_;
// Each entry in this deque consists of a single sample, i.e. a
// single column vector, and its timestamp.
std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
// Samples are buffered in a vector of sample blocks.
class SampleBlockBuffer {
public:
// Initializes the buffer.
void Init(double sample_rate, int num_channels) {
ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate;
num_channels_ = num_channels;
num_samples_ = 0;
first_block_offset_ = 0;
}
// Number of channels, equal to the number of rows in each Matrix.
int num_channels() const { return num_channels_; }
// Total number of available samples over all blocks.
int num_samples() const { return num_samples_; }
// Pushes a new block of samples on the back of the buffer with `timestamp`
// being the input timestamp of the packet containing the Matrix.
void Push(const Matrix& samples, Timestamp timestamp);
// Copies `count` samples from the front of the buffer. If there are fewer
// samples than this, the result is zero padded to have `count` samples.
// The timestamp of the last copied sample is written to *last_timestamp.
// This output is used below to update `current_timestamp_`, which is only
// used when `use_local_timestamp` is true.
Matrix CopySamples(int count, Timestamp* last_timestamp) const;
// Drops `count` samples from the front of the buffer. If `count` exceeds
// `num_samples()`, the buffer is emptied. Returns how many samples were
// dropped.
int DropSamples(int count);
private:
struct Block {
// Matrix of num_channels rows by num_samples columns, a block of possibly
// multiple samples.
Matrix samples;
// Timestamp of the first sample in the Block. This comes from the input
// packet's timestamp that contains this Matrix.
Timestamp timestamp;
Block() : timestamp(Timestamp::Unstarted()) {}
Block(const Matrix& samples, Timestamp timestamp)
: samples(samples), timestamp(timestamp) {}
int num_samples() const { return samples.cols(); }
};
std::vector<Block> blocks_;
// Number of timestamp units per sample. Used to compute timestamps as
// nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n).
double ts_units_per_sample_;
// Number of rows in each Matrix.
int num_channels_;
// The total number of samples over all blocks, equal to
// (sum_i blocks_[i].num_samples()) - first_block_offset_.
int num_samples_;
// The number of samples in the first block that have been discarded. This
// way we can cheaply represent "partially discarding" a block.
int first_block_offset_;
} sample_buffer_;
bool use_window_;
Matrix window_;
Eigen::RowVectorXf window_;
bool use_local_timestamp_;
};
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
for (int i = 0; i < input_frame.cols(); ++i) {
sample_buffer_.emplace_back(std::make_pair(
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
}
void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples,
Timestamp timestamp) {
num_samples_ += samples.cols();
blocks_.emplace_back(samples, timestamp);
}
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
while (sample_buffer_.size() >=
Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples(
int count, Timestamp* last_timestamp) const {
Matrix copied(num_channels_, count);
if (!blocks_.empty()) {
int num_copied = 0;
// First block has an offset for samples that have been discarded.
int offset = first_block_offset_;
int n;
Timestamp last_block_ts;
int last_sample_index;
for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) {
n = std::min(it->num_samples() - offset, count);
// Copy `n` samples from the next block.
copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n);
count -= n;
num_copied += n;
last_block_ts = it->timestamp;
last_sample_index = offset + n - 1;
offset = 0; // No samples have been discarded in subsequent blocks.
}
// Compute the timestamp of the last copied sample.
*last_timestamp =
last_block_ts + std::round(ts_units_per_sample_ * last_sample_index);
}
if (count > 0) {
copied.rightCols(count).setZero(); // Zero pad if needed.
}
return copied;
}
int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) {
if (blocks_.empty()) {
return 0;
}
auto block_it = blocks_.begin();
if (first_block_offset_ + count < block_it->num_samples()) {
// `count` is less than the remaining samples in the first block.
first_block_offset_ += count;
num_samples_ -= count;
return count;
}
int num_samples_dropped = block_it->num_samples() - first_block_offset_;
count -= num_samples_dropped;
first_block_offset_ = 0;
for (++block_it; block_it != blocks_.end(); ++block_it) {
if (block_it->num_samples() > count) {
break;
}
num_samples_dropped += block_it->num_samples();
count -= block_it->num_samples();
}
blocks_.erase(blocks_.begin(), block_it); // Drop whole blocks.
if (!blocks_.empty()) {
first_block_offset_ = count; // Drop part of the next block.
num_samples_dropped += count;
}
num_samples_ -= num_samples_dropped;
return num_samples_dropped;
}
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
current_timestamp_ = initial_input_timestamp_;
}
// Add input data to the internal buffer.
sample_buffer_.Push(cc->Inputs().Index(0).Get<Matrix>(),
cc->InputTimestamp());
// Construct and emit framed output packets.
while (sample_buffer_.num_samples() >=
frame_duration_samples_ + samples_still_to_drop_) {
while (samples_still_to_drop_ > 0) {
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
sample_buffer_.DropSamples(samples_still_to_drop_);
Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
&current_timestamp_);
const int frame_step_samples = next_frame_step_samples();
std::unique_ptr<Matrix> output_frame(
new Matrix(num_channels_, frame_duration_samples_));
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
++i) {
output_frame->col(i) = sample_buffer_.front().first;
current_timestamp_ = sample_buffer_.front().second;
sample_buffer_.pop_front();
}
const int frame_overlap_samples =
frame_duration_samples_ - frame_step_samples;
if (frame_overlap_samples > 0) {
for (int i = 0; i < frame_overlap_samples; ++i) {
output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
}
} else {
samples_still_to_drop_ = -frame_overlap_samples;
}
samples_still_to_drop_ = frame_step_samples;
if (use_window_) {
*output_frame = (output_frame->array() * window_.array()).matrix();
// Apply the window to each row of output_frame.
output_frame.array().rowwise() *= window_.array();
}
cc->Outputs().Index(0).Add(output_frame.release(),
CurrentOutputTimestamp());
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
.At(CurrentOutputTimestamp()));
++cumulative_output_frames_;
cumulative_completed_samples_ += frame_step_samples;
}
@ -206,35 +303,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
// fact to enable packet queueing optimizations.
cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
}
}
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp();
current_timestamp_ = initial_input_timestamp_;
}
EnqueueInput(cc);
FrameOutput(cc);
return absl::OkStatus();
}
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
sample_buffer_.pop_front();
--samples_still_to_drop_;
}
if (!sample_buffer_.empty() && pad_final_packet_) {
std::unique_ptr<Matrix> output_frame(new Matrix);
output_frame->setZero(num_channels_, frame_duration_samples_);
for (int i = 0; i < sample_buffer_.size(); ++i) {
output_frame->col(i) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
}
sample_buffer_.DropSamples(samples_still_to_drop_);
cc->Outputs().Index(0).Add(output_frame.release(),
CurrentOutputTimestamp());
if (sample_buffer_.num_samples() > 0 && pad_final_packet_) {
Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
&current_timestamp_);
cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
.At(CurrentOutputTimestamp()));
}
return absl::OkStatus();
@ -258,7 +338,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
cc->Inputs().Index(0).Header(), &input_header));
sample_rate_ = input_header.sample_rate();
num_channels_ = input_header.num_channels();
sample_buffer_.Init(sample_rate_, input_header.num_channels());
frame_duration_samples_ = time_series_util::SecondsToSamples(
framer_options.frame_duration_seconds(), sample_rate_);
RET_CHECK_GT(frame_duration_samples_, 0)
@ -312,9 +392,8 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
}
if (use_window_) {
window_ = Matrix::Ones(num_channels_, 1) *
Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
frame_duration_samples_)
window_ = Eigen::Map<Eigen::RowVectorXd>(window_vector.data(),
frame_duration_samples_)
.cast<float>();
}
use_local_timestamp_ = framer_options.use_local_timestamp();

View File

@ -0,0 +1,92 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Benchmark for TimeSeriesFramerCalculator.
#include <memory>
#include <random>
#include <vector>
#include "benchmark/benchmark.h"
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/packet.h"
using ::mediapipe::Matrix;
void BM_TimeSeriesFramerCalculator(benchmark::State& state) {
constexpr float kSampleRate = 32000.0;
constexpr int kNumChannels = 2;
constexpr int kFrameDurationSeconds = 5.0;
std::mt19937 rng(0 /*seed*/);
// Input around a half second's worth of samples at a time.
std::uniform_int_distribution<int> input_size_dist(15000, 17000);
// Generate a pool of random blocks of samples up front.
std::vector<Matrix> sample_pool;
sample_pool.reserve(20);
for (int i = 0; i < 20; ++i) {
sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng)));
}
std::uniform_int_distribution<int> pool_index_dist(0, sample_pool.size() - 1);
mediapipe::CalculatorGraphConfig config;
config.add_input_stream("input");
config.add_output_stream("output");
auto* node = config.add_node();
node->set_calculator("TimeSeriesFramerCalculator");
node->add_input_stream("input");
node->add_output_stream("output");
mediapipe::TimeSeriesFramerCalculatorOptions* options =
node->mutable_options()->MutableExtension(
mediapipe::TimeSeriesFramerCalculatorOptions::ext);
options->set_frame_duration_seconds(kFrameDurationSeconds);
for (auto _ : state) {
state.PauseTiming(); // Pause benchmark timing.
// Prepare input packets of random blocks of samples.
std::vector<mediapipe::Packet> input_packets;
input_packets.reserve(32);
float t = 0;
for (int i = 0; i < 32; ++i) {
auto samples =
std::make_unique<Matrix>(sample_pool[pool_index_dist(rng)]);
const int num_samples = samples->cols();
input_packets.push_back(mediapipe::Adopt(samples.release())
.At(mediapipe::Timestamp::FromSeconds(t)));
t += num_samples / kSampleRate;
}
// Initialize graph.
mediapipe::CalculatorGraph graph;
CHECK_OK(graph.Initialize(config));
// Prepare input header.
auto header = std::make_unique<mediapipe::TimeSeriesHeader>();
header->set_sample_rate(kSampleRate);
header->set_num_channels(kNumChannels);
state.ResumeTiming(); // Resume benchmark timing.
CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}}));
for (auto& packet : input_packets) {
CHECK_OK(graph.AddPacketToInputStream("input", packet));
}
CHECK(!graph.HasError());
CHECK_OK(graph.CloseAllInputStreams());
CHECK_OK(graph.WaitUntilIdle());
}
}
BENCHMARK(BM_TimeSeriesFramerCalculator);
BENCHMARK_MAIN();

View File

@ -117,6 +117,7 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:classification_proto",
"//mediapipe/framework/formats:landmark_proto",
"//mediapipe/framework/formats:matrix_data_proto",
"//mediapipe/framework/formats:time_series_header_proto",
],
)
@ -289,6 +290,7 @@ cc_library(
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:integral_types",
@ -1167,6 +1169,7 @@ cc_library(
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix_data_cc_proto",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check",

View File

@ -17,6 +17,7 @@
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/integral_types.h"
@ -104,4 +105,7 @@ typedef ConcatenateVectorCalculator<mediapipe::RenderData>
ConcatenateRenderDataVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
typedef ConcatenateVectorCalculator<mediapipe::Image>
ConcatenateImageVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateImageVectorCalculator);
} // namespace mediapipe

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix_data.pb.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
@ -85,8 +86,12 @@ class ConstantSidePacketCalculator : public CalculatorBase {
packet.Set<LandmarkList>();
} else if (packet_options.has_double_value()) {
packet.Set<double>();
} else if (packet_options.has_matrix_data_value()) {
packet.Set<MatrixData>();
} else if (packet_options.has_time_series_header_value()) {
packet.Set<TimeSeriesHeader>();
} else if (packet_options.has_int64_value()) {
packet.Set<int64_t>();
} else {
return absl::InvalidArgumentError(
"None of supported values were specified in options.");
@ -121,9 +126,13 @@ class ConstantSidePacketCalculator : public CalculatorBase {
MakePacket<LandmarkList>(packet_options.landmark_list_value()));
} else if (packet_options.has_double_value()) {
packet.Set(MakePacket<double>(packet_options.double_value()));
} else if (packet_options.has_matrix_data_value()) {
packet.Set(MakePacket<MatrixData>(packet_options.matrix_data_value()));
} else if (packet_options.has_time_series_header_value()) {
packet.Set(MakePacket<TimeSeriesHeader>(
packet_options.time_series_header_value()));
} else if (packet_options.has_int64_value()) {
packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
} else {
return absl::InvalidArgumentError(
"None of supported values were specified in options.");

View File

@ -19,6 +19,7 @@ package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/framework/formats/landmark.proto";
import "mediapipe/framework/formats/matrix_data.proto";
import "mediapipe/framework/formats/time_series_header.proto";
message ConstantSidePacketCalculatorOptions {
@ -29,14 +30,16 @@ message ConstantSidePacketCalculatorOptions {
message ConstantSidePacket {
oneof value {
int32 int_value = 1;
uint64 uint64_value = 5;
int64 int64_value = 11;
float float_value = 2;
double double_value = 9;
bool bool_value = 3;
string string_value = 4;
uint64 uint64_value = 5;
ClassificationList classification_list_value = 6;
LandmarkList landmark_list_value = 7;
double double_value = 9;
TimeSeriesHeader time_series_header_value = 10;
MatrixData matrix_data_value = 12;
}
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <string>
#include "absl/strings/string_view.h"
@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
DoTestSingleSidePacket("{ bool_value: true }", true);
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
}
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {

View File

@ -228,7 +228,6 @@ cc_library(
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
@ -280,7 +279,6 @@ cc_library(
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)

View File

@ -22,7 +22,6 @@
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
{Tensor::ElementType::kInt32,
Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
}
std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView()

View File

@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
// Read CPU input into tensors.
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
// If the input tensors have dynamic shape, then the tensors need to be
// resized and reallocated before we can copy the tensor values.
bool resized_tensor_shapes = false;
for (int i = 0; i < input_tensors.size(); ++i) {
if (input_tensors[i].shape().is_dynamic) {
interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
resized_tensor_shapes = true;
}
}
// Reallocation is needed for memory sanity.
if (resized_tensor_shapes) interpreter_->AllocateTensors();
for (int i = 0; i < input_tensors.size(); ++i) {
const TfLiteType input_tensor_type =
interpreter_->tensor(interpreter_->inputs()[i])->type;

View File

@ -20,7 +20,6 @@
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
// not found in the tokenizer vocab.
std::vector<Tensor> result;
result.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
{Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
kTensorsOut(cc).Send(std::move(result));

View File

@ -1077,6 +1077,7 @@ cc_test(
linkstatic = 1,
deps = [
":tensor_to_image_frame_calculator",
":tensor_to_image_frame_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame",

View File

@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
private:
float scale_factor_;
bool scale_per_frame_min_max_;
};
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
scale_factor_ =
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
.scale_per_frame_min_max();
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}
@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
const int32_t total_size = height * width * depth;
if (scale_per_frame_min_max_) {
RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
<< "Setting scale_per_frame_min_max requires FLOAT input tensors.";
}
::std::unique_ptr<const ImageFrame> output;
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
// Allocate buffer with alignments.
std::unique_ptr<uint8_t[]> buffer(
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
auto data = input_tensor.flat<float>().data();
float min = 1e23;
float max = -1e23;
if (scale_per_frame_min_max_) {
for (int i = 0; i < total_size; ++i) {
float d = scale_factor_ * data[i];
if (d < min) {
min = d;
}
if (d > max) {
max = d;
}
}
}
for (int i = 0; i < total_size; ++i) {
float d = scale_factor_ * data[i];
if (d < 0) d = 0;
if (d > 255) d = 255;
float d = data[i];
if (scale_per_frame_min_max_) {
d = 255 * (d - min) / (max - min + 1e-9);
} else {
d = scale_factor_ * d;
if (d < 0) d = 0;
if (d > 255) d = 255;
}
buffer[i] = d;
}
output = ::absl::make_unique<ImageFrame>(

View File

@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
// Multiples floating point tensor outputs by this value before converting to
// uint8. This is useful for converting from range [0, 1] to [0, 255]
optional float scale_factor = 1 [default = 1.0];
// If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
// per frame. This overrides any explicit scale_factor.
optional bool scale_per_frame_min_max = 2 [default = false];
}

View File

@ -11,7 +11,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <type_traits>
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame.h"
@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
template <class TypeParam>
class TensorToImageFrameCalculatorTest : public ::testing::Test {
protected:
void SetUpRunner() {
void SetUpRunner(bool scale_per_frame_min_max = false) {
CalculatorGraphConfig::Node config;
config.set_calculator("TensorToImageFrameCalculator");
config.add_input_stream("TENSOR:input_tensor");
config.add_output_stream("IMAGE:output_image");
config.mutable_options()
->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
->set_scale_per_frame_min_max(scale_per_frame_min_max);
runner_ = absl::make_unique<CalculatorRunner>(config);
}
@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
}
}
TYPED_TEST(TensorToImageFrameCalculatorTest,
Converts3DTensorToImageFrame2DGrayWithScaling) {
this->SetUpRunner(true);
auto& runner = this->runner_;
constexpr int kWidth = 16;
constexpr int kHeight = 8;
const tf::TensorShape tensor_shape{kHeight, kWidth};
auto tensor = absl::make_unique<tf::Tensor>(
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
auto tensor_vec = tensor->template flat<TypeParam>().data();
// Writing sequence of integers as floats which we want normalized.
tensor_vec[0] = 255;
for (int i = 1; i < kWidth * kHeight; ++i) {
tensor_vec[i] = 200;
}
const int64_t time = 1234;
runner->MutableInputs()->Tag(kTensor).packets.push_back(
Adopt(tensor.release()).At(Timestamp(time)));
if (!std::is_same<TypeParam, float>::value) {
EXPECT_FALSE(runner->Run().ok());
return; // Short circuit because does not apply to other types.
} else {
EXPECT_TRUE(runner->Run().ok());
const std::vector<Packet>& output_packets =
runner->Outputs().Tag(kImage).packets;
EXPECT_EQ(1, output_packets.size());
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
EXPECT_EQ(kWidth, output_image.Width());
EXPECT_EQ(kHeight, output_image.Height());
EXPECT_EQ(255, output_image.PixelData()[0]);
for (int i = 1; i < kWidth * kHeight; ++i) {
const uint8_t pixel_value = output_image.PixelData()[i];
ASSERT_EQ(0, pixel_value);
}
}
}
} // namespace mediapipe

View File

@ -1355,6 +1355,23 @@ cc_test(
],
)
cc_test(
name = "calculator_graph_summary_packet_test",
srcs = ["calculator_graph_summary_packet_test.cc"],
deps = [
":calculator_framework",
":packet",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/status",
],
)
cc_test(
name = "calculator_runner_test",
size = "medium",

View File

@ -32,7 +32,7 @@ template <class T>
struct dependent_false : std::false_type {};
template <typename T>
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, int index) {
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, size_t index) {
auto& vec = *vecp;
if (vec.size() <= index) {
vec.resize(index + 1);

View File

@ -109,9 +109,20 @@ class CalculatorContext {
// use OutputStream::SetOffset() directly.
void SetOffset(TimestampDiff offset);
// Returns the status of the graph run.
// DEPRECATED: This was intended to get graph run status during
// `CalculatorBase::Close` call. However, `Close` can run simultaneously with
// other calculators `CalculatorBase::Process`, hence the actual graph
// status may change any time and returned graph status here does not
// necessarily reflect the actual graph status.
//
// NOTE: This method should only be called during CalculatorBase::Close().
// As an alternative, instead of checking graph status in `Close` and doing
// work for "done" state, you can enable timestamp bound processing for your
// calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger
// `Process` on timestamp bound updates and handle "done" state there.
// Check examples in:
// mediapipe/framework/calculator_graph_summary_packet_test.cc.
//
ABSL_DEPRECATED("Does not reflect the actual graph status.")
absl::Status GraphStatus() const { return graph_status_; }
ProfilingContext* GetProfilingContext() const {

View File

@ -839,6 +839,13 @@ absl::Status CalculatorGraph::PrepareForRun(
}
absl::Status CalculatorGraph::WaitUntilIdle() {
if (has_sources_) {
LOG_FIRST_N(WARNING, 1)
<< "WaitUntilIdle called on a graph with source nodes, which "
"is not fully supported at the moment. Source nodes: "
<< ListSourceNodes();
}
MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle());
VLOG(2) << "Scheduler idle.";
absl::Status status = absl::OkStatus();
@ -1368,6 +1375,16 @@ const OutputStreamManager* CalculatorGraph::FindOutputStreamManager(
.get()[validated_graph_->OutputStreamIndex(name)];
}
std::string CalculatorGraph::ListSourceNodes() const {
std::vector<std::string> sources;
for (auto& node : nodes_) {
if (node->IsSource()) {
sources.push_back(node->DebugName());
}
}
return absl::StrJoin(sources, ", ");
}
namespace {
void PrintTimingToInfo(const std::string& label, int64_t timer_value) {
const int64_t total_seconds = timer_value / 1000000ll;

View File

@ -229,8 +229,11 @@ class CalculatorGraph {
// Wait until the running graph is in the idle mode, which is when nothing can
// be scheduled and nothing is running in the worker threads. This function
// can be called only after StartRun().
//
// NOTE: The graph must not have any source nodes because source nodes prevent
// the running graph from becoming idle until the source nodes are done.
// Currently, `WaitUntilIdle` cannot be used reliably on graphs with any
// source nodes.
absl::Status WaitUntilIdle();
// Wait until a packet is emitted on one of the observed output streams.
@ -594,6 +597,9 @@ class CalculatorGraph {
// status before taking any action.
void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
// Returns a comma-separated list of source nodes.
std::string ListSourceNodes() const;
#if !MEDIAPIPE_DISABLE_GPU
// Owns the legacy GpuSharedData if we need to create one for backwards
// compatibility.

View File

@ -0,0 +1,430 @@
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Value;
namespace {
MATCHER_P2(IntPacket, value, timestamp, "") {
*result_listener << "where object is (value: " << arg.template Get<int>()
<< ", timestamp: " << arg.Timestamp() << ")";
return Value(arg.template Get<int>(), Eq(value)) &&
Value(arg.Timestamp(), Eq(timestamp));
}
// Calculates and produces sum of all passed inputs when no more packets can be
// expected on the input stream.
class SummaryPacketCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"SUMMARY"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
static absl::Status UpdateContract(CalculatorContract* cc) {
// Makes sure there are no automatic timestamp bound updates when Process
// is called.
cc->SetTimestampOffset(TimestampDiff::Unset());
// Currently, only ImmediateInputStreamHandler supports "done" timestamp
// bound update. (ImmediateInputStreamhandler handles multiple input
// streams differently, so, in that case, calculator adjustments may be
// required.)
// TODO: update all input stream handlers to support "done"
// timestamp bound update.
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
// Enables processing timestamp bound updates. For this use case we are
// specifically interested in "done" timestamp bound update. (E.g. when
// all input packet sources are closed.)
cc->SetProcessTimestampBounds(true);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
if (!kIn(cc).IsEmpty()) {
value_ += kIn(cc).Get();
value_set_ = true;
}
if (kOut(cc).IsClosed()) {
// This can happen:
// 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
// source calculator finished generating packets sent to kIn) and
// HasNextAllowedInStream() == true (which is an often case).
// 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
// invoke Process() with Timestamp::Max to indicate "Done" timestamp
// bound update.
return absl::OkStatus();
}
// TODO: input stream holding a packet with timestamp that has
// no next timestamp allowed in stream should always result in
// InputStream::IsDone() == true.
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
// `Process` may or may not be invoked for "done" timestamp bound when
// upstream calculator fails in `Close`. Hence, extra care is needed to
// identify whether the calculator needs to send output.
// TODO: remove when "done" timestamp bound flakiness fixed.
if (value_set_) {
// kOut(cc).Send(value_) can be used here as well, however in the case
// of source calculator sending inputs into kIn the resulting timestamp
// is not well defined (e.g. it can be the last packet timestamp or
// Timestamp::Max())
// TODO: last packet from source should always result in
// InputStream::IsDone() == true.
kOut(cc).Send(value_, Timestamp::Max());
}
kOut(cc).Close();
}
return absl::OkStatus();
}
private:
int value_ = 0;
bool value_set_ = false;
};
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnClosingAllPacketSources) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp(10));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
send_packet(20, Timestamp(11));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
}
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp(10));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
send_packet(20, Timestamp::Max());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnPreStreamTimestamp) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp::PreStream());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnPostStreamTimestamp) {
std::vector<Packet> output_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input'
node {
calculator: "SummaryPacketCalculator"
input_stream: 'IN:input'
output_stream: 'SUMMARY:output'
}
)pb");
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp::PostStream());
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
class IntGeneratorCalculator : public Node {
public:
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kOut);
absl::Status Process(CalculatorContext* cc) final {
kOut(cc).Send(20, Timestamp(0));
kOut(cc).Send(10, Timestamp(1000));
return tool::StatusStop();
}
};
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnSourceCalculatorCompletion) {
std::vector<Packet> output_packets;
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "IntGeneratorCalculator"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
}
class EmitOnCloseCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Close(CalculatorContext* cc) final {
kOut(cc).Send(20, Timestamp(0));
kOut(cc).Send(10, Timestamp(1000));
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
ProducesSummaryPacketOnAnotherCalculatorClosure) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "EmitOnCloseCalculator"
input_stream: "IN:input"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
MP_ASSERT_OK(graph.CloseInputStream("input"));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
output_packets.clear();
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_THAT(output_packets, IsEmpty());
}
class FailureInCloseCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Close(CalculatorContext* cc) final {
return absl::InternalError("error");
}
};
MEDIAPIPE_REGISTER_NODE(FailureInCloseCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInClose) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "FailureInCloseCalculator"
input_stream: "IN:input"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
MP_ASSERT_OK(graph.CloseInputStream("input"));
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
EXPECT_THAT(output_packets, IsEmpty());
}
class FailureInProcessCalculator : public Node {
public:
static constexpr Input<int> kIn{"IN"};
static constexpr Output<int> kOut{"INT"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final {
return absl::InternalError("error");
}
};
MEDIAPIPE_REGISTER_NODE(FailureInProcessCalculator);
TEST(SummaryPacketCalculatorUseCaseTest,
DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInProcess) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input"
node {
calculator: "FailureInProcessCalculator"
input_stream: "IN:input"
output_stream: "INT:int_value"
}
node {
calculator: "SummaryPacketCalculator"
input_stream: "IN:int_value"
output_stream: "SUMMARY:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &graph_config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(output_packets, IsEmpty());
auto send_packet = [&graph](int value, Timestamp timestamp) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(value).At(timestamp)));
};
send_packet(10, Timestamp::PostStream());
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
EXPECT_THAT(output_packets, IsEmpty());
}
} // namespace
} // namespace mediapipe

View File

@ -117,11 +117,18 @@ class Tensor {
Shape() = default;
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
Shape(std::initializer_list<int> dimensions, bool is_dynamic)
: dims(dimensions), is_dynamic(is_dynamic) {}
Shape(const std::vector<int>& dimensions, bool is_dynamic)
: dims(dimensions), is_dynamic(is_dynamic) {}
int num_elements() const {
return std::accumulate(dims.begin(), dims.end(), 1,
std::multiplies<int>());
}
std::vector<int> dims;
// The Tensor has dynamic rather than static shape so the TFLite interpreter
// needs to be reallocated. Only relevant for CPU.
bool is_dynamic = false;
};
// Quantization parameters corresponding to the zero_point and scale value
// made available by TfLite quantized (uint8/int8) tensors.

View File

@ -2,6 +2,7 @@
#include <cstring>
#include <string>
#include <vector>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -34,6 +35,17 @@ TEST(General, TestDataTypes) {
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
}
TEST(General, TestDynamic) {
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true));
EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4);
EXPECT_TRUE(t1.shape().is_dynamic);
std::vector<int> t2_dims = {4, 3, 2, 3};
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true));
EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3);
EXPECT_TRUE(t2.shape().is_dynamic);
}
TEST(Cpu, TestMemoryAllocation) {
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
auto v1 = t1.GetCpuWriteView();

View File

@ -273,8 +273,8 @@ absl::Status Scheduler::WaitForObservedOutput() {
// Idleness requires:
// 1. either the graph has no source nodes or all source nodes are closed, and
// 2. no packets are added to graph input streams.
// For simplicity, we only allow WaitUntilIdle() to be called on a graph with
// no source nodes. (This is enforced by CalculatorGraph::WaitUntilIdle().)
// For simplicity, we only fully support WaitUntilIdle() to be called on a graph
// with no source nodes.
// The application must ensure no other threads are adding packets to graph
// input streams while a WaitUntilIdle() call is in progress.
absl::Status Scheduler::WaitUntilIdle() {

View File

@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
return *this + 1;
}
bool Timestamp::HasNextAllowedInStream() const {
if (*this >= Max() || *this == PreStream()) {
return false;
}
return true;
}
Timestamp Timestamp::PreviousAllowedInStream() const {
if (*this <= Min() || *this == PostStream()) {
// Indicates that no previous timestamps may occur.

View File

@ -186,6 +186,10 @@ class Timestamp {
// CHECKs that this->IsAllowedInStream().
Timestamp NextAllowedInStream() const;
// Returns true if there's a next timestamp in the range [Min .. Max] after
// this one.
bool HasNextAllowedInStream() const;
// Returns the previous timestamp in the range [Min .. Max], or
// Unstarted() if no Packets may preceed one with this timestamp.
Timestamp PreviousAllowedInStream() const;

View File

@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
Timestamp::PostStream().NextAllowedInStream());
}
TEST(TimestampTest, HasNextAllowedInStream) {
EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
}
TEST(TimestampTest, SpecialValueDifferences) {
{ // Lower range
const std::vector<Timestamp> timestamps = {

View File

@ -530,6 +530,7 @@ cc_library(
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],

View File

@ -14,7 +14,7 @@
"""MediaPipe Task Library Helper Rules for iOS"""
MPP_TASK_MINIMUM_OS_VERSION = "11.0"
MPP_TASK_MINIMUM_OS_VERSION = "12.0"
# When the static framework is built with bazel, the all header files are moved
# to the "Headers" directory with no header path prefixes. This auxiliary rule

View File

@ -50,6 +50,7 @@ def mediapipe_proto_library_impl(
def_cc_proto = True,
def_py_proto = True,
def_java_lite_proto = True,
def_kt_lite_proto = True,
def_objc_proto = True,
def_java_proto = True,
def_jspb_proto = True,
@ -72,6 +73,7 @@ def mediapipe_proto_library_impl(
def_cc_proto: define the cc_proto_library target
def_py_proto: define the py_proto_library target
def_java_lite_proto: define the java_lite_proto_library target
def_kt_lite_proto: define the kt_lite_proto_library target
def_objc_proto: define the objc_proto_library target
def_java_proto: define the java_proto_library target
def_jspb_proto: define the jspb_proto_library target
@ -255,6 +257,7 @@ def mediapipe_proto_library(
def_cc_proto = True,
def_py_proto = True,
def_java_lite_proto = True,
def_kt_lite_proto = True,
def_portable_proto = True, # @unused
def_objc_proto = True,
def_java_proto = True,
@ -281,6 +284,7 @@ def mediapipe_proto_library(
def_cc_proto: define the cc_proto_library target
def_py_proto: define the py_proto_library target
def_java_lite_proto: define the java_lite_proto_library target
def_kt_lite_proto: define the kt_lite_proto_library target
def_portable_proto: ignored since portable protos are gone
def_objc_proto: define the objc_proto_library target
def_java_proto: define the java_proto_library target
@ -304,6 +308,7 @@ def mediapipe_proto_library(
def_cc_proto = def_cc_proto,
def_py_proto = def_py_proto,
def_java_lite_proto = def_java_lite_proto,
def_kt_lite_proto = def_kt_lite_proto,
def_objc_proto = def_objc_proto,
def_java_proto = def_java_proto,
def_jspb_proto = def_jspb_proto,
@ -334,6 +339,7 @@ def mediapipe_proto_library(
def_cc_proto = def_cc_proto,
def_py_proto = def_py_proto,
def_java_lite_proto = def_java_lite_proto,
def_kt_lite_proto = def_kt_lite_proto,
def_objc_proto = def_objc_proto,
def_java_proto = def_java_proto,
def_jspb_proto = def_jspb_proto,

View File

@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
@ -1430,10 +1431,10 @@ std::vector<const FieldDescriptor*> GetFields(const Message* src) {
// Orders map entries in dst to match src.
void OrderMapEntries(const Message* src, Message* dst,
std::set<const Message*>* seen = nullptr) {
std::unique_ptr<std::set<const Message*>> seen_owner;
absl::flat_hash_set<const Message*>* seen = nullptr) {
std::unique_ptr<absl::flat_hash_set<const Message*>> seen_owner;
if (!seen) {
seen_owner = std::make_unique<std::set<const Message*>>();
seen_owner = std::make_unique<absl::flat_hash_set<const Message*>>();
seen = seen_owner.get();
}
if (seen->count(src) > 0) {

View File

@ -34,6 +34,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import javax.microedition.khronos.egl.EGLConfig;
import javax.microedition.khronos.opengles.GL10;
@ -303,7 +304,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer {
}
// Use this when the texture is not a SurfaceTexture.
public void setNextFrame(TextureFrame frame) {
public void setNextFrame(@Nullable TextureFrame frame) {
if (surfaceTexture != null) {
Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */);
}

View File

@ -50,7 +50,6 @@ android_library(
"MediaPipeRunner.java",
],
visibility = [
"//java/com/google/android/libraries/camera/effects:__subpackages__",
"//mediapipe/java/com/google/mediapipe:__subpackages__",
],
exports = [

View File

@ -67,6 +67,7 @@ public class ExternalTextureRenderer {
private float[] textureTransformMatrix = new float[16];
private boolean flipY;
private int rotation = Surface.ROTATION_0;
private boolean doExplicitCpuSync = true;
/** Call this to setup the shader program before rendering. */
public void setup() {
@ -101,6 +102,14 @@ public class ExternalTextureRenderer {
this.rotation = rotation;
}
/**
* Configures whether the renderer should do an explicit CPU synchronization using glFinish upon
* each {@link #render} call. Defaults to true.
*/
public void setDoExplicitCpuSync(boolean doExplicitCpuSync) {
this.doExplicitCpuSync = doExplicitCpuSync;
}
/**
* Renders the surfaceTexture to the framebuffer with optional vertical flip.
*
@ -150,8 +159,11 @@ public class ExternalTextureRenderer {
GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
ShaderUtil.checkGlError("glBindTexture");
// TODO: add sync and go back to glFlush()
GLES20.glFinish();
if (doExplicitCpuSync) {
// TODO: add sync and go back to glFlush()
GLES20.glFinish();
}
}
/**

View File

@ -14,7 +14,10 @@
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe:__subpackages__"])
package(default_visibility = [
"//cloud/ml/applications/vision/model_garden/model_oss/mediapipe:__subpackages__",
"//mediapipe:__subpackages__",
])
licenses(["notice"])

View File

@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel):
self._model: tf.keras.Model = None
self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
self._loss_function: Union[str, tf.keras.losses.Loss] = None
self._metric_function: Union[str, tf.keras.metrics.Metric] = None
self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None
self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel):
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=[self._metric_function])
metrics=self._metric_functions,
)
latest_checkpoint = (
tf.train.latest_checkpoint(checkpoint_path)

View File

@ -80,10 +80,30 @@ py_test(
deps = [":loss_functions"],
)
######################################################################
# Public target of the MediaPipe Model Maker Quantization Config.
# Quantization Config is used to export a quantized model. Please refer
# to the specific task documentations such as:
# https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize
# for usage information.
######################################################################
py_library(
name = "metrics",
srcs = ["metrics.py"],
)
py_test(
name = "metrics_test",
srcs = ["metrics_test.py"],
deps = [":metrics"],
)
py_library(
name = "quantization",
srcs = ["quantization.py"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = ["//mediapipe/model_maker/python/core/data:dataset"],
)

View File

@ -0,0 +1,104 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metrics utility library."""
import tensorflow as tf
def _get_binary_sparse_metric(metric: tf.metrics.Metric):
"""Helper method to create a BinarySparse version of a tf.keras.Metric.
BinarySparse is an implementation where the update_state(y_true, y_pred) takes
in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only
supports the binary classification case, and that class_id=0 is the negative
class and class_id=1 is the positive class.
Currently supported tf.metric.Metric classes
1. BinarySparseRecallAtPrecision
2. BinarySparsePrecisionAtRecall
Args:
metric: A tf.metric.Metric class for which we want to generate a
BinarySparse version of this metric.
Returns:
A class for the BinarySparse version of the specified tf.metrics.Metric
"""
class BinarySparseMetric(metric):
"""A BinarySparse wrapper class for a tf.keras.Metric.
This class has the same parameters and functions as the underlying
metric class. For example, the parameters for BinarySparseRecallAtPrecision
is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint
is that class_id must be set to 1 (or not specified) for the Binary metric.
"""
def __init__(self, *args, **kwargs):
if 'class_id' in kwargs and kwargs['class_id'] != 1:
raise ValueError(
f'Custom BinarySparseMetric for class:{metric.__name__} is '
'only supported for class_id=1, got class_id='
f'{kwargs["class_id"]} instead'
)
else:
kwargs['class_id'] = 1
super().__init__(*args, **kwargs)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
y_true_one_hot = tf.one_hot(y_true, 2)
return super().update_state(
y_true_one_hot, y_pred, sample_weight=sample_weight
)
return BinarySparseMetric
def _get_sparse_metric(metric: tf.metrics.Metric):
"""Helper method to create a Sparse version of a tf.keras.Metric.
Sparse is an implementation where the update_state(y_true, y_pred) takes in
shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes).
Currently supported tf.metrics.Metric classes:
1. tf.metrics.Recall
2. tf.metrics.Precision
Args:
metric: A tf.metric.Metric class for which we want to generate a Sparse
version of this metric.
Returns:
A class for the Sparse version of the specified tf.keras.Metric.
"""
class SparseMetric(metric):
"""A Sparse wrapper class for a tf.keras.Metric."""
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred = tf.math.argmax(y_pred, axis=-1)
return super().update_state(y_true, y_pred, sample_weight=sample_weight)
return SparseMetric
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
tf.metrics.RecallAtPrecision
)
BinarySparsePrecisionAtRecall = _get_binary_sparse_metric(
tf.metrics.PrecisionAtRecall
)

View File

@ -0,0 +1,74 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import parameterized
import tensorflow as tf
from mediapipe.model_maker.python.core.utils import metrics
class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.y_true = [0, 0, 1, 1, 0, 1]
self.y_pred = [
[0.9, 0.1], # 0, 0 y
[0.8, 0.2], # 0, 0 y
[0.7, 0.3], # 0, 1 n
[0.6, 0.4], # 0, 1 n
[0.3, 0.7], # 1, 0 y
[0.3, 0.7], # 1, 1 y
]
self.num_classes = 3
def _assert_metric_equals(self, metric, value):
metric.update_state(self.y_true, self.y_pred)
self.assertEqual(metric.result(), value)
def test_sparse_recall(self):
metric = metrics.SparseRecall()
self._assert_metric_equals(metric, 1 / 3)
def test_sparse_precision(self):
metric = metrics.SparsePrecision()
self._assert_metric_equals(metric, 1 / 2)
def test_binary_sparse_recall_at_precision(self):
metric = metrics.BinarySparseRecallAtPrecision(1.0)
self._assert_metric_equals(metric, 0.0) # impossible to achieve precision=1
metric = metrics.BinarySparseRecallAtPrecision(0.4)
self._assert_metric_equals(metric, 1.0)
def test_binary_sparse_precision_at_recall(self):
metric = metrics.BinarySparsePrecisionAtRecall(1.0)
self._assert_metric_equals(metric, 3 / 4)
metric = metrics.BinarySparsePrecisionAtRecall(0.7)
self._assert_metric_equals(metric, 3 / 4)
def test_binary_sparse_precision_at_recall_class_id_error(self):
# class_id=1 case should not error
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1)
# class_id=2 case should error
with self.assertRaisesRegex(
ValueError,
'Custom BinarySparseMetric for class:PrecisionAtRecall is only'
' supported for class_id=1, got class_id=2 instead',
):
_ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
if __name__ == '__main__':
tf.test.main()

View File

@ -31,11 +31,11 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":dataset",
":hyperparameters",
":model_options",
":model_spec",
":text_classifier",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
@ -45,12 +45,18 @@ py_library(
deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
)
py_library(
name = "model_spec",
srcs = ["model_spec.py"],
deps = [
":hyperparameters",
":model_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/utils:file_util",
"//mediapipe/model_maker/python/text/core:bert_model_spec",
],
@ -61,9 +67,9 @@ py_test(
srcs = ["model_spec_test.py"],
tags = ["requires-net:external"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
@ -100,9 +106,9 @@ py_library(
name = "text_classifier_options",
srcs = ["text_classifier_options.py"],
deps = [
":hyperparameters",
":model_options",
":model_spec",
"//mediapipe/model_maker/python/core:hyperparameters",
],
)
@ -111,13 +117,14 @@ py_library(
srcs = ["text_classifier.py"],
deps = [
":dataset",
":hyperparameters",
":model_options",
":model_spec",
":preprocessor",
":text_classifier_options",
"//mediapipe/model_maker/python/core:hyperparameters",
"//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:metrics",
"//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",

View File

@ -13,19 +13,23 @@
# limitations under the License.
"""MediaPipe Public Python API for Text Classifier."""
from mediapipe.model_maker.python.core import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import dataset
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
from mediapipe.model_maker.python.text.text_classifier import model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec
from mediapipe.model_maker.python.text.text_classifier import text_classifier
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
HParams = hyperparameters.BaseHParams
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions
)
BertOptimizer = hyperparameters.BertOptimizer
BertHParams = hyperparameters.BertHParams
BertModelOptions = model_options.BertModelOptions
CSVParams = dataset.CSVParameters
Dataset = dataset.Dataset
AverageWordEmbeddingModelOptions = (
model_options.AverageWordEmbeddingModelOptions)
BertModelOptions = model_options.BertModelOptions
SupportedModels = model_spec.SupportedModels
TextClassifier = text_classifier.TextClassifier
TextClassifierOptions = text_classifier_options.TextClassifierOptions

View File

@ -0,0 +1,54 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training object detection models."""
import dataclasses
import enum
from typing import Union
from mediapipe.model_maker.python.core import hyperparameters as hp
@dataclasses.dataclass
class AverageWordEmbeddingHParams(hp.BaseHParams):
"""The hyperparameters for an AverageWordEmbeddingClassifier."""
@enum.unique
class BertOptimizer(enum.Enum):
"""Supported Optimizers for Bert Text Classifier."""
ADAMW = "adamw"
LAMB = "lamb"
@dataclasses.dataclass
class BertHParams(hp.BaseHParams):
"""The hyperparameters for a Bert Classifier.
Attributes:
learning_rate: Learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
optimizer: Optimizer to use for training. Only supported values are "adamw"
and "lamb".
"""
learning_rate: float = 3e-5
batch_size: int = 48
epochs: int = 2
optimizer: BertOptimizer = BertOptimizer.ADAMW
HParams = Union[BertHParams, AverageWordEmbeddingHParams]

View File

@ -17,13 +17,11 @@ import dataclasses
import enum
import functools
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.utils import file_util
from mediapipe.model_maker.python.text.core import bert_model_spec
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
# BERT-based text classifier spec inherited from BertModelSpec
BertClassifierSpec = bert_model_spec.BertModelSpec
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
'text_classifier/mobilebert_tiny',
@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
is_folder=True,
)
EXBERT_FILES = file_util.DownloadedFiles(
'text_classifier/exbert',
'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
is_folder=True,
)
@dataclasses.dataclass
class AverageWordEmbeddingClassifierSpec:
@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
"""
# `learning_rate` is unused for the average word embedding model
hparams: hp.BaseHParams = hp.BaseHParams(
epochs=10, batch_size=32, learning_rate=0)
hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0
)
model_options: mo.AverageWordEmbeddingModelOptions = (
mo.AverageWordEmbeddingModelOptions())
name: str = 'AverageWordEmbedding'
average_word_embedding_classifier_spec = functools.partial(
AverageWordEmbeddingClassifierSpec)
@dataclasses.dataclass
class BertClassifierSpec(bert_model_spec.BertModelSpec):
"""Specification for a Bert classifier model.
Only overrides the hparams attribute since the rest of the attributes are
inherited from the BertModelSpec.
"""
hparams: hp.BertHParams = hp.BertHParams()
mobilebert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=MOBILEBERT_TINY_FILES,
hparams=hp.BaseHParams(
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='MobileBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'mask': 'serving_default_input_3:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)
exbert_classifier_spec = functools.partial(
BertClassifierSpec,
downloaded_files=EXBERT_FILES,
hparams=hp.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
),
name='ExBert',
tflite_input_name={
'ids': 'serving_default_input_1:0',
'segment_ids': 'serving_default_input_2:0',
'mask': 'serving_default_input_3:0',
},
)
@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
"""Predefined text classifier model specs supported by Model Maker."""
AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
EXBERT_CLASSIFIER = exbert_classifier_spec

View File

@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
import tensorflow as tf
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
hp.BertHParams(
epochs=3,
batch_size=48,
learning_rate=3e-5,
distribution_strategy='off'))
distribution_strategy='off',
),
)
def test_predefined_average_word_embedding_spec(self):
model_spec_obj = (
@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
dropout_rate=0.2))
self.assertEqual(
model_spec_obj.hparams,
hp.BaseHParams(
hp.AverageWordEmbeddingHParams(
epochs=10,
batch_size=32,
learning_rate=0,
@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
custom_bert_classifier_options)
def test_custom_average_word_embedding_spec(self):
custom_hparams = hp.BaseHParams(
custom_hparams = hp.AverageWordEmbeddingHParams(
learning_rate=0.4,
batch_size=64,
epochs=10,
@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
export_dir='foo/bar',
distribution_strategy='mirrored',
num_gpus=3,
tpu='tpu/address')
tpu='tpu/address',
)
custom_average_word_embedding_model_options = (
classifier_model_options.AverageWordEmbeddingModelOptions(
seq_len=512,

View File

@ -19,14 +19,16 @@ import tempfile
from typing import Any, Optional, Sequence, Tuple
import tensorflow as tf
from tensorflow_addons import optimizers as tfa_optimizers
import tensorflow_hub as hub
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import metrics
from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
from mediapipe.model_maker.python.text.text_classifier import preprocessor
@ -54,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
f" got {options.supported_model}")
if (isinstance(options.model_options, mo.BertModelOptions) and
(options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
if isinstance(options.model_options, mo.BertModelOptions) and (
options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
):
raise ValueError(
f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
"Expected a Bert Classifier(MobileBERT or EXBERT), got "
f"{options.supported_model}"
)
class TextClassifier(classifier.Classifier):
"""API for creating and training a text classification model."""
def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
label_names: Sequence[str]):
def __init__(
self, model_spec: Any, label_names: Sequence[str], shuffle: bool
):
super().__init__(
model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
model_spec=model_spec, label_names=label_names, shuffle=shuffle
)
self._model_spec = model_spec
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
@classmethod
@ -106,7 +112,10 @@ class TextClassifier(classifier.Classifier):
if options.hparams is None:
options.hparams = options.supported_model.value().hparams
if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
if (
options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
):
text_classifier = (
_BertClassifier.create_bert_classifier(train_data, validation_data,
options,
@ -123,12 +132,24 @@ class TextClassifier(classifier.Classifier):
return text_classifier
def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
def evaluate(
self,
data: ds.Dataset,
batch_size: int = 32,
desired_precisions: Optional[Sequence[float]] = None,
desired_recalls: Optional[Sequence[float]] = None,
) -> Any:
"""Overrides Classifier.evaluate().
Args:
data: Evaluation dataset. Must be a TextClassifier Dataset.
batch_size: Number of samples per evaluation step.
desired_precisions: If specified, adds a RecallAtPrecision metric per
desired_precisions[i] entry which tracks the recall given the constraint
on precision. Only supported for binary classification.
desired_recalls: If specified, adds a PrecisionAtRecall metric per
desired_recalls[i] entry which tracks the precision given the constraint
on recall. Only supported for binary classification.
Returns:
The loss value and accuracy.
@ -144,6 +165,28 @@ class TextClassifier(classifier.Classifier):
processed_data = self._text_preprocessor.preprocess(data)
dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
additional_metrics = []
if desired_precisions and len(data.label_names) == 2:
for precision in desired_precisions:
additional_metrics.append(
metrics.BinarySparseRecallAtPrecision(
precision, name=f"recall_at_precision_{precision}"
)
)
if desired_recalls and len(data.label_names) == 2:
for recall in desired_recalls:
additional_metrics.append(
metrics.BinarySparsePrecisionAtRecall(
recall, name=f"precision_at_recall_{recall}"
)
)
metric_functions = self._metric_functions + additional_metrics
self._model.compile(
optimizer=self._optimizer,
loss=self._loss_function,
metrics=metric_functions,
)
return self._model.evaluate(dataset)
def export_model(
@ -161,9 +204,8 @@ class TextClassifier(classifier.Classifier):
path is {self._hparams.export_dir}/{model_name}.
quantization_config: The configuration for model quantization.
"""
if not tf.io.gfile.exists(self._hparams.export_dir):
tf.io.gfile.makedirs(self._hparams.export_dir)
tflite_file = os.path.join(self._hparams.export_dir, model_name)
tf.io.gfile.makedirs(os.path.dirname(tflite_file))
metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
tflite_model = model_util.convert_to_tflite(
@ -174,7 +216,7 @@ class TextClassifier(classifier.Classifier):
writer = self._get_metadata_writer(tflite_model, vocab_filepath)
tflite_model_with_metadata, metadata_json = writer.populate()
model_util.save_tflite(tflite_model_with_metadata, tflite_file)
with open(metadata_file, "w") as f:
with tf.io.gfile.GFile(metadata_file, "w") as f:
f.write(metadata_json)
@abc.abstractmethod
@ -191,13 +233,23 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
_DELIM_REGEX_PATTERN = r"[^\w\']+"
def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.BaseHParams, label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
def __init__(
self,
model_spec: ms.AverageWordEmbeddingClassifierSpec,
model_options: mo.AverageWordEmbeddingModelOptions,
hparams: hp.AverageWordEmbeddingHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._model_options = model_options
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = "sparse_categorical_crossentropy"
self._metric_function = "accuracy"
self._metric_functions = [
"accuracy",
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: (
preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
@ -306,16 +358,26 @@ class _BertClassifier(TextClassifier):
_INITIALIZER_RANGE = 0.02
def __init__(self, model_spec: ms.BertClassifierSpec,
model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
label_names: Sequence[str]):
super().__init__(model_spec, hparams, label_names)
def __init__(
self,
model_spec: ms.BertClassifierSpec,
model_options: mo.BertModelOptions,
hparams: hp.BertHParams,
label_names: Sequence[str],
):
super().__init__(model_spec, label_names, hparams.shuffle)
self._hparams = hparams
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._model_options = model_options
with self._hparams.get_strategy().scope():
self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32
)
self._metric_functions = [
tf.keras.metrics.SparseCategoricalAccuracy(
"test_accuracy", dtype=tf.float32
),
metrics.SparsePrecision(name="precision", dtype=tf.float32),
metrics.SparseRecall(name="recall", dtype=tf.float32),
]
self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
@classmethod
@ -438,11 +500,26 @@ class _BertClassifier(TextClassifier):
initial_learning_rate=initial_lr,
decay_schedule_fn=lr_schedule,
warmup_steps=warmup_steps)
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"])
if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
self._optimizer = tf.keras.optimizers.experimental.AdamW(
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
)
self._optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"]
)
elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
self._optimizer = tfa_optimizers.LAMB(
lr_schedule,
weight_decay_rate=0.01,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
global_clipnorm=1.0,
)
else:
raise ValueError(
"BertHParams.optimizer must be set to ADAM or "
f"LAMB. Got {self._hparams.optimizer}."
)
def _save_vocab(self, vocab_filepath: str):
tf.io.gfile.copy(

View File

@ -66,14 +66,16 @@ def run(data_dir,
quantization_config = None
if (supported_model ==
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
hparams = text_classifier.HParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
hparams = text_classifier.AverageWordEmbeddingHParams(
epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
)
# Warning: This takes extremely long to run on CPU
elif (
supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
quantization_config = quantization.QuantizationConfig.for_dynamic()
hparams = text_classifier.HParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
hparams = text_classifier.BertHParams(
epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
)
# Fine-tunes the model.
options = text_classifier.TextClassifierOptions(

View File

@ -16,7 +16,7 @@
import dataclasses
from typing import Optional
from mediapipe.model_maker.python.core import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
@ -34,5 +34,5 @@ class TextClassifierOptions:
architecture of the `supported_model`.
"""
supported_model: ms.SupportedModels
hparams: Optional[hp.BaseHParams] = None
hparams: Optional[hp.HParams] = None
model_options: Optional[mo.TextClassifierModelOptions] = None

View File

@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
def test_create_and_train_average_word_embedding_model(self):
train_data, validation_data = self._get_data()
options = (
text_classifier.TextClassifierOptions(
supported_model=(text_classifier.SupportedModels
.AVERAGE_WORD_EMBEDDING_CLASSIFIER),
hparams=text_classifier.HParams(
epochs=1, batch_size=1, learning_rate=0)))
options = text_classifier.TextClassifierOptions(
supported_model=(
text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
),
hparams=text_classifier.AverageWordEmbeddingHParams(
epochs=1, batch_size=1, learning_rate=0
),
)
average_word_embedding_classifier = (
text_classifier.TextClassifier.create(train_data, validation_data,
options))
@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
options = text_classifier.TextClassifierOptions(
supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
model_options=text_classifier.BertModelOptions(
do_fine_tuning=False, seq_len=2),
hparams=text_classifier.HParams(
do_fine_tuning=False, seq_len=2
),
hparams=text_classifier.BertHParams(
epochs=1,
batch_size=1,
learning_rate=3e-5,
distribution_strategy='off'))
distribution_strategy='off',
),
)
bert_classifier = text_classifier.TextClassifier.create(
train_data, validation_data, options)

View File

@ -20,13 +20,6 @@ licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "constants",
srcs = ["constants.py"],
@ -72,18 +65,11 @@ py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
":constants",
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [":testdata"],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/vision:face_aligner",
],
)

View File

@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
)
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
'face_stylizer/face_landmarker_v2.task',
'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
is_folder=False,
)
# Dimension of the input style vector to the decoder
STYLE_DIM = 512

View File

@ -13,13 +13,37 @@
# limitations under the License.
"""Face stylizer dataset library."""
from typing import Sequence
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.face_stylizer import constants
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.vision import face_aligner
def _preprocess_face_dataset(
all_image_paths: Sequence[str],
) -> Sequence[tf.Tensor]:
"""Preprocess face image dataset by aligning the face."""
path = constants.FACE_ALIGNER_TASK_FILES.get_path()
base_options = base_options_module.BaseOptions(model_asset_path=path)
options = face_aligner.FaceAlignerOptions(base_options=base_options)
aligner = face_aligner.FaceAligner.create_from_options(options)
preprocessed_images = []
for path in all_image_paths:
tf.compat.v1.logging.info('Preprocess image %s', path)
image = image_module.Image.create_from_file(path)
aligned_image = aligner.align(image)
aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
preprocessed_images.append(aligned_image_tensor)
return preprocessed_images
# TODO: Change to a unlabeled dataset if it makes sense.
@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
):
raise ValueError('No images found under given directory')
image_data = _preprocess_face_dataset(all_image_paths)
label_names = sorted(
name
for name in os.listdir(data_root)
@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
image_ds = tf.data.Dataset.from_tensor_slices(image_data)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(

View File

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.vision.core import image_utils
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._test_data_dirname = 'input/style'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
test_data_dirname = 'input/style'
input_data_dir = test_utils.get_test_data_path(test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])

View File

@ -14,7 +14,7 @@
"""APIs to train face stylization model."""
import os
from typing import Callable, Optional
from typing import Any, Callable, Optional
import numpy as np
import tensorflow as tf
@ -54,7 +54,6 @@ class FaceStylizer(object):
self._model_spec = model_spec
self._model_options = model_options
self._hparams = hparams
# TODO: Support face alignment in image preprocessor.
self._preprocessor = image_preprocessing.Preprocessor(
input_shape=self._model_spec.input_image_shape,
num_classes=1,
@ -128,7 +127,7 @@ class FaceStylizer(object):
def _train_model(
self,
train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None,
preprocessor: Optional[Callable[..., Any]] = None,
):
"""Trains the face stylizer model.

View File

@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier):
self._model_options = model_options
self._hparams = hparams
self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
self._metric_function = 'categorical_accuracy'
self._metric_functions = ['categorical_accuracy']
self._optimizer = 'adam'
self._callbacks = self._get_callbacks()
self._history = None

View File

@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier):
self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
self._loss_function = tf.keras.losses.CategoricalCrossentropy(
label_smoothing=self._hparams.label_smoothing)
self._metric_function = 'accuracy'
self._metric_functions = ['accuracy']
self._history = None # Training history returned from `keras_model.fit`.
@classmethod

View File

@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
)
return model_config
def _build_model(self) -> tf.keras.Model:
def _build_model(self, omit_l2=False) -> tf.keras.Model:
"""Builds a RetinaNet object detector model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._model_spec.input_image_shape
)
l2_regularizer = tf.keras.regularizers.l2(
self._model_options.l2_weight_decay / 2.0
)
if omit_l2:
l2_regularizer = None
else:
l2_regularizer = tf.keras.regularizers.l2(
self._model_options.l2_weight_decay / 2.0
)
model_config = self._get_model_config()
return factory.build_retinanet(input_specs, model_config, l2_regularizer)
@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
def convert_to_qat(self) -> None:
"""Converts the model to a QAT RetinaNet model."""
model = self._build_model()
model = self._build_model(omit_l2=True)
dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
model(dummy_input, training=True)
model.set_weights(self._model.get_weights())

View File

@ -43,6 +43,7 @@ cc_library(
":base_audio_task_api",
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/cc/core:task_api_factory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
@ -60,13 +61,8 @@ class AudioTaskApiFactory {
"Task graph config should only contain one task subgraph node.",
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
} else {
if (!node.options().HasExtension(Options::ext)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat(node.calculator(),
" is missing the required task options field."),
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
MP_RETURN_IF_ERROR(
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
found_task_subgraph = true;
}
}

View File

@ -29,6 +29,7 @@ cc_library(
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"@com_google_absl//absl/log",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",

View File

@ -17,15 +17,56 @@ limitations under the License.
#include <memory>
#include <string>
#include <variant>
#include "absl/log/log.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
namespace mediapipe {
namespace tasks {
namespace core {
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
const BaseOptions::CpuOptions& options) {
proto::Acceleration acceleration_proto = proto::Acceleration();
acceleration_proto.mutable_tflite();
return acceleration_proto;
}
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
const BaseOptions::GpuOptions& options) {
proto::Acceleration acceleration_proto = proto::Acceleration();
auto* gpu = acceleration_proto.mutable_gpu();
gpu->set_use_advanced_gpu_api(true);
gpu->set_cached_kernel_path(options.cached_kernel_path);
gpu->set_serialized_model_dir(options.serialized_model_dir);
gpu->set_model_token(options.model_token);
return acceleration_proto;
}
template <typename T>
void SetDelegateOptionsOrDie(const BaseOptions* base_options,
proto::BaseOptions& base_options_proto) {
if (base_options->delegate_options.has_value()) {
if (!std::holds_alternative<T>(*base_options->delegate_options)) {
LOG(FATAL) << "Specified Delegate type does not match the provided "
"delegate options.";
} else {
std::visit(
[&base_options_proto](const auto& delegate_options) {
proto::Acceleration acceleration_proto =
ConvertDelegateOptionsToAccelerationProto(delegate_options);
base_options_proto.mutable_acceleration()->Swap(
&acceleration_proto);
},
*base_options->delegate_options);
}
}
}
proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
proto::BaseOptions base_options_proto;
if (!base_options->model_asset_path.empty()) {
@ -53,11 +94,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
switch (base_options->delegate) {
case BaseOptions::Delegate::CPU:
base_options_proto.mutable_acceleration()->mutable_tflite();
SetDelegateOptionsOrDie<BaseOptions::CpuOptions>(base_options,
base_options_proto);
break;
case BaseOptions::Delegate::GPU:
base_options_proto.mutable_acceleration()
->mutable_gpu()
->set_use_advanced_gpu_api(true);
SetDelegateOptionsOrDie<BaseOptions::GpuOptions>(base_options,
base_options_proto);
break;
case BaseOptions::Delegate::EDGETPU_NNAPI:
base_options_proto.mutable_acceleration()
@ -65,7 +110,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
->set_accelerator_name("google-edgetpu");
break;
}
return base_options_proto;
}
} // namespace core

View File

@ -17,7 +17,9 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
#include <memory>
#include <optional>
#include <string>
#include <variant>
#include "absl/memory/memory.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
@ -38,7 +40,8 @@ struct BaseOptions {
std::string model_asset_path = "";
// The delegate to run MediaPipe. If the delegate is not set, the default
// delegate CPU is used.
// delegate CPU is used. Use `delegate_options` to configure advanced
// features of the selected delegate."
enum Delegate {
CPU = 0,
GPU = 1,
@ -48,6 +51,30 @@ struct BaseOptions {
Delegate delegate = CPU;
// Options for CPU.
struct CpuOptions {};
// Options for GPU.
struct GpuOptions {
// Load pre-compiled serialized binary cache to accelerate init process.
// Only available on Android. Kernel caching will only be enabled if this
// path is set. NOTE: binary cache usage may be skipped if valid serialized
// model, specified by "serialized_model_dir", exists.
std::string cached_kernel_path;
// A dir to load from and save to a pre-compiled serialized model used to
// accelerate init process.
// NOTE: serialized model takes precedence over binary cache
// specified by "cached_kernel_path", which still can be used if
// serialized model is invalid or missing.
std::string serialized_model_dir;
// Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens.
std::string model_token;
};
// The file descriptor to a file opened with open(2), with optional additional
// offset and length information.
struct FileDescriptorMeta {
@ -67,6 +94,10 @@ struct BaseOptions {
// built-in Ops.
std::unique_ptr<tflite::OpResolver> op_resolver =
absl::make_unique<MediaPipeBuiltinOpResolver>();
// Options for the chosen delegate. If not set, the default delegate options
// is used.
std::optional<std::variant<CpuOptions, GpuOptions>> delegate_options;
};
// Converts a BaseOptions to a BaseOptionsProto.

View File

@ -1,6 +1,9 @@
#include "mediapipe/tasks/cc/core/base_options.h"
#include <memory>
#include <optional>
#include <string>
#include <variant>
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/port/gmock.h"
@ -11,6 +14,8 @@
constexpr char kTestModelBundlePath[] =
"mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
constexpr char kCachedModelDir[] = "/data/local/tmp";
constexpr char kModelToken[] = "dummy_model_token";
namespace mediapipe {
namespace tasks {
@ -40,6 +45,44 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) {
EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
}
TEST(DelegateOptionsTest, SucceedCpuOptions) {
BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::CPU;
BaseOptions::CpuOptions cpu_options;
base_options.delegate_options = cpu_options;
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
EXPECT_TRUE(proto.acceleration().has_tflite());
ASSERT_FALSE(proto.acceleration().has_gpu());
}
TEST(DelegateOptionsTest, SucceedGpuOptions) {
BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::GPU;
BaseOptions::GpuOptions gpu_options;
gpu_options.cached_kernel_path = kCachedModelDir;
gpu_options.model_token = kModelToken;
base_options.delegate_options = gpu_options;
proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
ASSERT_TRUE(proto.acceleration().has_gpu());
ASSERT_FALSE(proto.acceleration().has_tflite());
EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir);
EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
}
TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) {
BaseOptions base_options;
base_options.delegate = BaseOptions::Delegate::CPU;
BaseOptions::GpuOptions gpu_options;
gpu_options.cached_kernel_path = kCachedModelDir;
gpu_options.model_token = kModelToken;
base_options.delegate_options = gpu_options;
ASSERT_DEATH(
{ proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); },
"Specified Delegate type does not match the provided "
"delegate options.");
}
} // namespace
} // namespace core
} // namespace tasks

View File

@ -81,7 +81,6 @@ class TaskApiFactory {
return std::make_unique<T>(std::move(runner));
}
private:
template <typename Options>
static absl::Status CheckHasValidOptions(
const CalculatorGraphConfig::Node& node) {

View File

@ -86,10 +86,9 @@ cc_test(
"//mediapipe/tasks/cc/components/containers:classification_result",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@com_google_sentencepiece//src:sentencepiece_processor", # fixdeps: keep
"@org_tensorflow//tensorflow/lite:test_util",
],
)

View File

@ -15,8 +15,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
#include <cmath>
#include <cstdlib>
#include <memory>
#include <sstream>
#include <string>
@ -24,7 +22,6 @@ limitations under the License.
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"

View File

@ -45,7 +45,7 @@ constexpr char kUniversalSentenceEncoderModel[] =
// Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4;
// Tolerancy for cosine similarity evaluation.
constexpr double kSimilarityTolerancy = 1e-6;
constexpr double kSimilarityTolerancy = 2e-2;
using ::mediapipe::file::JoinPath;
using ::testing::HasSubstr;
@ -79,6 +79,8 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
#ifdef _WIN32
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
#elif defined(__FMA__)
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.3605f, kEpsilon);
#else
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
#endif // _WIN32
@ -87,7 +89,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
#ifdef __FMA__
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 21.254150f, kEpsilon);
#else
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
#endif
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(

View File

@ -43,6 +43,7 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:rect",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status",
@ -58,6 +59,7 @@ cc_library(
":base_vision_task_api",
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/tasks/cc/core:task_api_factory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -60,13 +61,8 @@ class VisionTaskApiFactory {
"Task graph config should only contain one task subgraph node.",
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
} else {
if (!node.options().HasExtension(Options::ext)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat(node.calculator(),
" is missing the required task options field."),
MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
}
MP_RETURN_IF_ERROR(
tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
found_task_subgraph = true;
}
}

View File

@ -153,6 +153,8 @@ cc_library(
alwayslink = 1,
)
# TODO: open source hand joints graph
cc_library(
name = "hand_landmarker_result",
srcs = ["hand_landmarker_result.cc"],

View File

@ -41,3 +41,5 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
],
)
# TODO: open source hand joints graph

View File

@ -52,6 +52,7 @@ cc_library(
name = "interactive_segmenter_graph",
srcs = ["interactive_segmenter_graph.cc"],
deps = [
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/image:set_alpha_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:flat_color_image_calculator",
@ -60,6 +61,7 @@ cc_library(
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",

View File

@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
@ -35,6 +37,51 @@ namespace mediapipe {
namespace tasks {
namespace vision {
namespace interactive_segmenter {
namespace internal {
// A calculator to add thickness to the render data according to the image size,
// so that the render data is scale invariant to the image size. If the render
// data already has thickness, it will be kept as is.
class AddThicknessToRenderDataCalculator : public api2::Node {
public:
static constexpr api2::Input<Image> kImageIn{"IMAGE"};
static constexpr api2::Input<mediapipe::RenderData> kRenderDataIn{
"RENDER_DATA"};
static constexpr api2::Output<mediapipe::RenderData> kRenderDataOut{
"RENDER_DATA"};
static constexpr int kModelInputTensorWidth = 512;
static constexpr int kModelInputTensorHeight = 512;
MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut);
absl::Status Process(CalculatorContext* cc) final {
mediapipe::RenderData render_data = kRenderDataIn(cc).Get();
Image image = kImageIn(cc).Get();
double thickness = std::max(
std::max(image.width() / static_cast<double>(kModelInputTensorWidth),
image.height() / static_cast<double>(kModelInputTensorHeight)),
1.0);
for (auto& annotation : *render_data.mutable_render_annotations()) {
if (!annotation.has_thickness()) {
annotation.set_thickness(thickness);
}
}
kRenderDataOut(cc).Send(render_data);
return absl::OkStatus();
}
};
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
// moved to next line.
// clang-format off
MEDIAPIPE_REGISTER_NODE(
::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator);
// clang-format on
// NOLINTEND
} // namespace internal
namespace {
@ -59,6 +106,7 @@ constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
constexpr absl::string_view kRoiTag{"ROI"};
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
constexpr absl::string_view kRenderDataTag{"RENDER_DATA"};
// Updates the graph to return `roi` stream which has same dimension as
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
@ -69,14 +117,23 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
const absl::string_view image_tag_with_suffix =
use_gpu ? kImageGpuTag : kImageCpuTag;
// Adds thickness to the render data so that the render data is scale
// invariant to the input image size.
auto& add_thickness = graph.AddNode(
"mediapipe::tasks::vision::interactive_segmenter::internal::"
"AddThicknessToRenderDataCalculator");
image >> add_thickness.In(kImageTag);
roi >> add_thickness.In(kRenderDataTag);
auto roi_with_thickness = add_thickness.Out(kRenderDataTag);
// Generates a blank canvas with same size as input image.
auto& flat_color = graph.AddNode("FlatColorImageCalculator");
auto& flat_color_options =
flat_color.GetOptions<FlatColorImageCalculatorOptions>();
// SetAlphaCalculator only takes 1st channel.
flat_color_options.mutable_color()->set_r(0);
image >> flat_color.In(kImageTag)[0];
auto blank_canvas = flat_color.Out(kImageTag)[0];
image >> flat_color.In(kImageTag);
auto blank_canvas = flat_color.Out(kImageTag);
auto& from_mp_image = graph.AddNode("FromImageCalculator");
blank_canvas >> from_mp_image.In(kImageTag);
@ -85,7 +142,7 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
blank_canvas_in_cpu_or_gpu >>
roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
roi >> roi_to_alpha.In(0);
roi_with_thickness >> roi_to_alpha.In(0);
auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
return alpha;
@ -163,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
image >> from_mp_image.In(kImageTag);
auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
// Creates an RGBA image with model input tensor size.
auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
auto& set_alpha = graph.AddNode("SetAlphaCalculator");

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -70,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
// Golden mask for the dogs in cats_and_dogs.jpg.
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"};
constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"};
constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"};
constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"};
constexpr float kGoldenMaskSimilarity = 0.97;
@ -183,6 +188,7 @@ struct InteractiveSegmenterTestParams {
std::string test_name;
RegionOfInterest::Format format;
std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
absl::string_view input_image_file;
absl::string_view golden_mask_file;
float similarity_threshold;
};
@ -220,8 +226,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
params.input_image_file)));
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
@ -244,6 +250,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
EXPECT_THAT(actual_mask,
SimilarToUint8Mask(expected_mask, params.similarity_threshold,
kGoldenMaskMagnificationFactor));
cv::Mat visualized_mask;
actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
visualized_mask.cols, visualized_mask.rows,
visualized_mask.step, visualized_mask.data,
[visualized_mask](uint8_t[]) {});
MP_EXPECT_OK(SavePngTestOutput(
visualized_image, absl::StrFormat("%s_category_mask", params.test_name)));
}
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
@ -252,8 +267,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
const InteractiveSegmenterTestParams& params = GetParam();
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
params.input_image_file)));
auto options = std::make_unique<InteractiveSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kPtmModel);
@ -275,6 +290,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
params.similarity_threshold));
cv::Mat visualized_mask;
actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
visualized_mask.cols, visualized_mask.rows,
visualized_mask.step, visualized_mask.data,
[visualized_mask](uint8_t[]) {});
MP_EXPECT_OK(SavePngTestOutput(
visualized_image,
absl::StrFormat("%s_confidence_mask", params.test_name)));
}
INSTANTIATE_TEST_SUITE_P(
@ -282,21 +306,28 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn<InteractiveSegmenterTestParams>(
{// Keypoint input.
{"PointToDog1", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1,
0.84f},
{"PointToDog2", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2,
kGoldenMaskSimilarity},
{"PenguinsSmall", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask,
0.9f},
{"PenguinsLarge", RegionOfInterest::Format::kKeyPoint,
NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask,
0.9f},
// Scribble input.
{"ScribbleToDog1", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.44, 0.70},
NormalizedKeypoint{0.44, 0.71},
NormalizedKeypoint{0.44, 0.72}},
kCatsAndDogsMaskDog1, 0.84f},
kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f},
{"ScribbleToDog2", RegionOfInterest::Format::kScribble,
std::vector{NormalizedKeypoint{0.66, 0.66},
NormalizedKeypoint{0.66, 0.67},
NormalizedKeypoint{0.66, 0.68}},
kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
[](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
info) { return info.param.test_name; });

View File

@ -60,3 +60,9 @@ objc_library(
srcs = ["sources/MPPLandmark.m"],
hdrs = ["sources/MPPLandmark.h"],
)
objc_library(
name = "MPPConnection",
srcs = ["sources/MPPConnection.m"],
hdrs = ["sources/MPPConnection.h"],
)

View File

@ -0,0 +1,44 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/** The value class representing a landmark connection. */
NS_SWIFT_NAME(Connection)
@interface MPPConnection : NSObject
@property(nonatomic, readonly) NSUInteger start;
@property(nonatomic, readonly) NSUInteger end;
/**
* Initializes a new `MPPConnection` with the start and end landmarks integer constants.
*
* @param start The integer representing the starting landmark of the connection.
* @param end The integer representing the ending landmark of the connection.
*
* @return An instance of `MPPConnection` initialized with the given start and end landmarks integer
* constants.
*/
- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end NS_DESIGNATED_INITIALIZER;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
@implementation MPPConnection
- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end {
self = [super init];
if (self) {
_start = start;
_end = end;
}
return self;
}
@end

View File

@ -54,3 +54,20 @@ ios_unit_test(
":MPPImageObjcTestLibrary",
],
)
objc_library(
name = "MPPMaskObjcTestLibrary",
testonly = 1,
srcs = ["MPPMaskTests.m"],
deps = ["//mediapipe/tasks/ios/vision/core:MPPMask"],
)
ios_unit_test(
name = "MPPMaskObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPMaskObjcTestLibrary",
],
)

View File

@ -0,0 +1,127 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
#import <XCTest/XCTest.h>
/** Unit tests for `MPPMask`. */
@interface MPPMaskTests : XCTestCase
@end
@implementation MPPMaskTests
#pragma mark - Tests
- (void)testInitWithUInt8ArrayNoCopySucceeds {
NSInteger width = 2;
NSInteger height = 3;
UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:NO];
XCTAssertEqual(mask.width, width);
XCTAssertEqual(mask.height, height);
// Test if UInt8 mask is not copied.
XCTAssertEqual(mask.uint8Data, (const UInt8*)uint8Data);
XCTAssertNotEqual(mask.float32Data, NULL);
for (int i = 0 ; i < width * height ; i ++) {
XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f, @"index i = %d", i);
}
// Test if repeated Float32 mask accesses return the same array in memory.
XCTAssertEqual(mask.float32Data, mask.float32Data);
}
- (void)testInitWithUInt8ArrayCopySucceeds {
NSInteger width = 2;
NSInteger height = 3;
UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:YES];
XCTAssertEqual(mask.width, width);
XCTAssertEqual(mask.height, height);
// Test if UInt8 mask is copied.
XCTAssertNotEqual(mask.uint8Data, (const UInt8*)uint8Data);
XCTAssertNotEqual(mask.float32Data, NULL);
for (int i = 0 ; i < width * height ; i ++) {
XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f);
}
// Test if repeated Float32 mask accesses return the same array in memory.
XCTAssertEqual(mask.float32Data, mask.float32Data);
}
- (void)testInitWithFloat32ArrayNoCopySucceeds {
NSInteger width = 2;
NSInteger height = 3;
UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:NO];
XCTAssertEqual(mask.width, width);
XCTAssertEqual(mask.height, height);
// Test if Float32 mask is not copied.
XCTAssertEqual(mask.float32Data, (const float*)float32Data);
XCTAssertNotEqual(mask.uint8Data, NULL);
for (int i = 0 ; i < width * height ; i ++) {
XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
}
// Test if repeated UInt8 mask accesses return the same array in memory.
XCTAssertEqual(mask.uint8Data, mask.uint8Data);
}
- (void)testInitWithFloat32ArrayCopySucceeds {
NSInteger width = 2;
NSInteger height = 3;
UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:YES];
XCTAssertEqual(mask.width, width);
XCTAssertEqual(mask.height, height);
// Test if Float32 mask is copied.
XCTAssertNotEqual(mask.float32Data, (const float*)float32Data);
XCTAssertNotEqual(mask.uint8Data, NULL);
for (int i = 0 ; i < width * height ; i ++) {
XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
}
// Test if repeated UInt8 mask accesses return the same array in memory.
XCTAssertEqual(mask.uint8Data, mask.uint8Data);
}
@end

View File

@ -155,12 +155,12 @@ static const float kKeypointErrorThreshold = 1e-2;
NSInteger iterationCount = 100;
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
// normal expectation will fail if expectation.fullfill() is not called
// normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since it is not possible to predict how many times the expectation is supposed to be
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
// fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
// `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
@ -385,13 +385,13 @@ static const float kKeypointErrorThreshold = 1e-2;
NSInteger iterationCount = 100;
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
// normal expectation will fail if expectation.fullfill() is not called times. An normal
// expectation will fail if expectation.fullfill() is not called
// normal expectation will fail if expectation.fulfill() is not called times. An normal
// expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since it it not possible to determine how many times the expectation is supposed to be
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
// fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
// `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];

View File

@ -174,12 +174,12 @@ constexpr float kFacialTransformationMatrixErrorThreshold = 0.2f;
NSInteger iterationCount = 100;
// Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
// normal expectation will fail if expectation.fullfill() is not called
// normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since it is not possible to predict how many times the expectation is supposed to be
// fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
// fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
// `iterationCount` times.
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];

View File

@ -0,0 +1,62 @@
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load(
"//mediapipe/framework/tool:ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]
objc_library(
name = "MPPGestureRecognizerObjcTestLibrary",
testonly = 1,
srcs = ["MPPGestureRecognizerTests.m"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
data = [
"//mediapipe/tasks/testdata/vision:gesture_recognizer.task",
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/test/vision/gesture_recognizer/utils:MPPGestureRecognizerResultProtobufHelpers",
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
"//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer",
] + select({
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
}),
)
ios_unit_test(
name = "MPPGestureRecognizerObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPGestureRecognizerObjcTestLibrary",
],
)

View File

@ -0,0 +1,706 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizer.h"
static NSString *const kPbFileExtension = @"pbtxt";
typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
static ResourceFileInfo *const kGestureRecognizerBundleAssetFile =
@{@"name" : @"gesture_recognizer", @"type" : @"task"};
static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
static ResourceFileInfo *const kFistImage = @{@"name" : @"fist", @"type" : @"jpg"};
static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
static ResourceFileInfo *const kPointingUpRotatedImage =
@{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
static ResourceFileInfo *const kExpectedFistLandmarksFile =
@{@"name" : @"fist_landmarks", @"type" : kPbFileExtension};
static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
@{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
static NSString *const kFistLabel = @"Closed_Fist";
static NSString *const kExpectedThumbUpLabel = @"Thumb_Up";
static NSString *const kExpectedPointingUpLabel = @"Pointing_Up";
static NSString *const kRockLabel = @"Rock";
static const NSInteger kGestureExpectedIndex = -1;
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
static const float kLandmarksErrorTolerance = 0.03f;
static NSString *const kLiveStreamTestsDictGestureRecognizerKey = @"gesture_recognizer";
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertEqualGestures(gesture, expectedGesture, handIndex, gestureIndex) \
XCTAssertEqual(gesture.index, kGestureExpectedIndex, @"hand index = %d gesture index j = %d", \
handIndex, gestureIndex); \
XCTAssertEqualObjects(gesture.categoryName, expectedGesture.categoryName, \
@"hand index = %d gesture index j = %d", handIndex, gestureIndex);
#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex) \
XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance, \
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance, \
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
#define AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult) \
XCTAssertTrue(gestureRecognizerResult.gestures.count == 0); \
XCTAssertTrue(gestureRecognizerResult.handedness.count == 0); \
XCTAssertTrue(gestureRecognizerResult.landmarks.count == 0); \
XCTAssertTrue(gestureRecognizerResult.worldLandmarks.count == 0);
@interface MPPGestureRecognizerTests : XCTestCase <MPPGestureRecognizerLiveStreamDelegate> {
NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
}
@end
@implementation MPPGestureRecognizerTests
#pragma mark Expected Results
+ (MPPGestureRecognizerResult *)emptyGestureRecognizerResult {
return [[MPPGestureRecognizerResult alloc] initWithGestures:@[]
handedness:@[]
landmarks:@[]
worldLandmarks:@[]
timestampInMilliseconds:0];
}
+ (MPPGestureRecognizerResult *)thumbUpGestureRecognizerResult {
NSString *filePath =
[MPPGestureRecognizerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
return [MPPGestureRecognizerResult
gestureRecognizerResultsFromProtobufFileWithName:filePath
gestureLabel:kExpectedThumbUpLabel
shouldRemoveZPosition:YES];
}
+ (MPPGestureRecognizerResult *)fistGestureRecognizerResultWithLabel:(NSString *)gestureLabel {
NSString *filePath = [MPPGestureRecognizerTests filePathWithFileInfo:kExpectedFistLandmarksFile];
return [MPPGestureRecognizerResult gestureRecognizerResultsFromProtobufFileWithName:filePath
gestureLabel:gestureLabel
shouldRemoveZPosition:YES];
}
#pragma mark Assert Gesture Recognizer Results
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
areApproximatelyEqualToExpectedMultiHandLandmarks:
(NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
if (multiHandLandmarks.count == 0) {
return;
}
NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
MPPNormalizedLandmark *landmark = topHandLandmarks[i];
XCTAssertNotNil(landmark);
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
}
}
- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
(NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
if (expectedMultiHandWorldLandmarks.count == 0) {
return;
}
NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
MPPLandmark *landmark = topHandWorldLandmarks[i];
XCTAssertNotNil(landmark);
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
}
}
- (void)assertMultiHandGestures:(NSArray<NSArray<MPPCategory *> *> *)multiHandGestures
areApproximatelyEqualToExpectedMultiHandGestures:
(NSArray<NSArray<MPPCategory *> *> *)expectedMultiHandGestures {
XCTAssertEqual(multiHandGestures.count, expectedMultiHandGestures.count);
if (multiHandGestures.count == 0) {
return;
}
NSArray<MPPCategory *> *topHandGestures = multiHandGestures[0];
NSArray<MPPCategory *> *expectedTopHandGestures = expectedMultiHandGestures[0];
XCTAssertEqual(topHandGestures.count, expectedTopHandGestures.count);
for (int i = 0; i < expectedTopHandGestures.count; i++) {
MPPCategory *gesture = topHandGestures[i];
XCTAssertNotNil(gesture);
AssertEqualGestures(gesture, expectedTopHandGestures[i], 0, i);
}
}
- (void)assertGestureRecognizerResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
isApproximatelyEqualToExpectedResult:
(MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
[self assertMultiHandLandmarks:gestureRecognizerResult.landmarks
areApproximatelyEqualToExpectedMultiHandLandmarks:expectedGestureRecognizerResult.landmarks];
[self assertMultiHandWorldLandmarks:gestureRecognizerResult.worldLandmarks
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedGestureRecognizerResult
.worldLandmarks];
[self assertMultiHandGestures:gestureRecognizerResult.gestures
areApproximatelyEqualToExpectedMultiHandGestures:expectedGestureRecognizerResult.gestures];
}
- (void)assertResultsOfRecognizeImageWithFileInfo:(ResourceFileInfo *)fileInfo
usingGestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
approximatelyEqualsGestureRecognizerResult:
(MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
MPPGestureRecognizerResult *gestureRecognizerResult =
[self recognizeImageWithFileInfo:fileInfo usingGestureRecognizer:gestureRecognizer];
[self assertGestureRecognizerResult:gestureRecognizerResult
isApproximatelyEqualToExpectedResult:expectedGestureRecognizerResult];
}
#pragma mark File
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
NSString *filePath = [MPPGestureRecognizerTests filePathWithName:fileInfo[@"name"]
extension:fileInfo[@"type"]];
return filePath;
}
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
return filePath;
}
#pragma mark Gesture Recognizer Initializers
- (MPPGestureRecognizerOptions *)gestureRecognizerOptionsWithModelFileInfo:
(ResourceFileInfo *)modelFileInfo {
NSString *modelPath = [MPPGestureRecognizerTests filePathWithFileInfo:modelFileInfo];
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[[MPPGestureRecognizerOptions alloc] init];
gestureRecognizerOptions.baseOptions.modelAssetPath = modelPath;
return gestureRecognizerOptions;
}
- (MPPGestureRecognizer *)createGestureRecognizerWithOptionsSucceeds:
(MPPGestureRecognizerOptions *)gestureRecognizerOptions {
MPPGestureRecognizer *gestureRecognizer =
[[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:nil];
XCTAssertNotNil(gestureRecognizer);
return gestureRecognizer;
}
- (void)assertCreateGestureRecognizerWithOptions:
(MPPGestureRecognizerOptions *)gestureRecognizerOptions
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPGestureRecognizer *gestureRecognizer =
[[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:&error];
XCTAssertNil(gestureRecognizer);
AssertEqualErrors(error, expectedError);
}
#pragma mark Recognize Helpers
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
fileName:fileInfo[@"name"]
ofType:fileInfo[@"type"]];
XCTAssertNotNil(image);
return image;
}
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
orientation:(UIImageOrientation)orientation {
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
fileName:fileInfo[@"name"]
ofType:fileInfo[@"type"]
orientation:orientation];
XCTAssertNotNil(image);
return image;
}
- (MPPGestureRecognizerResult *)recognizeImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
usingGestureRecognizer:
(MPPGestureRecognizer *)gestureRecognizer {
MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
error:nil];
XCTAssertNotNil(gestureRecognizerResult);
return gestureRecognizerResult;
}
#pragma mark General Tests
- (void)testRecognizeWithModelPathSucceeds {
NSString *modelPath =
[MPPGestureRecognizerTests filePathWithFileInfo:kGestureRecognizerBundleAssetFile];
MPPGestureRecognizer *gestureRecognizer =
[[MPPGestureRecognizer alloc] initWithModelPath:modelPath error:nil];
XCTAssertNotNil(gestureRecognizer);
[self assertResultsOfRecognizeImageWithFileInfo:kThumbUpImage
usingGestureRecognizer:gestureRecognizer
approximatelyEqualsGestureRecognizerResult:[MPPGestureRecognizerTests
thumbUpGestureRecognizerResult]];
}
- (void)testRecognizeWithEmptyResultsSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
MPPGestureRecognizerResult *gestureRecognizerResult =
[self recognizeImageWithFileInfo:kNoHandsImage usingGestureRecognizer:gestureRecognizer];
AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
}
- (void)testRecognizeWithScoreThresholdSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
MPPGestureRecognizerResult *gestureRecognizerResult =
[self recognizeImageWithFileInfo:kThumbUpImage usingGestureRecognizer:gestureRecognizer];
MPPGestureRecognizerResult *expectedGestureRecognizerResult =
[MPPGestureRecognizerTests thumbUpGestureRecognizerResult];
XCTAssertTrue(gestureRecognizerResult.gestures.count == 1);
AssertEqualGestures(gestureRecognizerResult.gestures[0][0],
expectedGestureRecognizerResult.gestures[0][0], 0, 0);
}
- (void)testRecognizeWithNumHandsSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
const NSInteger numHands = 2;
gestureRecognizerOptions.numHands = numHands;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
MPPGestureRecognizerResult *gestureRecognizerResult =
[self recognizeImageWithFileInfo:kTwoHandsImage usingGestureRecognizer:gestureRecognizer];
XCTAssertTrue(gestureRecognizerResult.handedness.count == numHands);
}
- (void)testRecognizeWithRotationSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
gestureRecognizerOptions.numHands = 1;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
orientation:UIImageOrientationRight];
MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
error:nil];
XCTAssertNotNil(gestureRecognizerResult);
XCTAssertEqual(gestureRecognizerResult.gestures.count, 1);
XCTAssertEqualObjects(gestureRecognizerResult.gestures[0][0].categoryName,
kExpectedPointingUpLabel);
}
- (void)testRecognizeWithCannedGestureFistSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
gestureRecognizerOptions.numHands = 1;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
[self assertResultsOfRecognizeImageWithFileInfo:kFistImage
usingGestureRecognizer:gestureRecognizer
approximatelyEqualsGestureRecognizerResult:
[MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
}
- (void)testRecognizeWithAllowGestureFistSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
gestureRecognizerOptions.numHands = 1;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
[self assertResultsOfRecognizeImageWithFileInfo:kFistImage
usingGestureRecognizer:gestureRecognizer
approximatelyEqualsGestureRecognizerResult:
[MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
}
- (void)testRecognizeWithDenyGestureFistSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
gestureRecognizerOptions.numHands = 1;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
MPPGestureRecognizerResult *gestureRecognizerResult =
[self recognizeImageWithFileInfo:kFistImage usingGestureRecognizer:gestureRecognizer];
AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
}
- (void)testRecognizeWithPreferAllowlistOverDenylistSucceeds {
MPPGestureRecognizerOptions *gestureRecognizerOptions =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
gestureRecognizerOptions.numHands = 1;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
[self assertResultsOfRecognizeImageWithFileInfo:kFistImage
usingGestureRecognizer:gestureRecognizer
approximatelyEqualsGestureRecognizerResult:
[MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
}
#pragma mark Running Mode Tests
- (void)testCreateGestureRecognizerFailsWithDelegateInNonLiveStreamMode {
MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = runningModesToTest[i];
options.gestureRecognizerLiveStreamDelegate = self;
[self assertCreateGestureRecognizerWithOptions:options
failsWithExpectedError:
[NSError
errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"The vision task is in image or video mode. The "
@"delegate must not be set in the task's options."
}]];
}
}
- (void)testCreateGestureRecognizerFailsWithMissingDelegateInLiveStreamMode {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
[self
assertCreateGestureRecognizerWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"The vision task is in live stream mode. An "
@"object must be set as the delegate of the task "
@"in its options to ensure asynchronous delivery "
@"of results."
}]];
}
- (void)testRecognizeFailsWithCallingWrongApiInImageMode {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kFistImage];
NSError *liveStreamApiCallError;
XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
@"stream mode. Current Running Mode: Image"
}];
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"video mode. Current Running Mode: Image"
}];
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
}
- (void)testRecognizeFailsWithCallingWrongApiInVideoMode {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = MPPRunningModeVideo;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kFistImage];
NSError *liveStreamApiCallError;
XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
@"stream mode. Current Running Mode: Video"
}];
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *imageApiCallError;
XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
NSError *expectedImageApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"image mode. Current Running Mode: Video"
}];
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
}
- (void)testRecognizeFailsWithCallingWrongApiInLiveStreamMode {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
options.gestureRecognizerLiveStreamDelegate = self;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kFistImage];
NSError *imageApiCallError;
XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
NSError *expectedImageApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"image mode. Current Running Mode: Live Stream"
}];
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"video mode. Current Running Mode: Live Stream"
}];
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
}
- (void)testRecognizeWithVideoModeSucceeds {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = MPPRunningModeVideo;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
for (int i = 0; i < 3; i++) {
MPPGestureRecognizerResult *gestureRecognizerResult =
[gestureRecognizer recognizeVideoFrame:image timestampInMilliseconds:i error:nil];
[self assertGestureRecognizerResult:gestureRecognizerResult
isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
thumbUpGestureRecognizerResult]];
}
}
- (void)testRecognizeWithOutOfOrderTimestampsAndLiveStreamModeFails {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
options.gestureRecognizerLiveStreamDelegate = self;
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"recognizeWithOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = 1;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:options];
_outOfOrderTimestampTestDict = @{
kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
kLiveStreamTestsDictExpectationKey : expectation
};
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image timestampInMilliseconds:1 error:nil]);
NSError *error;
XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
timestampInMilliseconds:0
error:&error]);
NSError *expectedError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
}];
AssertEqualErrors(error, expectedError);
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
- (void)testRecognizeWithLiveStreamModeSucceeds {
MPPGestureRecognizerOptions *options =
[self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
options.gestureRecognizerLiveStreamDelegate = self;
NSInteger iterationCount = 100;
// Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
// times. An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is supposed to be fulfilled
// setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds ifexpectation is fulfilled <=
// `iterationCount` times.
XCTestExpectation *expectation =
[[XCTestExpectation alloc] initWithDescription:@"recognizeWithLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;
expectation.inverted = YES;
MPPGestureRecognizer *gestureRecognizer =
[self createGestureRecognizerWithOptionsSucceeds:options];
_liveStreamSucceedsTestDict = @{
kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
kLiveStreamTestsDictExpectationKey : expectation
};
// TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
// with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
// `CMSampleBuffer`.
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
for (int i = 0; i < iterationCount; i++) {
XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image
timestampInMilliseconds:i
error:nil]);
}
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
- (void)gestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
didFinishRecognitionWithResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error {
[self assertGestureRecognizerResult:gestureRecognizerResult
isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
thumbUpGestureRecognizerResult]];
if (gestureRecognizer == _outOfOrderTimestampTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
[_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
} else if (gestureRecognizer ==
_liveStreamSucceedsTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
[_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
}
}
@end

View File

@ -0,0 +1,22 @@
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPGestureRecognizerResultProtobufHelpers",
srcs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.mm"],
hdrs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.h"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
"//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizerResult",
"//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
],
)

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPGestureRecognizerResult (ProtobufHelpers)
+ (MPPGestureRecognizerResult *)
gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
gestureLabel:(NSString *)gestureLabel
shouldRemoveZPosition:(BOOL)removeZPosition;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,65 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+Helpers.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
namespace {
using ClassificationListProto = ::mediapipe::ClassificationList;
using ClassificationProto = ::mediapipe::Classification;
using LandmarksDetectionResultProto =
::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
} // anonymous namespace
@implementation MPPGestureRecognizerResult (ProtobufHelpers)
+ (MPPGestureRecognizerResult *)
gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
gestureLabel:(NSString *)gestureLabel
shouldRemoveZPosition:(BOOL)removeZPosition {
LandmarksDetectionResultProto landmarkDetectionResultProto;
if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
return nil;
}
if (removeZPosition) {
// Remove z position of landmarks, because they are not used in correctness testing. For video
// or live stream mode, the z positions varies a lot during tracking from frame to frame.
for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
landmark.clear_z();
}
}
ClassificationListProto gesturesProto;
ClassificationProto *classificationProto = gesturesProto.add_classification();
classificationProto->set_label([gestureLabel UTF8String]);
return [MPPGestureRecognizerResult
gestureRecognizerResultWithHandGesturesProto:{gesturesProto}
handednessProto:{landmarkDetectionResultProto.classifications()}
handLandmarksProto:{landmarkDetectionResultProto.landmarks()}
worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
timestampInMilliseconds:0];
}
@end

View File

@ -0,0 +1,62 @@
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load(
"//mediapipe/framework/tool:ios.bzl",
"MPP_TASK_MINIMUM_OS_VERSION",
)
load(
"@org_tensorflow//tensorflow/lite:special_rules.bzl",
"tflite_ios_lab_runner",
)
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
TFL_DEFAULT_TAGS = [
"apple",
]
# Following sanitizer tests are not supported by iOS test targets.
TFL_DISABLED_SANITIZER_TAGS = [
"noasan",
"nomsan",
"notsan",
]
objc_library(
name = "MPPHandLandmarkerObjcTestLibrary",
testonly = 1,
srcs = ["MPPHandLandmarkerTests.m"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
data = [
"//mediapipe/tasks/testdata/vision:hand_landmarker.task",
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/test/vision/hand_landmarker/utils:MPPHandLandmarkerResultProtobufHelpers",
"//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker",
] + select({
"//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
"//conditions:default": ["@ios_opencv//:OpencvFramework"],
}),
)
ios_unit_test(
name = "MPPHandLandmarkerObjcTest",
minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":MPPHandLandmarkerObjcTestLibrary",
],
)

View File

@ -0,0 +1,557 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <XCTest/XCTest.h>
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarker.h"
static NSString *const kPbFileExtension = @"pbtxt";
typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
static ResourceFileInfo *const kHandLandmarkerBundleAssetFile =
@{@"name" : @"hand_landmarker", @"type" : @"task"};
static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
static ResourceFileInfo *const kPointingUpRotatedImage =
@{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
@{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
@{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
static const float kLandmarksErrorTolerance = 0.03f;
static NSString *const kLiveStreamTestsDictHandLandmarkerKey = @"gesture_recognizer";
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
#define AssertEqualErrors(error, expectedError) \
XCTAssertNotNil(error); \
XCTAssertEqualObjects(error.domain, expectedError.domain); \
XCTAssertEqual(error.code, expectedError.code); \
XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex) \
XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance, \
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance, \
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
#define AssertHandLandmarkerResultIsEmpty(handLandmarkerResult) \
XCTAssertTrue(handLandmarkerResult.handedness.count == 0); \
XCTAssertTrue(handLandmarkerResult.landmarks.count == 0); \
XCTAssertTrue(handLandmarkerResult.worldLandmarks.count == 0);
@interface MPPHandLandmarkerTests : XCTestCase <MPPHandLandmarkerLiveStreamDelegate> {
NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
}
@end
@implementation MPPHandLandmarkerTests
#pragma mark Results
+ (MPPHandLandmarkerResult *)emptyHandLandmarkerResult {
return [[MPPHandLandmarkerResult alloc] initWithLandmarks:@[]
worldLandmarks:@[]
handedness:@[]
timestampInMilliseconds:0];
}
+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
NSString *filePath = [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
shouldRemoveZPosition:YES];
}
+ (MPPHandLandmarkerResult *)pointingUpRotatedHandLandmarkerResult {
NSString *filePath =
[MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
shouldRemoveZPosition:YES];
}
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
areApproximatelyEqualToExpectedMultiHandLandmarks:
(NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
if (multiHandLandmarks.count == 0) {
return;
}
NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
MPPNormalizedLandmark *landmark = topHandLandmarks[i];
XCTAssertNotNil(landmark);
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
}
}
- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
(NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
if (expectedMultiHandWorldLandmarks.count == 0) {
return;
}
NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
MPPLandmark *landmark = topHandWorldLandmarks[i];
XCTAssertNotNil(landmark);
AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
}
}
- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
isApproximatelyEqualToExpectedResult:(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
[self assertMultiHandLandmarks:handLandmarkerResult.landmarks
areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
[self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedHandLandmarkerResult
.worldLandmarks];
}
#pragma mark File
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
NSString *filePath = [MPPHandLandmarkerTests filePathWithName:fileInfo[@"name"]
extension:fileInfo[@"type"]];
return filePath;
}
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
ofType:extension];
return filePath;
}
#pragma mark Hand Landmarker Initializers
- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
(ResourceFileInfo *)modelFileInfo {
NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
MPPHandLandmarkerOptions *handLandmarkerOptions = [[MPPHandLandmarkerOptions alloc] init];
handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
return handLandmarkerOptions;
}
- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
(MPPHandLandmarkerOptions *)handLandmarkerOptions {
NSError *error;
MPPHandLandmarker *handLandmarker =
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
XCTAssertNotNil(handLandmarker);
XCTAssertNil(error);
return handLandmarker;
}
- (void)assertCreateHandLandmarkerWithOptions:(MPPHandLandmarkerOptions *)handLandmarkerOptions
failsWithExpectedError:(NSError *)expectedError {
NSError *error = nil;
MPPHandLandmarker *handLandmarker =
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
XCTAssertNil(handLandmarker);
AssertEqualErrors(error, expectedError);
}
#pragma mark Assert Hand Landmarker Results
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
fileName:fileInfo[@"name"]
ofType:fileInfo[@"type"]];
XCTAssertNotNil(image);
return image;
}
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
orientation:(UIImageOrientation)orientation {
MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
fileName:fileInfo[@"name"]
ofType:fileInfo[@"type"]
orientation:orientation];
XCTAssertNotNil(image);
return image;
}
- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker {
MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
XCTAssertNotNil(handLandmarkerResult);
return handLandmarkerResult;
}
- (void)assertResultsOfDetectInImageWithFileInfo:(ResourceFileInfo *)fileInfo
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
approximatelyEqualsHandLandmarkerResult:
(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:fileInfo
usingHandLandmarker:handLandmarker];
[self assertHandLandmarkerResult:handLandmarkerResult
isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
}
#pragma mark General Tests
- (void)testDetectWithModelPathSucceeds {
NSString *modelPath =
[MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
MPPHandLandmarker *handLandmarker = [[MPPHandLandmarker alloc] initWithModelPath:modelPath
error:nil];
XCTAssertNotNil(handLandmarker);
[self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
usingHandLandmarker:handLandmarker
approximatelyEqualsHandLandmarkerResult:[MPPHandLandmarkerTests
thumbUpHandLandmarkerResult]];
}
- (void)testDetectWithEmptyResultsSucceeds {
MPPHandLandmarkerOptions *handLandmarkerOptions =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
MPPHandLandmarker *handLandmarker =
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kNoHandsImage
usingHandLandmarker:handLandmarker];
AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
}
- (void)testDetectWithNumHandsSucceeds {
MPPHandLandmarkerOptions *handLandmarkerOptions =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
const NSInteger numHands = 2;
handLandmarkerOptions.numHands = numHands;
MPPHandLandmarker *handLandmarker =
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kTwoHandsImage
usingHandLandmarker:handLandmarker];
XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
}
- (void)testDetectWithRotationSucceeds {
MPPHandLandmarkerOptions *handLandmarkerOptions =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
MPPHandLandmarker *handLandmarker =
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
orientation:UIImageOrientationRight];
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
[self assertHandLandmarkerResult:handLandmarkerResult
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests
pointingUpRotatedHandLandmarkerResult]];
}
#pragma mark Running Mode Tests
- (void)testCreateHandLandmarkerFailsWithDelegateInNonLiveStreamMode {
MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = runningModesToTest[i];
options.handLandmarkerLiveStreamDelegate = self;
[self
assertCreateHandLandmarkerWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"The vision task is in image or video mode. The "
@"delegate must not be set in the task's options."
}]];
}
}
- (void)testCreateHandLandmarkerFailsWithMissingDelegateInLiveStreamMode {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
[self assertCreateHandLandmarkerWithOptions:options
failsWithExpectedError:
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"The vision task is in live stream mode. An "
@"object must be set as the delegate of the task "
@"in its options to ensure asynchronous delivery "
@"of results."
}]];
}
- (void)testDetectFailsWithCallingWrongApiInImageMode {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
NSError *liveStreamApiCallError;
XCTAssertFalse([handLandmarker detectAsyncInImage:image
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
@"stream mode. Current Running Mode: Image"
}];
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([handLandmarker detectInVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"video mode. Current Running Mode: Image"
}];
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
}
- (void)testDetectFailsWithCallingWrongApiInVideoMode {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = MPPRunningModeVideo;
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
NSError *liveStreamApiCallError;
XCTAssertFalse([handLandmarker detectAsyncInImage:image
timestampInMilliseconds:0
error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
@"stream mode. Current Running Mode: Video"
}];
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *imageApiCallError;
XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
NSError *expectedImageApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"image mode. Current Running Mode: Video"
}];
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
}
- (void)testDetectFailsWithCallingWrongApiInLiveStreamMode {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
options.handLandmarkerLiveStreamDelegate = self;
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
NSError *imageApiCallError;
XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
NSError *expectedImageApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"image mode. Current Running Mode: Live Stream"
}];
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
NSError *videoApiCallError;
XCTAssertFalse([handLandmarker detectInVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey : @"The vision task is not initialized with "
@"video mode. Current Running Mode: Live Stream"
}];
AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
}
- (void)testDetectWithVideoModeSucceeds {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = MPPRunningModeVideo;
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
for (int i = 0; i < 3; i++) {
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInVideoFrame:image
timestampInMilliseconds:i
error:nil];
[self assertHandLandmarkerResult:handLandmarkerResult
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
}
}
- (void)testDetectWithOutOfOrderTimestampsAndLiveStreamModeFails {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
options.handLandmarkerLiveStreamDelegate = self;
XCTestExpectation *expectation = [[XCTestExpectation alloc]
initWithDescription:@"detectWiththOutOfOrderTimestampsAndLiveStream"];
expectation.expectedFulfillmentCount = 1;
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
_outOfOrderTimestampTestDict = @{
kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
kLiveStreamTestsDictExpectationKey : expectation
};
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:1 error:nil]);
NSError *error;
XCTAssertFalse([handLandmarker detectAsyncInImage:image timestampInMilliseconds:0 error:&error]);
NSError *expectedError =
[NSError errorWithDomain:kExpectedErrorDomain
code:MPPTasksErrorCodeInvalidArgumentError
userInfo:@{
NSLocalizedDescriptionKey :
@"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
}];
AssertEqualErrors(error, expectedError);
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
- (void)testDetectWithLiveStreamModeSucceeds {
MPPHandLandmarkerOptions *options =
[self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
options.runningMode = MPPRunningModeLiveStream;
options.handLandmarkerLiveStreamDelegate = self;
NSInteger iterationCount = 100;
// Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
// times. An normal expectation will fail if expectation.fulfill() is not called
// `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
// only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is supposed to be fullfilled
// setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds ifexpectation is fullfilled <=
// `iterationCount` times.
XCTestExpectation *expectation =
[[XCTestExpectation alloc] initWithDescription:@"detectWithLiveStream"];
expectation.expectedFulfillmentCount = iterationCount + 1;
expectation.inverted = YES;
MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
_liveStreamSucceedsTestDict = @{
kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
kLiveStreamTestsDictExpectationKey : expectation
};
// TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
// with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
// `CMSampleBuffer`.
MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
for (int i = 0; i < iterationCount; i++) {
XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:i error:nil]);
}
NSTimeInterval timeout = 0.5f;
[self waitForExpectations:@[ expectation ] timeout:timeout];
}
- (void)handLandmarker:(MPPHandLandmarker *)handLandmarker
didFinishDetectionWithResult:(MPPHandLandmarkerResult *)handLandmarkerResult
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError *)error {
[self assertHandLandmarkerResult:handLandmarkerResult
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
if (handLandmarker == _outOfOrderTimestampTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
[_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
} else if (handLandmarker == _liveStreamSucceedsTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
[_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
}
}
@end

View File

@ -0,0 +1,22 @@
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
objc_library(
name = "MPPHandLandmarkerResultProtobufHelpers",
srcs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.mm"],
hdrs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.h"],
copts = [
"-ObjC++",
"-std=c++17",
"-x objective-c++",
],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarkerResult",
"//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
],
)

View File

@ -0,0 +1,26 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarkerResult.h"
NS_ASSUME_NONNULL_BEGIN
@interface MPPHandLandmarkerResult (ProtobufHelpers)
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
shouldRemoveZPosition:(BOOL)removeZPosition;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,58 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
#import "mediapipe/tasks/ios/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+Helpers.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
namespace {
using ClassificationListProto = ::mediapipe::ClassificationList;
using ClassificationProto = ::mediapipe::Classification;
using LandmarksDetectionResultProto =
::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
} // anonymous namespace
@implementation MPPHandLandmarkerResult (ProtobufHelpers)
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
shouldRemoveZPosition:(BOOL)removeZPosition {
LandmarksDetectionResultProto landmarkDetectionResultProto;
if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
return nil;
}
if (removeZPosition) {
// Remove z position of landmarks, because they are not used in correctness testing. For video
// or live stream mode, the z positions varies a lot during tracking from frame to frame.
for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
landmark.clear_z();
}
}
return [MPPHandLandmarkerResult
handLandmarkerResultWithLandmarksProto:{landmarkDetectionResultProto.landmarks()}
worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
handednessProto:{landmarkDetectionResultProto.classifications()}
timestampInMilliSeconds:0];
}
@end

View File

@ -673,10 +673,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
// If `expectation.isInverted = true`, the test will only succeed if
// expectation is not fulfilled for the specified `expectedFulfillmentCount`.
// Since in our case we cannot predict how many times the expectation is
// supposed to be fullfilled setting,
// supposed to be fulfilled setting,
// `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
// `expectation.isInverted = true` ensures that test succeeds if
// expectation is fullfilled <= `iterationCount` times.
// expectation is fulfilled <= `iterationCount` times.
XCTestExpectation *expectation =
[[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];

View File

@ -64,3 +64,17 @@ objc_library(
"@com_google_absl//absl/status:statusor",
],
)
objc_library(
name = "MPPMask",
srcs = ["sources/MPPMask.mm"],
hdrs = ["sources/MPPMask.h"],
copts = [
"-ObjC++",
"-std=c++17",
],
deps = [
"//mediapipe/tasks/ios/common:MPPCommon",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
],
)

View File

@ -0,0 +1,118 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <Foundation/Foundation.h>
NS_ASSUME_NONNULL_BEGIN
/** The underlying type of the segmentation mask. */
typedef NS_ENUM(NSUInteger, MPPMaskDataType) {
/** Represents the native `UInt8 *` type. */
MPPMaskDataTypeUInt8,
/** Represents the native `float *` type. */
MPPMaskDataTypeFloat32,
} NS_SWIFT_NAME(MaskDataType);
/**
* The wrapper class for MediaPipe segmentation masks.
*
* Masks are stored as `UInt8 *` or `float *` objects.
* Every mask has an underlying type which can be accessed using `dataType`. You can access the
* mask as any other type using the appropriate properties. For example, if the underlying type is
* `MPPMaskDataTypeUInt8`, in addition to accessing the mask using `uint8Data`, you can access
* `float32Data` to get the 32 bit float data (with values ranging from 0.0 to 1.0). The first
* time you access the data as a type different from the underlying type, an expensive type
* conversion is performed. Subsequent accesses return a pointer to the memory location fo the same
* type converted array. As type conversions can be expensive, it is recommended to limit the
* accesses to data of types different from the underlying type.
*
* Masks that are returned from a MediaPipe Tasks are owned by by the underlying C++ Task. If you
* need to extend the lifetime of these objects, you can invoke the `[MPPMask copy:]` method.
*/
NS_SWIFT_NAME(Mask)
@interface MPPMask : NSObject <NSCopying>
/** The width of the mask. */
@property(nonatomic, readonly) NSInteger width;
/** The height of the mask. */
@property(nonatomic, readonly) NSInteger height;
/** The data type of the mask. */
@property(nonatomic, readonly) MPPMaskDataType dataType;
/**
* The pointer to the memory location where the underlying mask as a single channel `UInt8` array is
* stored. Uint8 values use the full value range and range from 0 to 255.
*/
@property(nonatomic, readonly, assign) const UInt8 *uint8Data;
/**
* The pointer to the memory location where the underlying mask as a single channel float 32 array
* is stored. Float values range from 0.0 to 1.0.
*/
@property(nonatomic, readonly, assign) const float *float32Data;
/**
* Initializes an `MPPMask` object of type `MPPMaskDataTypeUInt8` with the given `UInt8*` data,
* width and height.
*
* If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
* `uint8Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
* the `MPPMask` must outlive the passed in `uint8Data`.
*
* @param uint8Data A pointer to the memory location of the `UInt8` data array.
* @param width The width of the mask.
* @param height The height of the mask.
* @param shouldCopy The height of the mask.
*
* @return A new `MPPMask` instance with the given `UInt8*` data, width and height.
*/
- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
width:(NSInteger)width
height:(NSInteger)height
shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
/**
* Initializes an `MPPMask` object of type `MPPMaskDataTypeFloat32` with the given `float*` data,
* width and height.
*
* If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
* `float32Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
* the `MPPMask` must outlive the passed in `float32Data`.
*
* @param float32Data A pointer to the memory location of the `float` data array.
* @param width The width of the mask.
* @param height The height of the mask.
*
* @return A new `MPPMask` instance with the given `float*` data, width and height.
*/
- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
width:(NSInteger)width
height:(NSInteger)height
shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
// TODO: Add methods for CVPixelBuffer conversion.
/** Unavailable. */
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;
@end
NS_ASSUME_NONNULL_END

View File

@ -0,0 +1,135 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
@interface MPPMask () {
const UInt8 *_uint8Data;
const float *_float32Data;
std::unique_ptr<UInt8[]> _uint8DataPtr;
std::unique_ptr<float[]> _float32DataPtr;
}
@end
@implementation MPPMask
- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
width:(NSInteger)width
height:(NSInteger)height
shouldCopy:(BOOL)shouldCopy {
self = [super init];
if (self) {
_width = width;
_height = height;
_dataType = MPPMaskDataTypeUInt8;
if (shouldCopy) {
size_t length = _width * _height;
_uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
_uint8Data = _uint8DataPtr.get();
memcpy((UInt8 *)_uint8Data, uint8Data, length * sizeof(UInt8));
} else {
_uint8Data = uint8Data;
}
}
return self;
}
- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
width:(NSInteger)width
height:(NSInteger)height
shouldCopy:(BOOL)shouldCopy {
self = [super init];
if (self) {
_width = width;
_height = height;
_dataType = MPPMaskDataTypeFloat32;
if (shouldCopy) {
size_t length = _width * _height;
_float32DataPtr = std::unique_ptr<float[]>(new float[length]);
_float32Data = _float32DataPtr.get();
memcpy((float *)_float32Data, float32Data, length * sizeof(float));
} else {
_float32Data = float32Data;
}
}
return self;
}
- (const UInt8 *)uint8Data {
switch (_dataType) {
case MPPMaskDataTypeUInt8: {
return _uint8Data;
}
case MPPMaskDataTypeFloat32: {
if (_uint8DataPtr) {
return _uint8DataPtr.get();
}
size_t length = _width * _height;
_uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
UInt8 *data = _uint8DataPtr.get();
for (int i = 0; i < length; i++) {
data[i] = _float32Data[i] * 255;
}
return data;
}
default:
return NULL;
}
}
- (const float *)float32Data {
switch (_dataType) {
case MPPMaskDataTypeUInt8: {
if (_float32DataPtr) {
return _float32DataPtr.get();
}
size_t length = _width * _height;
_float32DataPtr = std::unique_ptr<float[]>(new float[length]);
float *data = _float32DataPtr.get();
for (int i = 0; i < length; i++) {
data[i] = (float)_uint8Data[i] / 255;
}
return data;
}
case MPPMaskDataTypeFloat32: {
return _float32Data;
}
default:
return NULL;
}
}
- (id)copyWithZone:(NSZone *)zone {
switch (_dataType) {
case MPPMaskDataTypeUInt8:
return [[MPPMask alloc] initWithUInt8Data:self.uint8Data
width:self.width
height:self.height
shouldCopy:YES];
case MPPMaskDataTypeFloat32:
return [[MPPMask alloc] initWithFloat32Data:self.float32Data
width:self.width
height:self.height
shouldCopy:YES];
}
}
@end

View File

@ -165,7 +165,7 @@ static NSString *const kTaskPrefix = @"com.mediapipe.tasks.vision";
// For 90° and 270° rotations, we need to swap width and height.
// This is due to the internal behavior of ImageToTensorCalculator, which:
// - first denormalizes the provided rect by multiplying the rect width or height by the image
// width or height, repectively.
// width or height, respectively.
// - then rotates this by denormalized rect by the provided rotation, and uses this for cropping,
// - then finally rotates this back.
if (rotationDegrees % 180 == 0) {

View File

@ -28,6 +28,7 @@
- (id)copyWithZone:(NSZone *)zone {
MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
faceDetectorOptions.runningMode = self.runningMode;
faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;

View File

@ -43,9 +43,16 @@ objc_library(
],
)
objc_library(
name = "MPPFaceLandmarksConnections",
hdrs = ["sources/MPPFaceLandmarksConnections.h"],
module_name = "MPPFaceLandmarksConnections",
deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"],
)
objc_library(
name = "MPPFaceLandmarker",
srcs = ["sources/MPPFaceLandmarker.m"],
srcs = ["sources/MPPFaceLandmarker.mm"],
hdrs = ["sources/MPPFaceLandmarker.h"],
copts = [
"-ObjC++",
@ -55,9 +62,11 @@ objc_library(
deps = [
":MPPFaceLandmarkerOptions",
":MPPFaceLandmarkerResult",
":MPPFaceLandmarksConnections",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
"//mediapipe/tasks/ios/components/containers:MPPConnection",
"//mediapipe/tasks/ios/core:MPPTaskInfo",
"//mediapipe/tasks/ios/vision/core:MPPImage",
"//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",

View File

@ -14,6 +14,7 @@
#import <Foundation/Foundation.h>
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h"
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerResult.h"
@ -147,6 +148,83 @@ NS_SWIFT_NAME(FaceLandmarker)
error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
/**
* Returns the connections between all the landmarks in the lips.
*
* @return An array of connections between all the landmarks in the lips.
*/
+ (NSArray<MPPConnection *> *)lipsConnections;
/**
* Returns the connections between all the landmarks in the left eye.
*
* @return An array of connections between all the landmarks in the left eye.
*/
+ (NSArray<MPPConnection *> *)leftEyeConnections;
/**
* Returns the connections between all the landmarks in the left eyebrow.
*
* @return An array of connections between all the landmarks in the left eyebrow.
*/
+ (NSArray<MPPConnection *> *)leftEyebrowConnections;
/**
* Returns the connections between all the landmarks in the left iris.
*
* @return An array of connections between all the landmarks in the left iris.
*/
+ (NSArray<MPPConnection *> *)leftIrisConnections;
/**
* Returns the connections between all the landmarks in the right eye.
*
* @return An array of connections between all the landmarks in the right eyr.
*/
+ (NSArray<MPPConnection *> *)rightEyeConnections;
/**
* Returns the connections between all the landmarks in the right eyebrow.
*
* @return An array of connections between all the landmarks in the right eyebrow.
*/
+ (NSArray<MPPConnection *> *)rightEyebrowConnections;
/**
* Returns the connections between all the landmarks in the right iris.
*
* @return An array of connections between all the landmarks in the right iris.
*/
+ (NSArray<MPPConnection *> *)rightIrisConnections;
/**
* Returns the connections between all the landmarks of the face oval.
*
* @return An array of connections between all the landmarks of the face oval.
*/
+ (NSArray<MPPConnection *> *)faceOvalConnections;
/**
* Returns the connections between making up the contours of the face.
*
* @return An array of connections between all the contours of the face.
*/
+ (NSArray<MPPConnection *> *)contoursConnections;
/**
* Returns the connections between all the landmarks making up the tesselation of the face.
*
* @return An array of connections between all the landmarks making up the tesselation of the face.
*/
+ (NSArray<MPPConnection *> *)tesselationConnections;
/**
* Returns the connections between all the landmarks in the face.
*
* @return An array of connections between all the landmarks in the face.
*/
+ (NSArray<MPPConnection *> *)faceConnections;
- (instancetype)init NS_UNAVAILABLE;
+ (instancetype)new NS_UNAVAILABLE;

Some files were not shown because too many files have changed in this diff Show More