Merge remote-tracking branch 'origin/master' into nguyencse/facemeshioslib
This commit is contained in:
		
						commit
						5093b7a2a9
					
				
							
								
								
									
										14
									
								
								WORKSPACE
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								WORKSPACE
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -46,11 +46,12 @@ http_archive(
 | 
			
		|||
 | 
			
		||||
http_archive(
 | 
			
		||||
    name = "rules_foreign_cc",
 | 
			
		||||
   strip_prefix = "rules_foreign_cc-0.1.0",
 | 
			
		||||
   url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip",
 | 
			
		||||
    sha256 = "2a4d07cd64b0719b39a7c12218a3e507672b82a97b98c6a89d38565894cf7c51",
 | 
			
		||||
    strip_prefix = "rules_foreign_cc-0.9.0",
 | 
			
		||||
    url = "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.9.0.tar.gz",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
 | 
			
		||||
load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies")
 | 
			
		||||
 | 
			
		||||
rules_foreign_cc_dependencies()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -492,9 +493,10 @@ http_archive(
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
# TensorFlow repo should always go after the other external dependencies.
 | 
			
		||||
# TF on 2023-05-26.
 | 
			
		||||
_TENSORFLOW_GIT_COMMIT = "67d5c561981edc45daf3f9d73ddd1a77963733ca"
 | 
			
		||||
_TENSORFLOW_SHA256 = "0c8326285e9cb695313e194b97d388eea70bf8bf5b13e8f0962ca8eed5179ece"
 | 
			
		||||
# TF on 2023-06-13.
 | 
			
		||||
_TENSORFLOW_GIT_COMMIT = "491681a5620e41bf079a582ac39c585cc86878b9"
 | 
			
		||||
# curl -L https://github.com/tensorflow/tensorflow/archive/<TENSORFLOW_GIT_COMMIT>.tar.gz | shasum -a 256
 | 
			
		||||
_TENSORFLOW_SHA256 = "9f76389af7a2835e68413322c1eaabfadc912f02a76d71dc16be507f9ca3d3ac"
 | 
			
		||||
http_archive(
 | 
			
		||||
    name = "org_tensorflow",
 | 
			
		||||
    urls = [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -219,12 +219,10 @@ cc_library(
 | 
			
		|||
    deps = [
 | 
			
		||||
        ":time_series_framer_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:timestamp",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix",
 | 
			
		||||
        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:integral_types",
 | 
			
		||||
        "//mediapipe/framework/port:logging",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
        "//mediapipe/util:time_series_util",
 | 
			
		||||
        "@com_google_audio_tools//audio/dsp:window_functions",
 | 
			
		||||
        "@eigen_archive//:eigen3",
 | 
			
		||||
| 
						 | 
				
			
			@ -319,6 +317,20 @@ cc_test(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_binary(
 | 
			
		||||
    name = "time_series_framer_calculator_benchmark",
 | 
			
		||||
    srcs = ["time_series_framer_calculator_benchmark.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":time_series_framer_calculator",
 | 
			
		||||
        ":time_series_framer_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:packet",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix",
 | 
			
		||||
        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
			
		||||
        "@com_google_benchmark//:benchmark",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "time_series_framer_calculator_test",
 | 
			
		||||
    srcs = ["time_series_framer_calculator_test.cc"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,9 +15,7 @@
 | 
			
		|||
// Defines TimeSeriesFramerCalculator.
 | 
			
		||||
#include <math.h>
 | 
			
		||||
 | 
			
		||||
#include <deque>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "Eigen/Core"
 | 
			
		||||
#include "audio/dsp/window_functions.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -25,9 +23,8 @@
 | 
			
		|||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/matrix.h"
 | 
			
		||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/integral_types.h"
 | 
			
		||||
#include "mediapipe/framework/port/logging.h"
 | 
			
		||||
#include "mediapipe/framework/port/ret_check.h"
 | 
			
		||||
#include "mediapipe/framework/timestamp.h"
 | 
			
		||||
#include "mediapipe/util/time_series_util.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
| 
						 | 
				
			
			@ -88,11 +85,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
 | 
			
		|||
  absl::Status Close(CalculatorContext* cc) override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Adds input data to the internal buffer.
 | 
			
		||||
  void EnqueueInput(CalculatorContext* cc);
 | 
			
		||||
  // Constructs and emits framed output packets.
 | 
			
		||||
  void FrameOutput(CalculatorContext* cc);
 | 
			
		||||
 | 
			
		||||
  Timestamp CurrentOutputTimestamp() {
 | 
			
		||||
    if (use_local_timestamp_) {
 | 
			
		||||
      return current_timestamp_;
 | 
			
		||||
| 
						 | 
				
			
			@ -106,14 +98,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
 | 
			
		|||
                 Timestamp::kTimestampUnitsPerSecond);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns the timestamp of a sample on a base, which is usually the time
 | 
			
		||||
  // stamp of a packet.
 | 
			
		||||
  Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
 | 
			
		||||
                                   int64_t number_of_samples) {
 | 
			
		||||
    return timestamp_base + round(number_of_samples / sample_rate_ *
 | 
			
		||||
                                  Timestamp::kTimestampUnitsPerSecond);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // The number of input samples to advance after the current output frame is
 | 
			
		||||
  // emitted.
 | 
			
		||||
  int next_frame_step_samples() const {
 | 
			
		||||
| 
						 | 
				
			
			@ -142,61 +126,174 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
 | 
			
		|||
  Timestamp initial_input_timestamp_;
 | 
			
		||||
  // The current timestamp is updated along with the incoming packets.
 | 
			
		||||
  Timestamp current_timestamp_;
 | 
			
		||||
  int num_channels_;
 | 
			
		||||
 | 
			
		||||
  // Each entry in this deque consists of a single sample, i.e. a
 | 
			
		||||
  // single column vector, and its timestamp.
 | 
			
		||||
  std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
 | 
			
		||||
  // Samples are buffered in a vector of sample blocks.
 | 
			
		||||
  class SampleBlockBuffer {
 | 
			
		||||
   public:
 | 
			
		||||
    // Initializes the buffer.
 | 
			
		||||
    void Init(double sample_rate, int num_channels) {
 | 
			
		||||
      ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate;
 | 
			
		||||
      num_channels_ = num_channels;
 | 
			
		||||
      num_samples_ = 0;
 | 
			
		||||
      first_block_offset_ = 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Number of channels, equal to the number of rows in each Matrix.
 | 
			
		||||
    int num_channels() const { return num_channels_; }
 | 
			
		||||
    // Total number of available samples over all blocks.
 | 
			
		||||
    int num_samples() const { return num_samples_; }
 | 
			
		||||
 | 
			
		||||
    // Pushes a new block of samples on the back of the buffer with `timestamp`
 | 
			
		||||
    // being the input timestamp of the packet containing the Matrix.
 | 
			
		||||
    void Push(const Matrix& samples, Timestamp timestamp);
 | 
			
		||||
    // Copies `count` samples from the front of the buffer. If there are fewer
 | 
			
		||||
    // samples than this, the result is zero padded to have `count` samples.
 | 
			
		||||
    // The timestamp of the last copied sample is written to *last_timestamp.
 | 
			
		||||
    // This output is used below to update `current_timestamp_`, which is only
 | 
			
		||||
    // used when `use_local_timestamp` is true.
 | 
			
		||||
    Matrix CopySamples(int count, Timestamp* last_timestamp) const;
 | 
			
		||||
    // Drops `count` samples from the front of the buffer. If `count` exceeds
 | 
			
		||||
    // `num_samples()`, the buffer is emptied.  Returns how many samples were
 | 
			
		||||
    // dropped.
 | 
			
		||||
    int DropSamples(int count);
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
    struct Block {
 | 
			
		||||
      // Matrix of num_channels rows by num_samples columns, a block of possibly
 | 
			
		||||
      // multiple samples.
 | 
			
		||||
      Matrix samples;
 | 
			
		||||
      // Timestamp of the first sample in the Block. This comes from the input
 | 
			
		||||
      // packet's timestamp that contains this Matrix.
 | 
			
		||||
      Timestamp timestamp;
 | 
			
		||||
 | 
			
		||||
      Block() : timestamp(Timestamp::Unstarted()) {}
 | 
			
		||||
      Block(const Matrix& samples, Timestamp timestamp)
 | 
			
		||||
          : samples(samples), timestamp(timestamp) {}
 | 
			
		||||
      int num_samples() const { return samples.cols(); }
 | 
			
		||||
    };
 | 
			
		||||
    std::vector<Block> blocks_;
 | 
			
		||||
    // Number of timestamp units per sample. Used to compute timestamps as
 | 
			
		||||
    // nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n).
 | 
			
		||||
    double ts_units_per_sample_;
 | 
			
		||||
    // Number of rows in each Matrix.
 | 
			
		||||
    int num_channels_;
 | 
			
		||||
    // The total number of samples over all blocks, equal to
 | 
			
		||||
    // (sum_i blocks_[i].num_samples()) - first_block_offset_.
 | 
			
		||||
    int num_samples_;
 | 
			
		||||
    // The number of samples in the first block that have been discarded. This
 | 
			
		||||
    // way we can cheaply represent "partially discarding" a block.
 | 
			
		||||
    int first_block_offset_;
 | 
			
		||||
  } sample_buffer_;
 | 
			
		||||
 | 
			
		||||
  bool use_window_;
 | 
			
		||||
  Matrix window_;
 | 
			
		||||
  Eigen::RowVectorXf window_;
 | 
			
		||||
 | 
			
		||||
  bool use_local_timestamp_;
 | 
			
		||||
};
 | 
			
		||||
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
 | 
			
		||||
 | 
			
		||||
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
 | 
			
		||||
  const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < input_frame.cols(); ++i) {
 | 
			
		||||
    sample_buffer_.emplace_back(std::make_pair(
 | 
			
		||||
        input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
 | 
			
		||||
  }
 | 
			
		||||
void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples,
 | 
			
		||||
                                                         Timestamp timestamp) {
 | 
			
		||||
  num_samples_ += samples.cols();
 | 
			
		||||
  blocks_.emplace_back(samples, timestamp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
 | 
			
		||||
  while (sample_buffer_.size() >=
 | 
			
		||||
Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples(
 | 
			
		||||
    int count, Timestamp* last_timestamp) const {
 | 
			
		||||
  Matrix copied(num_channels_, count);
 | 
			
		||||
 | 
			
		||||
  if (!blocks_.empty()) {
 | 
			
		||||
    int num_copied = 0;
 | 
			
		||||
    // First block has an offset for samples that have been discarded.
 | 
			
		||||
    int offset = first_block_offset_;
 | 
			
		||||
    int n;
 | 
			
		||||
    Timestamp last_block_ts;
 | 
			
		||||
    int last_sample_index;
 | 
			
		||||
 | 
			
		||||
    for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) {
 | 
			
		||||
      n = std::min(it->num_samples() - offset, count);
 | 
			
		||||
      // Copy `n` samples from the next block.
 | 
			
		||||
      copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n);
 | 
			
		||||
      count -= n;
 | 
			
		||||
      num_copied += n;
 | 
			
		||||
      last_block_ts = it->timestamp;
 | 
			
		||||
      last_sample_index = offset + n - 1;
 | 
			
		||||
      offset = 0;  // No samples have been discarded in subsequent blocks.
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Compute the timestamp of the last copied sample.
 | 
			
		||||
    *last_timestamp =
 | 
			
		||||
        last_block_ts + std::round(ts_units_per_sample_ * last_sample_index);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (count > 0) {
 | 
			
		||||
    copied.rightCols(count).setZero();  // Zero pad if needed.
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return copied;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) {
 | 
			
		||||
  if (blocks_.empty()) {
 | 
			
		||||
    return 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto block_it = blocks_.begin();
 | 
			
		||||
  if (first_block_offset_ + count < block_it->num_samples()) {
 | 
			
		||||
    // `count` is less than the remaining samples in the first block.
 | 
			
		||||
    first_block_offset_ += count;
 | 
			
		||||
    num_samples_ -= count;
 | 
			
		||||
    return count;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int num_samples_dropped = block_it->num_samples() - first_block_offset_;
 | 
			
		||||
  count -= num_samples_dropped;
 | 
			
		||||
  first_block_offset_ = 0;
 | 
			
		||||
 | 
			
		||||
  for (++block_it; block_it != blocks_.end(); ++block_it) {
 | 
			
		||||
    if (block_it->num_samples() > count) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    num_samples_dropped += block_it->num_samples();
 | 
			
		||||
    count -= block_it->num_samples();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  blocks_.erase(blocks_.begin(), block_it);  // Drop whole blocks.
 | 
			
		||||
  if (!blocks_.empty()) {
 | 
			
		||||
    first_block_offset_ = count;  // Drop part of the next block.
 | 
			
		||||
    num_samples_dropped += count;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  num_samples_ -= num_samples_dropped;
 | 
			
		||||
  return num_samples_dropped;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
 | 
			
		||||
  if (initial_input_timestamp_ == Timestamp::Unstarted()) {
 | 
			
		||||
    initial_input_timestamp_ = cc->InputTimestamp();
 | 
			
		||||
    current_timestamp_ = initial_input_timestamp_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Add input data to the internal buffer.
 | 
			
		||||
  sample_buffer_.Push(cc->Inputs().Index(0).Get<Matrix>(),
 | 
			
		||||
                      cc->InputTimestamp());
 | 
			
		||||
 | 
			
		||||
  // Construct and emit framed output packets.
 | 
			
		||||
  while (sample_buffer_.num_samples() >=
 | 
			
		||||
         frame_duration_samples_ + samples_still_to_drop_) {
 | 
			
		||||
    while (samples_still_to_drop_ > 0) {
 | 
			
		||||
      sample_buffer_.pop_front();
 | 
			
		||||
      --samples_still_to_drop_;
 | 
			
		||||
    }
 | 
			
		||||
    sample_buffer_.DropSamples(samples_still_to_drop_);
 | 
			
		||||
    Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
 | 
			
		||||
                                                     ¤t_timestamp_);
 | 
			
		||||
    const int frame_step_samples = next_frame_step_samples();
 | 
			
		||||
    std::unique_ptr<Matrix> output_frame(
 | 
			
		||||
        new Matrix(num_channels_, frame_duration_samples_));
 | 
			
		||||
    for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
 | 
			
		||||
         ++i) {
 | 
			
		||||
      output_frame->col(i) = sample_buffer_.front().first;
 | 
			
		||||
      current_timestamp_ = sample_buffer_.front().second;
 | 
			
		||||
      sample_buffer_.pop_front();
 | 
			
		||||
    }
 | 
			
		||||
    const int frame_overlap_samples =
 | 
			
		||||
        frame_duration_samples_ - frame_step_samples;
 | 
			
		||||
    if (frame_overlap_samples > 0) {
 | 
			
		||||
      for (int i = 0; i < frame_overlap_samples; ++i) {
 | 
			
		||||
        output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
 | 
			
		||||
        current_timestamp_ = sample_buffer_[i].second;
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      samples_still_to_drop_ = -frame_overlap_samples;
 | 
			
		||||
    }
 | 
			
		||||
    samples_still_to_drop_ = frame_step_samples;
 | 
			
		||||
 | 
			
		||||
    if (use_window_) {
 | 
			
		||||
      *output_frame = (output_frame->array() * window_.array()).matrix();
 | 
			
		||||
      // Apply the window to each row of output_frame.
 | 
			
		||||
      output_frame.array().rowwise() *= window_.array();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cc->Outputs().Index(0).Add(output_frame.release(),
 | 
			
		||||
                               CurrentOutputTimestamp());
 | 
			
		||||
    cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
 | 
			
		||||
                                         .At(CurrentOutputTimestamp()));
 | 
			
		||||
    ++cumulative_output_frames_;
 | 
			
		||||
    cumulative_completed_samples_ += frame_step_samples;
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -206,35 +303,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
 | 
			
		|||
    // fact to enable packet queueing optimizations.
 | 
			
		||||
    cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
 | 
			
		||||
  if (initial_input_timestamp_ == Timestamp::Unstarted()) {
 | 
			
		||||
    initial_input_timestamp_ = cc->InputTimestamp();
 | 
			
		||||
    current_timestamp_ = initial_input_timestamp_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EnqueueInput(cc);
 | 
			
		||||
  FrameOutput(cc);
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
 | 
			
		||||
  while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
 | 
			
		||||
    sample_buffer_.pop_front();
 | 
			
		||||
    --samples_still_to_drop_;
 | 
			
		||||
  }
 | 
			
		||||
  if (!sample_buffer_.empty() && pad_final_packet_) {
 | 
			
		||||
    std::unique_ptr<Matrix> output_frame(new Matrix);
 | 
			
		||||
    output_frame->setZero(num_channels_, frame_duration_samples_);
 | 
			
		||||
    for (int i = 0; i < sample_buffer_.size(); ++i) {
 | 
			
		||||
      output_frame->col(i) = sample_buffer_[i].first;
 | 
			
		||||
      current_timestamp_ = sample_buffer_[i].second;
 | 
			
		||||
    }
 | 
			
		||||
  sample_buffer_.DropSamples(samples_still_to_drop_);
 | 
			
		||||
 | 
			
		||||
    cc->Outputs().Index(0).Add(output_frame.release(),
 | 
			
		||||
                               CurrentOutputTimestamp());
 | 
			
		||||
  if (sample_buffer_.num_samples() > 0 && pad_final_packet_) {
 | 
			
		||||
    Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
 | 
			
		||||
                                                     ¤t_timestamp_);
 | 
			
		||||
    cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
 | 
			
		||||
                                         .At(CurrentOutputTimestamp()));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +338,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
 | 
			
		|||
      cc->Inputs().Index(0).Header(), &input_header));
 | 
			
		||||
 | 
			
		||||
  sample_rate_ = input_header.sample_rate();
 | 
			
		||||
  num_channels_ = input_header.num_channels();
 | 
			
		||||
  sample_buffer_.Init(sample_rate_, input_header.num_channels());
 | 
			
		||||
  frame_duration_samples_ = time_series_util::SecondsToSamples(
 | 
			
		||||
      framer_options.frame_duration_seconds(), sample_rate_);
 | 
			
		||||
  RET_CHECK_GT(frame_duration_samples_, 0)
 | 
			
		||||
| 
						 | 
				
			
			@ -312,8 +392,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  if (use_window_) {
 | 
			
		||||
    window_ = Matrix::Ones(num_channels_, 1) *
 | 
			
		||||
              Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
 | 
			
		||||
    window_ = Eigen::Map<Eigen::RowVectorXd>(window_vector.data(),
 | 
			
		||||
                                             frame_duration_samples_)
 | 
			
		||||
                  .cast<float>();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,92 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
//
 | 
			
		||||
// Benchmark for TimeSeriesFramerCalculator.
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <random>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "benchmark/benchmark.h"
 | 
			
		||||
#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/matrix.h"
 | 
			
		||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::Matrix;
 | 
			
		||||
 | 
			
		||||
void BM_TimeSeriesFramerCalculator(benchmark::State& state) {
 | 
			
		||||
  constexpr float kSampleRate = 32000.0;
 | 
			
		||||
  constexpr int kNumChannels = 2;
 | 
			
		||||
  constexpr int kFrameDurationSeconds = 5.0;
 | 
			
		||||
  std::mt19937 rng(0 /*seed*/);
 | 
			
		||||
  // Input around a half second's worth of samples at a time.
 | 
			
		||||
  std::uniform_int_distribution<int> input_size_dist(15000, 17000);
 | 
			
		||||
  // Generate a pool of random blocks of samples up front.
 | 
			
		||||
  std::vector<Matrix> sample_pool;
 | 
			
		||||
  sample_pool.reserve(20);
 | 
			
		||||
  for (int i = 0; i < 20; ++i) {
 | 
			
		||||
    sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng)));
 | 
			
		||||
  }
 | 
			
		||||
  std::uniform_int_distribution<int> pool_index_dist(0, sample_pool.size() - 1);
 | 
			
		||||
 | 
			
		||||
  mediapipe::CalculatorGraphConfig config;
 | 
			
		||||
  config.add_input_stream("input");
 | 
			
		||||
  config.add_output_stream("output");
 | 
			
		||||
  auto* node = config.add_node();
 | 
			
		||||
  node->set_calculator("TimeSeriesFramerCalculator");
 | 
			
		||||
  node->add_input_stream("input");
 | 
			
		||||
  node->add_output_stream("output");
 | 
			
		||||
  mediapipe::TimeSeriesFramerCalculatorOptions* options =
 | 
			
		||||
      node->mutable_options()->MutableExtension(
 | 
			
		||||
          mediapipe::TimeSeriesFramerCalculatorOptions::ext);
 | 
			
		||||
  options->set_frame_duration_seconds(kFrameDurationSeconds);
 | 
			
		||||
 | 
			
		||||
  for (auto _ : state) {
 | 
			
		||||
    state.PauseTiming();  // Pause benchmark timing.
 | 
			
		||||
 | 
			
		||||
    // Prepare input packets of random blocks of samples.
 | 
			
		||||
    std::vector<mediapipe::Packet> input_packets;
 | 
			
		||||
    input_packets.reserve(32);
 | 
			
		||||
    float t = 0;
 | 
			
		||||
    for (int i = 0; i < 32; ++i) {
 | 
			
		||||
      auto samples =
 | 
			
		||||
          std::make_unique<Matrix>(sample_pool[pool_index_dist(rng)]);
 | 
			
		||||
      const int num_samples = samples->cols();
 | 
			
		||||
      input_packets.push_back(mediapipe::Adopt(samples.release())
 | 
			
		||||
                                  .At(mediapipe::Timestamp::FromSeconds(t)));
 | 
			
		||||
      t += num_samples / kSampleRate;
 | 
			
		||||
    }
 | 
			
		||||
    // Initialize graph.
 | 
			
		||||
    mediapipe::CalculatorGraph graph;
 | 
			
		||||
    CHECK_OK(graph.Initialize(config));
 | 
			
		||||
    // Prepare input header.
 | 
			
		||||
    auto header = std::make_unique<mediapipe::TimeSeriesHeader>();
 | 
			
		||||
    header->set_sample_rate(kSampleRate);
 | 
			
		||||
    header->set_num_channels(kNumChannels);
 | 
			
		||||
 | 
			
		||||
    state.ResumeTiming();  // Resume benchmark timing.
 | 
			
		||||
 | 
			
		||||
    CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}}));
 | 
			
		||||
    for (auto& packet : input_packets) {
 | 
			
		||||
      CHECK_OK(graph.AddPacketToInputStream("input", packet));
 | 
			
		||||
    }
 | 
			
		||||
    CHECK(!graph.HasError());
 | 
			
		||||
    CHECK_OK(graph.CloseAllInputStreams());
 | 
			
		||||
    CHECK_OK(graph.WaitUntilIdle());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
BENCHMARK(BM_TimeSeriesFramerCalculator);
 | 
			
		||||
 | 
			
		||||
BENCHMARK_MAIN();
 | 
			
		||||
| 
						 | 
				
			
			@ -117,6 +117,7 @@ mediapipe_proto_library(
 | 
			
		|||
        "//mediapipe/framework:calculator_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix_data_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:time_series_header_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -289,6 +290,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/api2:node",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:tensor",
 | 
			
		||||
        "//mediapipe/framework/port:integral_types",
 | 
			
		||||
| 
						 | 
				
			
			@ -1167,6 +1169,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework:collection_item_id",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:matrix_data_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
			
		||||
        "//mediapipe/framework/port:integral_types",
 | 
			
		||||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,6 +17,7 @@
 | 
			
		|||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/port/integral_types.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -104,4 +105,7 @@ typedef ConcatenateVectorCalculator<mediapipe::RenderData>
 | 
			
		|||
    ConcatenateRenderDataVectorCalculator;
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
 | 
			
		||||
 | 
			
		||||
typedef ConcatenateVectorCalculator<mediapipe::Image>
 | 
			
		||||
    ConcatenateImageVectorCalculator;
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(ConcatenateImageVectorCalculator);
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@
 | 
			
		|||
#include "mediapipe/framework/collection_item_id.h"
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/matrix_data.pb.h"
 | 
			
		||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/canonical_errors.h"
 | 
			
		||||
#include "mediapipe/framework/port/integral_types.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -85,8 +86,12 @@ class ConstantSidePacketCalculator : public CalculatorBase {
 | 
			
		|||
        packet.Set<LandmarkList>();
 | 
			
		||||
      } else if (packet_options.has_double_value()) {
 | 
			
		||||
        packet.Set<double>();
 | 
			
		||||
      } else if (packet_options.has_matrix_data_value()) {
 | 
			
		||||
        packet.Set<MatrixData>();
 | 
			
		||||
      } else if (packet_options.has_time_series_header_value()) {
 | 
			
		||||
        packet.Set<TimeSeriesHeader>();
 | 
			
		||||
      } else if (packet_options.has_int64_value()) {
 | 
			
		||||
        packet.Set<int64_t>();
 | 
			
		||||
      } else {
 | 
			
		||||
        return absl::InvalidArgumentError(
 | 
			
		||||
            "None of supported values were specified in options.");
 | 
			
		||||
| 
						 | 
				
			
			@ -121,9 +126,13 @@ class ConstantSidePacketCalculator : public CalculatorBase {
 | 
			
		|||
            MakePacket<LandmarkList>(packet_options.landmark_list_value()));
 | 
			
		||||
      } else if (packet_options.has_double_value()) {
 | 
			
		||||
        packet.Set(MakePacket<double>(packet_options.double_value()));
 | 
			
		||||
      } else if (packet_options.has_matrix_data_value()) {
 | 
			
		||||
        packet.Set(MakePacket<MatrixData>(packet_options.matrix_data_value()));
 | 
			
		||||
      } else if (packet_options.has_time_series_header_value()) {
 | 
			
		||||
        packet.Set(MakePacket<TimeSeriesHeader>(
 | 
			
		||||
            packet_options.time_series_header_value()));
 | 
			
		||||
      } else if (packet_options.has_int64_value()) {
 | 
			
		||||
        packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
 | 
			
		||||
      } else {
 | 
			
		||||
        return absl::InvalidArgumentError(
 | 
			
		||||
            "None of supported values were specified in options.");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ package mediapipe;
 | 
			
		|||
import "mediapipe/framework/calculator.proto";
 | 
			
		||||
import "mediapipe/framework/formats/classification.proto";
 | 
			
		||||
import "mediapipe/framework/formats/landmark.proto";
 | 
			
		||||
import "mediapipe/framework/formats/matrix_data.proto";
 | 
			
		||||
import "mediapipe/framework/formats/time_series_header.proto";
 | 
			
		||||
 | 
			
		||||
message ConstantSidePacketCalculatorOptions {
 | 
			
		||||
| 
						 | 
				
			
			@ -29,14 +30,16 @@ message ConstantSidePacketCalculatorOptions {
 | 
			
		|||
  message ConstantSidePacket {
 | 
			
		||||
    oneof value {
 | 
			
		||||
      int32 int_value = 1;
 | 
			
		||||
      uint64 uint64_value = 5;
 | 
			
		||||
      int64 int64_value = 11;
 | 
			
		||||
      float float_value = 2;
 | 
			
		||||
      double double_value = 9;
 | 
			
		||||
      bool bool_value = 3;
 | 
			
		||||
      string string_value = 4;
 | 
			
		||||
      uint64 uint64_value = 5;
 | 
			
		||||
      ClassificationList classification_list_value = 6;
 | 
			
		||||
      LandmarkList landmark_list_value = 7;
 | 
			
		||||
      double double_value = 9;
 | 
			
		||||
      TimeSeriesHeader time_series_header_value = 10;
 | 
			
		||||
      MatrixData matrix_data_value = 12;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,7 @@
 | 
			
		|||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
 | 
			
		|||
  DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
 | 
			
		||||
  DoTestSingleSidePacket("{ bool_value: true }", true);
 | 
			
		||||
  DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
 | 
			
		||||
  DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -228,7 +228,6 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
| 
						 | 
				
			
			@ -280,7 +279,6 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
 | 
			
		||||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,6 @@
 | 
			
		|||
 | 
			
		||||
#include "absl/container/flat_hash_set.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/ascii.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "absl/strings/substitute.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
 | 
			
		|||
  input_tensors.reserve(kNumInputTensorsForBert);
 | 
			
		||||
  for (int i = 0; i < kNumInputTensorsForBert; ++i) {
 | 
			
		||||
    input_tensors.push_back(
 | 
			
		||||
        {Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
 | 
			
		||||
        {Tensor::ElementType::kInt32,
 | 
			
		||||
         Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
 | 
			
		||||
  }
 | 
			
		||||
  std::memcpy(input_tensors[input_ids_tensor_index_]
 | 
			
		||||
                  .GetCpuWriteView()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		|||
    CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
 | 
			
		||||
  // Read CPU input into tensors.
 | 
			
		||||
  RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
 | 
			
		||||
 | 
			
		||||
  // If the input tensors have dynamic shape, then the tensors need to be
 | 
			
		||||
  // resized and reallocated before we can copy the tensor values.
 | 
			
		||||
  bool resized_tensor_shapes = false;
 | 
			
		||||
  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
    if (input_tensors[i].shape().is_dynamic) {
 | 
			
		||||
      interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
 | 
			
		||||
      resized_tensor_shapes = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Reallocation is needed for memory sanity.
 | 
			
		||||
  if (resized_tensor_shapes) interpreter_->AllocateTensors();
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
    const TfLiteType input_tensor_type =
 | 
			
		||||
        interpreter_->tensor(interpreter_->inputs()[i])->type;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,7 +20,6 @@
 | 
			
		|||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
  // not found in the tokenizer vocab.
 | 
			
		||||
  std::vector<Tensor> result;
 | 
			
		||||
  result.push_back(
 | 
			
		||||
      {Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
 | 
			
		||||
      {Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
 | 
			
		||||
  std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
 | 
			
		||||
              input_tokens.data(), input_tokens.size() * sizeof(int32_t));
 | 
			
		||||
  kTensorsOut(cc).Send(std::move(result));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1077,6 +1077,7 @@ cc_test(
 | 
			
		|||
    linkstatic = 1,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":tensor_to_image_frame_calculator",
 | 
			
		||||
        ":tensor_to_image_frame_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:calculator_runner",
 | 
			
		||||
        "//mediapipe/framework/formats:image_frame",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
 | 
			
		|||
 | 
			
		||||
 private:
 | 
			
		||||
  float scale_factor_;
 | 
			
		||||
  bool scale_per_frame_min_max_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
 | 
			
		||||
| 
						 | 
				
			
			@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
 | 
			
		|||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
 | 
			
		||||
  scale_factor_ =
 | 
			
		||||
      cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
 | 
			
		||||
  scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
 | 
			
		||||
                                 .scale_per_frame_min_max();
 | 
			
		||||
  cc->SetOffset(TimestampDiff(0));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
  auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
 | 
			
		||||
  const int32_t total_size = height * width * depth;
 | 
			
		||||
 | 
			
		||||
  if (scale_per_frame_min_max_) {
 | 
			
		||||
    RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
 | 
			
		||||
        << "Setting scale_per_frame_min_max requires FLOAT input tensors.";
 | 
			
		||||
  }
 | 
			
		||||
  ::std::unique_ptr<const ImageFrame> output;
 | 
			
		||||
  if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
 | 
			
		||||
    // Allocate buffer with alignments.
 | 
			
		||||
    std::unique_ptr<uint8_t[]> buffer(
 | 
			
		||||
        new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
 | 
			
		||||
    auto data = input_tensor.flat<float>().data();
 | 
			
		||||
    float min = 1e23;
 | 
			
		||||
    float max = -1e23;
 | 
			
		||||
    if (scale_per_frame_min_max_) {
 | 
			
		||||
      for (int i = 0; i < total_size; ++i) {
 | 
			
		||||
        float d = scale_factor_ * data[i];
 | 
			
		||||
        if (d < min) {
 | 
			
		||||
          min = d;
 | 
			
		||||
        }
 | 
			
		||||
        if (d > max) {
 | 
			
		||||
          max = d;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    for (int i = 0; i < total_size; ++i) {
 | 
			
		||||
      float d = data[i];
 | 
			
		||||
      if (scale_per_frame_min_max_) {
 | 
			
		||||
        d = 255 * (d - min) / (max - min + 1e-9);
 | 
			
		||||
      } else {
 | 
			
		||||
        d = scale_factor_ * d;
 | 
			
		||||
        if (d < 0) d = 0;
 | 
			
		||||
        if (d > 255) d = 255;
 | 
			
		||||
      }
 | 
			
		||||
      buffer[i] = d;
 | 
			
		||||
    }
 | 
			
		||||
    output = ::absl::make_unique<ImageFrame>(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
 | 
			
		|||
  // Multiples floating point tensor outputs by this value before converting to
 | 
			
		||||
  // uint8. This is useful for converting from range [0, 1] to [0, 255]
 | 
			
		||||
  optional float scale_factor = 1 [default = 1.0];
 | 
			
		||||
 | 
			
		||||
  // If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
 | 
			
		||||
  // per frame. This overrides any explicit scale_factor.
 | 
			
		||||
  optional bool scale_per_frame_min_max = 2 [default = false];
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,7 +11,9 @@
 | 
			
		|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_runner.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image_frame.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
 | 
			
		|||
template <class TypeParam>
 | 
			
		||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  void SetUpRunner() {
 | 
			
		||||
  void SetUpRunner(bool scale_per_frame_min_max = false) {
 | 
			
		||||
    CalculatorGraphConfig::Node config;
 | 
			
		||||
    config.set_calculator("TensorToImageFrameCalculator");
 | 
			
		||||
    config.add_input_stream("TENSOR:input_tensor");
 | 
			
		||||
    config.add_output_stream("IMAGE:output_image");
 | 
			
		||||
    config.mutable_options()
 | 
			
		||||
        ->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
 | 
			
		||||
        ->set_scale_per_frame_min_max(scale_per_frame_min_max);
 | 
			
		||||
    runner_ = absl::make_unique<CalculatorRunner>(config);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
 | 
			
		|||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TYPED_TEST(TensorToImageFrameCalculatorTest,
 | 
			
		||||
           Converts3DTensorToImageFrame2DGrayWithScaling) {
 | 
			
		||||
  this->SetUpRunner(true);
 | 
			
		||||
  auto& runner = this->runner_;
 | 
			
		||||
  constexpr int kWidth = 16;
 | 
			
		||||
  constexpr int kHeight = 8;
 | 
			
		||||
  const tf::TensorShape tensor_shape{kHeight, kWidth};
 | 
			
		||||
  auto tensor = absl::make_unique<tf::Tensor>(
 | 
			
		||||
      tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
 | 
			
		||||
  auto tensor_vec = tensor->template flat<TypeParam>().data();
 | 
			
		||||
 | 
			
		||||
  // Writing sequence of integers as floats which we want normalized.
 | 
			
		||||
  tensor_vec[0] = 255;
 | 
			
		||||
  for (int i = 1; i < kWidth * kHeight; ++i) {
 | 
			
		||||
    tensor_vec[i] = 200;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64_t time = 1234;
 | 
			
		||||
  runner->MutableInputs()->Tag(kTensor).packets.push_back(
 | 
			
		||||
      Adopt(tensor.release()).At(Timestamp(time)));
 | 
			
		||||
 | 
			
		||||
  if (!std::is_same<TypeParam, float>::value) {
 | 
			
		||||
    EXPECT_FALSE(runner->Run().ok());
 | 
			
		||||
    return;  // Short circuit because does not apply to other types.
 | 
			
		||||
  } else {
 | 
			
		||||
    EXPECT_TRUE(runner->Run().ok());
 | 
			
		||||
    const std::vector<Packet>& output_packets =
 | 
			
		||||
        runner->Outputs().Tag(kImage).packets;
 | 
			
		||||
    EXPECT_EQ(1, output_packets.size());
 | 
			
		||||
    EXPECT_EQ(time, output_packets[0].Timestamp().Value());
 | 
			
		||||
    const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
 | 
			
		||||
    EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
 | 
			
		||||
    EXPECT_EQ(kWidth, output_image.Width());
 | 
			
		||||
    EXPECT_EQ(kHeight, output_image.Height());
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(255, output_image.PixelData()[0]);
 | 
			
		||||
    for (int i = 1; i < kWidth * kHeight; ++i) {
 | 
			
		||||
      const uint8_t pixel_value = output_image.PixelData()[i];
 | 
			
		||||
      ASSERT_EQ(0, pixel_value);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1355,6 +1355,23 @@ cc_test(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "calculator_graph_summary_packet_test",
 | 
			
		||||
    srcs = ["calculator_graph_summary_packet_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":calculator_framework",
 | 
			
		||||
        ":packet",
 | 
			
		||||
        "//mediapipe/framework/api2:node",
 | 
			
		||||
        "//mediapipe/framework/api2:packet",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
 | 
			
		||||
        "//mediapipe/framework/tool:sink",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "calculator_runner_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,7 +32,7 @@ template <class T>
 | 
			
		|||
struct dependent_false : std::false_type {};
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, int index) {
 | 
			
		||||
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, size_t index) {
 | 
			
		||||
  auto& vec = *vecp;
 | 
			
		||||
  if (vec.size() <= index) {
 | 
			
		||||
    vec.resize(index + 1);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -109,9 +109,20 @@ class CalculatorContext {
 | 
			
		|||
  // use OutputStream::SetOffset() directly.
 | 
			
		||||
  void SetOffset(TimestampDiff offset);
 | 
			
		||||
 | 
			
		||||
  // Returns the status of the graph run.
 | 
			
		||||
  // DEPRECATED: This was intended to get graph run status during
 | 
			
		||||
  // `CalculatorBase::Close` call. However, `Close` can run simultaneously with
 | 
			
		||||
  // other calculators `CalculatorBase::Process`, hence the actual graph
 | 
			
		||||
  // status may change any time and returned graph status here does not
 | 
			
		||||
  // necessarily reflect the actual graph status.
 | 
			
		||||
  //
 | 
			
		||||
  // NOTE: This method should only be called during CalculatorBase::Close().
 | 
			
		||||
  // As an alternative, instead of checking graph status in `Close` and doing
 | 
			
		||||
  // work for "done" state, you can enable timestamp bound processing for your
 | 
			
		||||
  // calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger
 | 
			
		||||
  // `Process` on timestamp bound updates and handle "done" state there.
 | 
			
		||||
  // Check examples in:
 | 
			
		||||
  // mediapipe/framework/calculator_graph_summary_packet_test.cc.
 | 
			
		||||
  //
 | 
			
		||||
  ABSL_DEPRECATED("Does not reflect the actual graph status.")
 | 
			
		||||
  absl::Status GraphStatus() const { return graph_status_; }
 | 
			
		||||
 | 
			
		||||
  ProfilingContext* GetProfilingContext() const {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -839,6 +839,13 @@ absl::Status CalculatorGraph::PrepareForRun(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
absl::Status CalculatorGraph::WaitUntilIdle() {
 | 
			
		||||
  if (has_sources_) {
 | 
			
		||||
    LOG_FIRST_N(WARNING, 1)
 | 
			
		||||
        << "WaitUntilIdle called on a graph with source nodes, which "
 | 
			
		||||
           "is not fully supported at the moment. Source nodes: "
 | 
			
		||||
        << ListSourceNodes();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle());
 | 
			
		||||
  VLOG(2) << "Scheduler idle.";
 | 
			
		||||
  absl::Status status = absl::OkStatus();
 | 
			
		||||
| 
						 | 
				
			
			@ -1368,6 +1375,16 @@ const OutputStreamManager* CalculatorGraph::FindOutputStreamManager(
 | 
			
		|||
              .get()[validated_graph_->OutputStreamIndex(name)];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string CalculatorGraph::ListSourceNodes() const {
 | 
			
		||||
  std::vector<std::string> sources;
 | 
			
		||||
  for (auto& node : nodes_) {
 | 
			
		||||
    if (node->IsSource()) {
 | 
			
		||||
      sources.push_back(node->DebugName());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return absl::StrJoin(sources, ", ");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
void PrintTimingToInfo(const std::string& label, int64_t timer_value) {
 | 
			
		||||
  const int64_t total_seconds = timer_value / 1000000ll;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -229,8 +229,11 @@ class CalculatorGraph {
 | 
			
		|||
  // Wait until the running graph is in the idle mode, which is when nothing can
 | 
			
		||||
  // be scheduled and nothing is running in the worker threads. This function
 | 
			
		||||
  // can be called only after StartRun().
 | 
			
		||||
  //
 | 
			
		||||
  // NOTE: The graph must not have any source nodes because source nodes prevent
 | 
			
		||||
  // the running graph from becoming idle until the source nodes are done.
 | 
			
		||||
  // Currently, `WaitUntilIdle` cannot be used reliably on graphs with any
 | 
			
		||||
  // source nodes.
 | 
			
		||||
  absl::Status WaitUntilIdle();
 | 
			
		||||
 | 
			
		||||
  // Wait until a packet is emitted on one of the observed output streams.
 | 
			
		||||
| 
						 | 
				
			
			@ -594,6 +597,9 @@ class CalculatorGraph {
 | 
			
		|||
  // status before taking any action.
 | 
			
		||||
  void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
 | 
			
		||||
 | 
			
		||||
  // Returns a comma-separated list of source nodes.
 | 
			
		||||
  std::string ListSourceNodes() const;
 | 
			
		||||
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
  // Owns the legacy GpuSharedData if we need to create one for backwards
 | 
			
		||||
  // compatibility.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										430
									
								
								mediapipe/framework/calculator_graph_summary_packet_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										430
									
								
								mediapipe/framework/calculator_graph_summary_packet_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,430 @@
 | 
			
		|||
#include "absl/status/status.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/api2/packet.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::api2::Input;
 | 
			
		||||
using ::mediapipe::api2::Node;
 | 
			
		||||
using ::mediapipe::api2::Output;
 | 
			
		||||
using ::testing::ElementsAre;
 | 
			
		||||
using ::testing::Eq;
 | 
			
		||||
using ::testing::HasSubstr;
 | 
			
		||||
using ::testing::IsEmpty;
 | 
			
		||||
using ::testing::Value;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
MATCHER_P2(IntPacket, value, timestamp, "") {
 | 
			
		||||
  *result_listener << "where object is (value: " << arg.template Get<int>()
 | 
			
		||||
                   << ", timestamp: " << arg.Timestamp() << ")";
 | 
			
		||||
  return Value(arg.template Get<int>(), Eq(value)) &&
 | 
			
		||||
         Value(arg.Timestamp(), Eq(timestamp));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculates and produces sum of all passed inputs when no more packets can be
 | 
			
		||||
// expected on the input stream.
 | 
			
		||||
class SummaryPacketCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<int> kIn{"IN"};
 | 
			
		||||
  static constexpr Output<int> kOut{"SUMMARY"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
			
		||||
 | 
			
		||||
  static absl::Status UpdateContract(CalculatorContract* cc) {
 | 
			
		||||
    // Makes sure there are no automatic timestamp bound updates when Process
 | 
			
		||||
    // is called.
 | 
			
		||||
    cc->SetTimestampOffset(TimestampDiff::Unset());
 | 
			
		||||
    // Currently, only ImmediateInputStreamHandler supports "done" timestamp
 | 
			
		||||
    // bound update. (ImmediateInputStreamhandler handles multiple input
 | 
			
		||||
    // streams differently, so, in that case, calculator adjustments may be
 | 
			
		||||
    // required.)
 | 
			
		||||
    // TODO: update all input stream handlers to support "done"
 | 
			
		||||
    // timestamp bound update.
 | 
			
		||||
    cc->SetInputStreamHandler("ImmediateInputStreamHandler");
 | 
			
		||||
    // Enables processing timestamp bound updates. For this use case we are
 | 
			
		||||
    // specifically interested in "done" timestamp bound update. (E.g. when
 | 
			
		||||
    // all input packet sources are closed.)
 | 
			
		||||
    cc->SetProcessTimestampBounds(true);
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
			
		||||
    if (!kIn(cc).IsEmpty()) {
 | 
			
		||||
      value_ += kIn(cc).Get();
 | 
			
		||||
      value_set_ = true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (kOut(cc).IsClosed()) {
 | 
			
		||||
      // This can happen:
 | 
			
		||||
      // 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
 | 
			
		||||
      //    source calculator finished generating packets sent to kIn) and
 | 
			
		||||
      //    HasNextAllowedInStream() == true (which is an often case).
 | 
			
		||||
      // 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
 | 
			
		||||
      //    invoke Process() with Timestamp::Max to indicate "Done" timestamp
 | 
			
		||||
      //    bound update.
 | 
			
		||||
      return absl::OkStatus();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO: input stream holding a packet with timestamp that has
 | 
			
		||||
    // no next timestamp allowed in stream should always result in
 | 
			
		||||
    // InputStream::IsDone() == true.
 | 
			
		||||
    if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
 | 
			
		||||
      // `Process` may or may not be invoked for "done" timestamp bound when
 | 
			
		||||
      // upstream calculator fails in `Close`. Hence, extra care is needed to
 | 
			
		||||
      // identify whether the calculator needs to send output.
 | 
			
		||||
      // TODO: remove when "done" timestamp bound flakiness fixed.
 | 
			
		||||
      if (value_set_) {
 | 
			
		||||
        // kOut(cc).Send(value_) can be used here as well, however in the case
 | 
			
		||||
        // of source calculator sending inputs into kIn the resulting timestamp
 | 
			
		||||
        // is not well defined (e.g. it can be the last packet timestamp or
 | 
			
		||||
        // Timestamp::Max())
 | 
			
		||||
        // TODO: last packet from source should always result in
 | 
			
		||||
        // InputStream::IsDone() == true.
 | 
			
		||||
        kOut(cc).Send(value_, Timestamp::Max());
 | 
			
		||||
      }
 | 
			
		||||
      kOut(cc).Close();
 | 
			
		||||
    }
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  int value_ = 0;
 | 
			
		||||
  bool value_set_ = false;
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnClosingAllPacketSources) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: 'input'
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: 'IN:input'
 | 
			
		||||
      output_stream: 'SUMMARY:output'
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp(10));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  send_packet(20, Timestamp(11));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: 'input'
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: 'IN:input'
 | 
			
		||||
      output_stream: 'SUMMARY:output'
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp(10));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  send_packet(20, Timestamp::Max());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnPreStreamTimestamp) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: 'input'
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: 'IN:input'
 | 
			
		||||
      output_stream: 'SUMMARY:output'
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp::PreStream());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnPostStreamTimestamp) {
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  CalculatorGraphConfig graph_config =
 | 
			
		||||
      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
        input_stream: 'input'
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "SummaryPacketCalculator"
 | 
			
		||||
          input_stream: 'IN:input'
 | 
			
		||||
          output_stream: 'SUMMARY:output'
 | 
			
		||||
        }
 | 
			
		||||
      )pb");
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp::PostStream());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class IntGeneratorCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Output<int> kOut{"INT"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
			
		||||
    kOut(cc).Send(20, Timestamp(0));
 | 
			
		||||
    kOut(cc).Send(10, Timestamp(1000));
 | 
			
		||||
    return tool::StatusStop();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnSourceCalculatorCompletion) {
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  CalculatorGraphConfig graph_config =
 | 
			
		||||
      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "IntGeneratorCalculator"
 | 
			
		||||
          output_stream: "INT:int_value"
 | 
			
		||||
        }
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "SummaryPacketCalculator"
 | 
			
		||||
          input_stream: "IN:int_value"
 | 
			
		||||
          output_stream: "SUMMARY:output"
 | 
			
		||||
        }
 | 
			
		||||
      )pb");
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_EXPECT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class EmitOnCloseCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<int> kIn{"IN"};
 | 
			
		||||
  static constexpr Output<int> kOut{"INT"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
 | 
			
		||||
 | 
			
		||||
  absl::Status Close(CalculatorContext* cc) final {
 | 
			
		||||
    kOut(cc).Send(20, Timestamp(0));
 | 
			
		||||
    kOut(cc).Send(10, Timestamp(1000));
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnAnotherCalculatorClosure) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: "input"
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "EmitOnCloseCalculator"
 | 
			
		||||
      input_stream: "IN:input"
 | 
			
		||||
      output_stream: "INT:int_value"
 | 
			
		||||
    }
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: "IN:int_value"
 | 
			
		||||
      output_stream: "SUMMARY:output"
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseInputStream("input"));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class FailureInCloseCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<int> kIn{"IN"};
 | 
			
		||||
  static constexpr Output<int> kOut{"INT"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
 | 
			
		||||
 | 
			
		||||
  absl::Status Close(CalculatorContext* cc) final {
 | 
			
		||||
    return absl::InternalError("error");
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(FailureInCloseCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInClose) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: "input"
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "FailureInCloseCalculator"
 | 
			
		||||
      input_stream: "IN:input"
 | 
			
		||||
      output_stream: "INT:int_value"
 | 
			
		||||
    }
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: "IN:int_value"
 | 
			
		||||
      output_stream: "SUMMARY:output"
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseInputStream("input"));
 | 
			
		||||
  EXPECT_THAT(graph.WaitUntilIdle(),
 | 
			
		||||
              StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class FailureInProcessCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<int> kIn{"IN"};
 | 
			
		||||
  static constexpr Output<int> kOut{"INT"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
			
		||||
    return absl::InternalError("error");
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(FailureInProcessCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInProcess) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: "input"
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "FailureInProcessCalculator"
 | 
			
		||||
      input_stream: "IN:input"
 | 
			
		||||
      output_stream: "INT:int_value"
 | 
			
		||||
    }
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: "IN:int_value"
 | 
			
		||||
      output_stream: "SUMMARY:output"
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp::PostStream());
 | 
			
		||||
  EXPECT_THAT(graph.WaitUntilIdle(),
 | 
			
		||||
              StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -117,11 +117,18 @@ class Tensor {
 | 
			
		|||
    Shape() = default;
 | 
			
		||||
    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
			
		||||
    Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
 | 
			
		||||
    Shape(std::initializer_list<int> dimensions, bool is_dynamic)
 | 
			
		||||
        : dims(dimensions), is_dynamic(is_dynamic) {}
 | 
			
		||||
    Shape(const std::vector<int>& dimensions, bool is_dynamic)
 | 
			
		||||
        : dims(dimensions), is_dynamic(is_dynamic) {}
 | 
			
		||||
    int num_elements() const {
 | 
			
		||||
      return std::accumulate(dims.begin(), dims.end(), 1,
 | 
			
		||||
                             std::multiplies<int>());
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<int> dims;
 | 
			
		||||
    // The Tensor has dynamic rather than static shape so the TFLite interpreter
 | 
			
		||||
    // needs to be reallocated. Only relevant for CPU.
 | 
			
		||||
    bool is_dynamic = false;
 | 
			
		||||
  };
 | 
			
		||||
  // Quantization parameters corresponding to the zero_point and scale value
 | 
			
		||||
  // made available by TfLite quantized (uint8/int8) tensors.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@
 | 
			
		|||
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,17 @@ TEST(General, TestDataTypes) {
 | 
			
		|||
  EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(General, TestDynamic) {
 | 
			
		||||
  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true));
 | 
			
		||||
  EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4);
 | 
			
		||||
  EXPECT_TRUE(t1.shape().is_dynamic);
 | 
			
		||||
 | 
			
		||||
  std::vector<int> t2_dims = {4, 3, 2, 3};
 | 
			
		||||
  Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true));
 | 
			
		||||
  EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3);
 | 
			
		||||
  EXPECT_TRUE(t2.shape().is_dynamic);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(Cpu, TestMemoryAllocation) {
 | 
			
		||||
  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
 | 
			
		||||
  auto v1 = t1.GetCpuWriteView();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -273,8 +273,8 @@ absl::Status Scheduler::WaitForObservedOutput() {
 | 
			
		|||
// Idleness requires:
 | 
			
		||||
// 1. either the graph has no source nodes or all source nodes are closed, and
 | 
			
		||||
// 2. no packets are added to graph input streams.
 | 
			
		||||
// For simplicity, we only allow WaitUntilIdle() to be called on a graph with
 | 
			
		||||
// no source nodes. (This is enforced by CalculatorGraph::WaitUntilIdle().)
 | 
			
		||||
// For simplicity, we only fully support WaitUntilIdle() to be called on a graph
 | 
			
		||||
// with no source nodes.
 | 
			
		||||
// The application must ensure no other threads are adding packets to graph
 | 
			
		||||
// input streams while a WaitUntilIdle() call is in progress.
 | 
			
		||||
absl::Status Scheduler::WaitUntilIdle() {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
 | 
			
		|||
  return *this + 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Timestamp::HasNextAllowedInStream() const {
 | 
			
		||||
  if (*this >= Max() || *this == PreStream()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Timestamp Timestamp::PreviousAllowedInStream() const {
 | 
			
		||||
  if (*this <= Min() || *this == PostStream()) {
 | 
			
		||||
    // Indicates that no previous timestamps may occur.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,6 +186,10 @@ class Timestamp {
 | 
			
		|||
  // CHECKs that this->IsAllowedInStream().
 | 
			
		||||
  Timestamp NextAllowedInStream() const;
 | 
			
		||||
 | 
			
		||||
  // Returns true if there's a next timestamp in the range [Min .. Max] after
 | 
			
		||||
  // this one.
 | 
			
		||||
  bool HasNextAllowedInStream() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the previous timestamp in the range [Min .. Max], or
 | 
			
		||||
  // Unstarted() if no Packets may preceed one with this timestamp.
 | 
			
		||||
  Timestamp PreviousAllowedInStream() const;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
 | 
			
		|||
            Timestamp::PostStream().NextAllowedInStream());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TimestampTest, HasNextAllowedInStream) {
 | 
			
		||||
  EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
 | 
			
		||||
 | 
			
		||||
  EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TimestampTest, SpecialValueDifferences) {
 | 
			
		||||
  {  // Lower range
 | 
			
		||||
    const std::vector<Timestamp> timestamps = {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -530,6 +530,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/port:ret_check",
 | 
			
		||||
        "//mediapipe/framework/port:status",
 | 
			
		||||
        "@com_google_absl//absl/base:core_headers",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
        "@com_google_absl//absl/memory",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
 | 
			
		||||
"""MediaPipe Task Library Helper Rules for iOS"""
 | 
			
		||||
 | 
			
		||||
MPP_TASK_MINIMUM_OS_VERSION = "11.0"
 | 
			
		||||
MPP_TASK_MINIMUM_OS_VERSION = "12.0"
 | 
			
		||||
 | 
			
		||||
# When the static framework is built with bazel, the all header files are moved
 | 
			
		||||
# to the "Headers" directory with no header path prefixes. This auxiliary rule
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,7 @@ def mediapipe_proto_library_impl(
 | 
			
		|||
        def_cc_proto = True,
 | 
			
		||||
        def_py_proto = True,
 | 
			
		||||
        def_java_lite_proto = True,
 | 
			
		||||
        def_kt_lite_proto = True,
 | 
			
		||||
        def_objc_proto = True,
 | 
			
		||||
        def_java_proto = True,
 | 
			
		||||
        def_jspb_proto = True,
 | 
			
		||||
| 
						 | 
				
			
			@ -72,6 +73,7 @@ def mediapipe_proto_library_impl(
 | 
			
		|||
      def_cc_proto: define the cc_proto_library target
 | 
			
		||||
      def_py_proto: define the py_proto_library target
 | 
			
		||||
      def_java_lite_proto: define the java_lite_proto_library target
 | 
			
		||||
      def_kt_lite_proto: define the kt_lite_proto_library target
 | 
			
		||||
      def_objc_proto: define the objc_proto_library target
 | 
			
		||||
      def_java_proto: define the java_proto_library target
 | 
			
		||||
      def_jspb_proto: define the jspb_proto_library target
 | 
			
		||||
| 
						 | 
				
			
			@ -255,6 +257,7 @@ def mediapipe_proto_library(
 | 
			
		|||
        def_cc_proto = True,
 | 
			
		||||
        def_py_proto = True,
 | 
			
		||||
        def_java_lite_proto = True,
 | 
			
		||||
        def_kt_lite_proto = True,
 | 
			
		||||
        def_portable_proto = True,  # @unused
 | 
			
		||||
        def_objc_proto = True,
 | 
			
		||||
        def_java_proto = True,
 | 
			
		||||
| 
						 | 
				
			
			@ -281,6 +284,7 @@ def mediapipe_proto_library(
 | 
			
		|||
      def_cc_proto: define the cc_proto_library target
 | 
			
		||||
      def_py_proto: define the py_proto_library target
 | 
			
		||||
      def_java_lite_proto: define the java_lite_proto_library target
 | 
			
		||||
      def_kt_lite_proto: define the kt_lite_proto_library target
 | 
			
		||||
      def_portable_proto: ignored since portable protos are gone
 | 
			
		||||
      def_objc_proto: define the objc_proto_library target
 | 
			
		||||
      def_java_proto: define the java_proto_library target
 | 
			
		||||
| 
						 | 
				
			
			@ -304,6 +308,7 @@ def mediapipe_proto_library(
 | 
			
		|||
        def_cc_proto = def_cc_proto,
 | 
			
		||||
        def_py_proto = def_py_proto,
 | 
			
		||||
        def_java_lite_proto = def_java_lite_proto,
 | 
			
		||||
        def_kt_lite_proto = def_kt_lite_proto,
 | 
			
		||||
        def_objc_proto = def_objc_proto,
 | 
			
		||||
        def_java_proto = def_java_proto,
 | 
			
		||||
        def_jspb_proto = def_jspb_proto,
 | 
			
		||||
| 
						 | 
				
			
			@ -334,6 +339,7 @@ def mediapipe_proto_library(
 | 
			
		|||
            def_cc_proto = def_cc_proto,
 | 
			
		||||
            def_py_proto = def_py_proto,
 | 
			
		||||
            def_java_lite_proto = def_java_lite_proto,
 | 
			
		||||
            def_kt_lite_proto = def_kt_lite_proto,
 | 
			
		||||
            def_objc_proto = def_objc_proto,
 | 
			
		||||
            def_java_proto = def_java_proto,
 | 
			
		||||
            def_jspb_proto = def_jspb_proto,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@
 | 
			
		|||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/container/flat_hash_set.h"
 | 
			
		||||
#include "absl/memory/memory.h"
 | 
			
		||||
#include "absl/strings/ascii.h"
 | 
			
		||||
#include "absl/strings/numbers.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -1430,10 +1431,10 @@ std::vector<const FieldDescriptor*> GetFields(const Message* src) {
 | 
			
		|||
 | 
			
		||||
// Orders map entries in dst to match src.
 | 
			
		||||
void OrderMapEntries(const Message* src, Message* dst,
 | 
			
		||||
                     std::set<const Message*>* seen = nullptr) {
 | 
			
		||||
  std::unique_ptr<std::set<const Message*>> seen_owner;
 | 
			
		||||
                     absl::flat_hash_set<const Message*>* seen = nullptr) {
 | 
			
		||||
  std::unique_ptr<absl::flat_hash_set<const Message*>> seen_owner;
 | 
			
		||||
  if (!seen) {
 | 
			
		||||
    seen_owner = std::make_unique<std::set<const Message*>>();
 | 
			
		||||
    seen_owner = std::make_unique<absl::flat_hash_set<const Message*>>();
 | 
			
		||||
    seen = seen_owner.get();
 | 
			
		||||
  }
 | 
			
		||||
  if (seen->count(src) > 0) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,6 +34,7 @@ import java.util.HashMap;
 | 
			
		|||
import java.util.Map;
 | 
			
		||||
import java.util.concurrent.atomic.AtomicBoolean;
 | 
			
		||||
import java.util.concurrent.atomic.AtomicReference;
 | 
			
		||||
import javax.annotation.Nullable;
 | 
			
		||||
import javax.microedition.khronos.egl.EGLConfig;
 | 
			
		||||
import javax.microedition.khronos.opengles.GL10;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -303,7 +304,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  // Use this when the texture is not a SurfaceTexture.
 | 
			
		||||
  public void setNextFrame(TextureFrame frame) {
 | 
			
		||||
  public void setNextFrame(@Nullable TextureFrame frame) {
 | 
			
		||||
    if (surfaceTexture != null) {
 | 
			
		||||
      Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */);
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,6 @@ android_library(
 | 
			
		|||
        "MediaPipeRunner.java",
 | 
			
		||||
    ],
 | 
			
		||||
    visibility = [
 | 
			
		||||
        "//java/com/google/android/libraries/camera/effects:__subpackages__",
 | 
			
		||||
        "//mediapipe/java/com/google/mediapipe:__subpackages__",
 | 
			
		||||
    ],
 | 
			
		||||
    exports = [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -67,6 +67,7 @@ public class ExternalTextureRenderer {
 | 
			
		|||
  private float[] textureTransformMatrix = new float[16];
 | 
			
		||||
  private boolean flipY;
 | 
			
		||||
  private int rotation = Surface.ROTATION_0;
 | 
			
		||||
  private boolean doExplicitCpuSync = true;
 | 
			
		||||
 | 
			
		||||
  /** Call this to setup the shader program before rendering. */
 | 
			
		||||
  public void setup() {
 | 
			
		||||
| 
						 | 
				
			
			@ -101,6 +102,14 @@ public class ExternalTextureRenderer {
 | 
			
		|||
    this.rotation = rotation;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Configures whether the renderer should do an explicit CPU synchronization using glFinish upon
 | 
			
		||||
   * each {@link #render} call. Defaults to true.
 | 
			
		||||
   */
 | 
			
		||||
  public void setDoExplicitCpuSync(boolean doExplicitCpuSync) {
 | 
			
		||||
    this.doExplicitCpuSync = doExplicitCpuSync;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Renders the surfaceTexture to the framebuffer with optional vertical flip.
 | 
			
		||||
   *
 | 
			
		||||
| 
						 | 
				
			
			@ -150,9 +159,12 @@ public class ExternalTextureRenderer {
 | 
			
		|||
    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
 | 
			
		||||
    ShaderUtil.checkGlError("glBindTexture");
 | 
			
		||||
 | 
			
		||||
    if (doExplicitCpuSync) {
 | 
			
		||||
 | 
			
		||||
      // TODO: add sync and go back to glFlush()
 | 
			
		||||
      GLES20.glFinish();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Call this to delete the shader program.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,10 @@
 | 
			
		|||
 | 
			
		||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
			
		||||
package(default_visibility = [
 | 
			
		||||
    "//cloud/ml/applications/vision/model_garden/model_oss/mediapipe:__subpackages__",
 | 
			
		||||
    "//mediapipe:__subpackages__",
 | 
			
		||||
])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel):
 | 
			
		|||
    self._model: tf.keras.Model = None
 | 
			
		||||
    self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
 | 
			
		||||
    self._loss_function: Union[str, tf.keras.losses.Loss] = None
 | 
			
		||||
    self._metric_function: Union[str, tf.keras.metrics.Metric] = None
 | 
			
		||||
    self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None
 | 
			
		||||
    self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
 | 
			
		||||
    self._hparams: hp.BaseHParams = None
 | 
			
		||||
    self._history: tf.keras.callbacks.History = None
 | 
			
		||||
| 
						 | 
				
			
			@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel):
 | 
			
		|||
    self._model.compile(
 | 
			
		||||
        optimizer=self._optimizer,
 | 
			
		||||
        loss=self._loss_function,
 | 
			
		||||
        metrics=[self._metric_function])
 | 
			
		||||
        metrics=self._metric_functions,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    latest_checkpoint = (
 | 
			
		||||
        tf.train.latest_checkpoint(checkpoint_path)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,10 +80,30 @@ py_test(
 | 
			
		|||
    deps = [":loss_functions"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
######################################################################
 | 
			
		||||
# Public target of the MediaPipe Model Maker Quantization Config.
 | 
			
		||||
 | 
			
		||||
# Quantization Config is used to export a quantized model. Please refer
 | 
			
		||||
# to the specific task documentations such as:
 | 
			
		||||
# https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize
 | 
			
		||||
# for usage information.
 | 
			
		||||
######################################################################
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "metrics",
 | 
			
		||||
    srcs = ["metrics.py"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "metrics_test",
 | 
			
		||||
    srcs = ["metrics_test.py"],
 | 
			
		||||
    deps = [":metrics"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "quantization",
 | 
			
		||||
    srcs = ["quantization.py"],
 | 
			
		||||
    srcs_version = "PY3",
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = ["//mediapipe/model_maker/python/core/data:dataset"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										104
									
								
								mediapipe/model_maker/python/core/utils/metrics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								mediapipe/model_maker/python/core/utils/metrics.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,104 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Metrics utility library."""
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_binary_sparse_metric(metric: tf.metrics.Metric):
 | 
			
		||||
  """Helper method to create a BinarySparse version of a tf.keras.Metric.
 | 
			
		||||
 | 
			
		||||
  BinarySparse is an implementation where the update_state(y_true, y_pred) takes
 | 
			
		||||
  in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only
 | 
			
		||||
  supports the binary classification case, and that class_id=0 is the negative
 | 
			
		||||
  class and class_id=1 is the positive class.
 | 
			
		||||
 | 
			
		||||
  Currently supported tf.metric.Metric classes
 | 
			
		||||
    1. BinarySparseRecallAtPrecision
 | 
			
		||||
    2. BinarySparsePrecisionAtRecall
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    metric: A tf.metric.Metric class for which we want to generate a
 | 
			
		||||
      BinarySparse version of this metric.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A class for the BinarySparse version of the specified tf.metrics.Metric
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  class BinarySparseMetric(metric):
 | 
			
		||||
    """A BinarySparse wrapper class for a tf.keras.Metric.
 | 
			
		||||
 | 
			
		||||
    This class has the same parameters and functions as the underlying
 | 
			
		||||
    metric class. For example, the parameters for BinarySparseRecallAtPrecision
 | 
			
		||||
    is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint
 | 
			
		||||
    is that class_id must be set to 1 (or not specified) for the Binary metric.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
      if 'class_id' in kwargs and kwargs['class_id'] != 1:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f'Custom BinarySparseMetric for class:{metric.__name__} is '
 | 
			
		||||
            'only supported for class_id=1, got class_id='
 | 
			
		||||
            f'{kwargs["class_id"]} instead'
 | 
			
		||||
        )
 | 
			
		||||
      else:
 | 
			
		||||
        kwargs['class_id'] = 1
 | 
			
		||||
      super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def update_state(self, y_true, y_pred, sample_weight=None):
 | 
			
		||||
      y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
 | 
			
		||||
      y_true_one_hot = tf.one_hot(y_true, 2)
 | 
			
		||||
      return super().update_state(
 | 
			
		||||
          y_true_one_hot, y_pred, sample_weight=sample_weight
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
  return BinarySparseMetric
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_sparse_metric(metric: tf.metrics.Metric):
 | 
			
		||||
  """Helper method to create a Sparse version of a tf.keras.Metric.
 | 
			
		||||
 | 
			
		||||
  Sparse is an implementation where the update_state(y_true, y_pred) takes in
 | 
			
		||||
  shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes).
 | 
			
		||||
 | 
			
		||||
  Currently supported tf.metrics.Metric classes:
 | 
			
		||||
    1. tf.metrics.Recall
 | 
			
		||||
    2. tf.metrics.Precision
 | 
			
		||||
 | 
			
		||||
  Args:
 | 
			
		||||
    metric: A tf.metric.Metric class for which we want to generate a Sparse
 | 
			
		||||
      version of this metric.
 | 
			
		||||
 | 
			
		||||
  Returns:
 | 
			
		||||
    A class for the Sparse version of the specified tf.keras.Metric.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  class SparseMetric(metric):
 | 
			
		||||
    """A Sparse wrapper class for a tf.keras.Metric."""
 | 
			
		||||
 | 
			
		||||
    def update_state(self, y_true, y_pred, sample_weight=None):
 | 
			
		||||
      y_pred = tf.math.argmax(y_pred, axis=-1)
 | 
			
		||||
      return super().update_state(y_true, y_pred, sample_weight=sample_weight)
 | 
			
		||||
 | 
			
		||||
  return SparseMetric
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SparseRecall = _get_sparse_metric(tf.metrics.Recall)
 | 
			
		||||
SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
 | 
			
		||||
BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
 | 
			
		||||
    tf.metrics.RecallAtPrecision
 | 
			
		||||
)
 | 
			
		||||
BinarySparsePrecisionAtRecall = _get_binary_sparse_metric(
 | 
			
		||||
    tf.metrics.PrecisionAtRecall
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										74
									
								
								mediapipe/model_maker/python/core/utils/metrics_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								mediapipe/model_maker/python/core/utils/metrics_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,74 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    self.y_true = [0, 0, 1, 1, 0, 1]
 | 
			
		||||
    self.y_pred = [
 | 
			
		||||
        [0.9, 0.1],  # 0, 0 y
 | 
			
		||||
        [0.8, 0.2],  # 0, 0 y
 | 
			
		||||
        [0.7, 0.3],  # 0, 1 n
 | 
			
		||||
        [0.6, 0.4],  # 0, 1 n
 | 
			
		||||
        [0.3, 0.7],  # 1, 0 y
 | 
			
		||||
        [0.3, 0.7],  # 1, 1 y
 | 
			
		||||
    ]
 | 
			
		||||
    self.num_classes = 3
 | 
			
		||||
 | 
			
		||||
  def _assert_metric_equals(self, metric, value):
 | 
			
		||||
    metric.update_state(self.y_true, self.y_pred)
 | 
			
		||||
    self.assertEqual(metric.result(), value)
 | 
			
		||||
 | 
			
		||||
  def test_sparse_recall(self):
 | 
			
		||||
    metric = metrics.SparseRecall()
 | 
			
		||||
    self._assert_metric_equals(metric, 1 / 3)
 | 
			
		||||
 | 
			
		||||
  def test_sparse_precision(self):
 | 
			
		||||
    metric = metrics.SparsePrecision()
 | 
			
		||||
    self._assert_metric_equals(metric, 1 / 2)
 | 
			
		||||
 | 
			
		||||
  def test_binary_sparse_recall_at_precision(self):
 | 
			
		||||
    metric = metrics.BinarySparseRecallAtPrecision(1.0)
 | 
			
		||||
    self._assert_metric_equals(metric, 0.0)  # impossible to achieve precision=1
 | 
			
		||||
    metric = metrics.BinarySparseRecallAtPrecision(0.4)
 | 
			
		||||
    self._assert_metric_equals(metric, 1.0)
 | 
			
		||||
 | 
			
		||||
  def test_binary_sparse_precision_at_recall(self):
 | 
			
		||||
    metric = metrics.BinarySparsePrecisionAtRecall(1.0)
 | 
			
		||||
    self._assert_metric_equals(metric, 3 / 4)
 | 
			
		||||
    metric = metrics.BinarySparsePrecisionAtRecall(0.7)
 | 
			
		||||
    self._assert_metric_equals(metric, 3 / 4)
 | 
			
		||||
 | 
			
		||||
  def test_binary_sparse_precision_at_recall_class_id_error(self):
 | 
			
		||||
    # class_id=1 case should not error
 | 
			
		||||
    _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1)
 | 
			
		||||
    # class_id=2 case should error
 | 
			
		||||
    with self.assertRaisesRegex(
 | 
			
		||||
        ValueError,
 | 
			
		||||
        'Custom BinarySparseMetric for class:PrecisionAtRecall is only'
 | 
			
		||||
        ' supported for class_id=1, got class_id=2 instead',
 | 
			
		||||
    ):
 | 
			
		||||
      _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
  tf.test.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -31,11 +31,11 @@ py_library(
 | 
			
		|||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":text_classifier",
 | 
			
		||||
        ":text_classifier_options",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -45,12 +45,18 @@ py_library(
 | 
			
		|||
    deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "hyperparameters",
 | 
			
		||||
    srcs = ["hyperparameters.py"],
 | 
			
		||||
    deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "model_spec",
 | 
			
		||||
    srcs = ["model_spec.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:file_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/text/core:bert_model_spec",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			@ -61,9 +67,9 @@ py_test(
 | 
			
		|||
    srcs = ["model_spec_test.py"],
 | 
			
		||||
    tags = ["requires-net:external"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -100,9 +106,9 @@ py_library(
 | 
			
		|||
    name = "text_classifier_options",
 | 
			
		||||
    srcs = ["text_classifier_options.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -111,13 +117,14 @@ py_library(
 | 
			
		|||
    srcs = ["text_classifier.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        ":hyperparameters",
 | 
			
		||||
        ":model_options",
 | 
			
		||||
        ":model_spec",
 | 
			
		||||
        ":preprocessor",
 | 
			
		||||
        ":text_classifier_options",
 | 
			
		||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:metrics",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
			
		||||
        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,19 +13,23 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""MediaPipe Public Python API for Text Classifier."""
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import dataset
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
			
		||||
 | 
			
		||||
HParams = hyperparameters.BaseHParams
 | 
			
		||||
 | 
			
		||||
AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
 | 
			
		||||
AverageWordEmbeddingModelOptions = (
 | 
			
		||||
    model_options.AverageWordEmbeddingModelOptions
 | 
			
		||||
)
 | 
			
		||||
BertOptimizer = hyperparameters.BertOptimizer
 | 
			
		||||
BertHParams = hyperparameters.BertHParams
 | 
			
		||||
BertModelOptions = model_options.BertModelOptions
 | 
			
		||||
CSVParams = dataset.CSVParameters
 | 
			
		||||
Dataset = dataset.Dataset
 | 
			
		||||
AverageWordEmbeddingModelOptions = (
 | 
			
		||||
    model_options.AverageWordEmbeddingModelOptions)
 | 
			
		||||
BertModelOptions = model_options.BertModelOptions
 | 
			
		||||
SupportedModels = model_spec.SupportedModels
 | 
			
		||||
TextClassifier = text_classifier.TextClassifier
 | 
			
		||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,54 @@
 | 
			
		|||
# Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Hyperparameters for training object detection models."""
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
import enum
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class AverageWordEmbeddingHParams(hp.BaseHParams):
 | 
			
		||||
  """The hyperparameters for an AverageWordEmbeddingClassifier."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@enum.unique
 | 
			
		||||
class BertOptimizer(enum.Enum):
 | 
			
		||||
  """Supported Optimizers for Bert Text Classifier."""
 | 
			
		||||
 | 
			
		||||
  ADAMW = "adamw"
 | 
			
		||||
  LAMB = "lamb"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class BertHParams(hp.BaseHParams):
 | 
			
		||||
  """The hyperparameters for a Bert Classifier.
 | 
			
		||||
 | 
			
		||||
  Attributes:
 | 
			
		||||
    learning_rate: Learning rate to use for gradient descent training.
 | 
			
		||||
    batch_size: Batch size for training.
 | 
			
		||||
    epochs: Number of training iterations over the dataset.
 | 
			
		||||
    optimizer: Optimizer to use for training. Only supported values are "adamw"
 | 
			
		||||
      and "lamb".
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  learning_rate: float = 3e-5
 | 
			
		||||
  batch_size: int = 48
 | 
			
		||||
  epochs: int = 2
 | 
			
		||||
  optimizer: BertOptimizer = BertOptimizer.ADAMW
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
HParams = Union[BertHParams, AverageWordEmbeddingHParams]
 | 
			
		||||
| 
						 | 
				
			
			@ -17,13 +17,11 @@ import dataclasses
 | 
			
		|||
import enum
 | 
			
		||||
import functools
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import file_util
 | 
			
		||||
from mediapipe.model_maker.python.text.core import bert_model_spec
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
			
		||||
 | 
			
		||||
# BERT-based text classifier spec inherited from BertModelSpec
 | 
			
		||||
BertClassifierSpec = bert_model_spec.BertModelSpec
 | 
			
		||||
 | 
			
		||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'text_classifier/mobilebert_tiny',
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
			
		|||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
EXBERT_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'text_classifier/exbert',
 | 
			
		||||
    'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
 | 
			
		||||
    is_folder=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class AverageWordEmbeddingClassifierSpec:
 | 
			
		||||
| 
						 | 
				
			
			@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
 | 
			
		|||
  """
 | 
			
		||||
 | 
			
		||||
  # `learning_rate` is unused for the average word embedding model
 | 
			
		||||
  hparams: hp.BaseHParams = hp.BaseHParams(
 | 
			
		||||
      epochs=10, batch_size=32, learning_rate=0)
 | 
			
		||||
  hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
 | 
			
		||||
      epochs=10, batch_size=32, learning_rate=0
 | 
			
		||||
  )
 | 
			
		||||
  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
			
		||||
      mo.AverageWordEmbeddingModelOptions())
 | 
			
		||||
  name: str = 'AverageWordEmbedding'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
average_word_embedding_classifier_spec = functools.partial(
 | 
			
		||||
    AverageWordEmbeddingClassifierSpec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class BertClassifierSpec(bert_model_spec.BertModelSpec):
 | 
			
		||||
  """Specification for a Bert classifier model.
 | 
			
		||||
 | 
			
		||||
  Only overrides the hparams attribute since the rest of the attributes are
 | 
			
		||||
  inherited from the BertModelSpec.
 | 
			
		||||
  """
 | 
			
		||||
 | 
			
		||||
  hparams: hp.BertHParams = hp.BertHParams()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mobilebert_classifier_spec = functools.partial(
 | 
			
		||||
    BertClassifierSpec,
 | 
			
		||||
    downloaded_files=MOBILEBERT_TINY_FILES,
 | 
			
		||||
    hparams=hp.BaseHParams(
 | 
			
		||||
    hparams=hp.BertHParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
			
		||||
    ),
 | 
			
		||||
    name='MobileBert',
 | 
			
		||||
    tflite_input_name={
 | 
			
		||||
        'ids': 'serving_default_input_1:0',
 | 
			
		||||
        'mask': 'serving_default_input_3:0',
 | 
			
		||||
        'segment_ids': 'serving_default_input_2:0',
 | 
			
		||||
        'mask': 'serving_default_input_3:0',
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
exbert_classifier_spec = functools.partial(
 | 
			
		||||
    BertClassifierSpec,
 | 
			
		||||
    downloaded_files=EXBERT_FILES,
 | 
			
		||||
    hparams=hp.BertHParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
			
		||||
    ),
 | 
			
		||||
    name='ExBert',
 | 
			
		||||
    tflite_input_name={
 | 
			
		||||
        'ids': 'serving_default_input_1:0',
 | 
			
		||||
        'segment_ids': 'serving_default_input_2:0',
 | 
			
		||||
        'mask': 'serving_default_input_3:0',
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
 | 
			
		|||
  """Predefined text classifier model specs supported by Model Maker."""
 | 
			
		||||
  AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
 | 
			
		||||
  MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
 | 
			
		||||
  EXBERT_CLASSIFIER = exbert_classifier_spec
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
 | 
			
		|||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
            seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        model_spec_obj.hparams,
 | 
			
		||||
        hp.BaseHParams(
 | 
			
		||||
        hp.BertHParams(
 | 
			
		||||
            epochs=3,
 | 
			
		||||
            batch_size=48,
 | 
			
		||||
            learning_rate=3e-5,
 | 
			
		||||
            distribution_strategy='off'))
 | 
			
		||||
            distribution_strategy='off',
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  def test_predefined_average_word_embedding_spec(self):
 | 
			
		||||
    model_spec_obj = (
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
            dropout_rate=0.2))
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        model_spec_obj.hparams,
 | 
			
		||||
        hp.BaseHParams(
 | 
			
		||||
        hp.AverageWordEmbeddingHParams(
 | 
			
		||||
            epochs=10,
 | 
			
		||||
            batch_size=32,
 | 
			
		||||
            learning_rate=0,
 | 
			
		||||
| 
						 | 
				
			
			@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
                     custom_bert_classifier_options)
 | 
			
		||||
 | 
			
		||||
  def test_custom_average_word_embedding_spec(self):
 | 
			
		||||
    custom_hparams = hp.BaseHParams(
 | 
			
		||||
    custom_hparams = hp.AverageWordEmbeddingHParams(
 | 
			
		||||
        learning_rate=0.4,
 | 
			
		||||
        batch_size=64,
 | 
			
		||||
        epochs=10,
 | 
			
		||||
| 
						 | 
				
			
			@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
 | 
			
		|||
        export_dir='foo/bar',
 | 
			
		||||
        distribution_strategy='mirrored',
 | 
			
		||||
        num_gpus=3,
 | 
			
		||||
        tpu='tpu/address')
 | 
			
		||||
        tpu='tpu/address',
 | 
			
		||||
    )
 | 
			
		||||
    custom_average_word_embedding_model_options = (
 | 
			
		||||
        classifier_model_options.AverageWordEmbeddingModelOptions(
 | 
			
		||||
            seq_len=512,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,14 +19,16 @@ import tempfile
 | 
			
		|||
from typing import Any, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
from tensorflow_addons import optimizers as tfa_optimizers
 | 
			
		||||
import tensorflow_hub as hub
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.core.data import dataset as ds
 | 
			
		||||
from mediapipe.model_maker.python.core.tasks import classifier
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import metrics
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
			
		||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
			
		||||
| 
						 | 
				
			
			@ -54,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
 | 
			
		|||
       ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
 | 
			
		||||
    raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
 | 
			
		||||
                     f" got {options.supported_model}")
 | 
			
		||||
  if (isinstance(options.model_options, mo.BertModelOptions) and
 | 
			
		||||
      (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
 | 
			
		||||
  if isinstance(options.model_options, mo.BertModelOptions) and (
 | 
			
		||||
      options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
			
		||||
      and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
 | 
			
		||||
  ):
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
        f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
 | 
			
		||||
        "Expected a Bert Classifier(MobileBERT or EXBERT), got "
 | 
			
		||||
        f"{options.supported_model}"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TextClassifier(classifier.Classifier):
 | 
			
		||||
  """API for creating and training a text classification model."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
 | 
			
		||||
               label_names: Sequence[str]):
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self, model_spec: Any, label_names: Sequence[str], shuffle: bool
 | 
			
		||||
  ):
 | 
			
		||||
    super().__init__(
 | 
			
		||||
        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
			
		||||
        model_spec=model_spec, label_names=label_names, shuffle=shuffle
 | 
			
		||||
    )
 | 
			
		||||
    self._model_spec = model_spec
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -106,7 +112,10 @@ class TextClassifier(classifier.Classifier):
 | 
			
		|||
    if options.hparams is None:
 | 
			
		||||
      options.hparams = options.supported_model.value().hparams
 | 
			
		||||
 | 
			
		||||
    if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
 | 
			
		||||
    if (
 | 
			
		||||
        options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
			
		||||
        or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
 | 
			
		||||
    ):
 | 
			
		||||
      text_classifier = (
 | 
			
		||||
          _BertClassifier.create_bert_classifier(train_data, validation_data,
 | 
			
		||||
                                                 options,
 | 
			
		||||
| 
						 | 
				
			
			@ -123,12 +132,24 @@ class TextClassifier(classifier.Classifier):
 | 
			
		|||
 | 
			
		||||
    return text_classifier
 | 
			
		||||
 | 
			
		||||
  def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
 | 
			
		||||
  def evaluate(
 | 
			
		||||
      self,
 | 
			
		||||
      data: ds.Dataset,
 | 
			
		||||
      batch_size: int = 32,
 | 
			
		||||
      desired_precisions: Optional[Sequence[float]] = None,
 | 
			
		||||
      desired_recalls: Optional[Sequence[float]] = None,
 | 
			
		||||
  ) -> Any:
 | 
			
		||||
    """Overrides Classifier.evaluate().
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      data: Evaluation dataset. Must be a TextClassifier Dataset.
 | 
			
		||||
      batch_size: Number of samples per evaluation step.
 | 
			
		||||
      desired_precisions: If specified, adds a RecallAtPrecision metric per
 | 
			
		||||
        desired_precisions[i] entry which tracks the recall given the constraint
 | 
			
		||||
        on precision. Only supported for binary classification.
 | 
			
		||||
      desired_recalls: If specified, adds a PrecisionAtRecall metric per
 | 
			
		||||
        desired_recalls[i] entry which tracks the precision given the constraint
 | 
			
		||||
        on recall. Only supported for binary classification.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      The loss value and accuracy.
 | 
			
		||||
| 
						 | 
				
			
			@ -144,6 +165,28 @@ class TextClassifier(classifier.Classifier):
 | 
			
		|||
 | 
			
		||||
    processed_data = self._text_preprocessor.preprocess(data)
 | 
			
		||||
    dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
 | 
			
		||||
 | 
			
		||||
    additional_metrics = []
 | 
			
		||||
    if desired_precisions and len(data.label_names) == 2:
 | 
			
		||||
      for precision in desired_precisions:
 | 
			
		||||
        additional_metrics.append(
 | 
			
		||||
            metrics.BinarySparseRecallAtPrecision(
 | 
			
		||||
                precision, name=f"recall_at_precision_{precision}"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    if desired_recalls and len(data.label_names) == 2:
 | 
			
		||||
      for recall in desired_recalls:
 | 
			
		||||
        additional_metrics.append(
 | 
			
		||||
            metrics.BinarySparsePrecisionAtRecall(
 | 
			
		||||
                recall, name=f"precision_at_recall_{recall}"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    metric_functions = self._metric_functions + additional_metrics
 | 
			
		||||
    self._model.compile(
 | 
			
		||||
        optimizer=self._optimizer,
 | 
			
		||||
        loss=self._loss_function,
 | 
			
		||||
        metrics=metric_functions,
 | 
			
		||||
    )
 | 
			
		||||
    return self._model.evaluate(dataset)
 | 
			
		||||
 | 
			
		||||
  def export_model(
 | 
			
		||||
| 
						 | 
				
			
			@ -161,9 +204,8 @@ class TextClassifier(classifier.Classifier):
 | 
			
		|||
        path is {self._hparams.export_dir}/{model_name}.
 | 
			
		||||
      quantization_config: The configuration for model quantization.
 | 
			
		||||
    """
 | 
			
		||||
    if not tf.io.gfile.exists(self._hparams.export_dir):
 | 
			
		||||
      tf.io.gfile.makedirs(self._hparams.export_dir)
 | 
			
		||||
    tflite_file = os.path.join(self._hparams.export_dir, model_name)
 | 
			
		||||
    tf.io.gfile.makedirs(os.path.dirname(tflite_file))
 | 
			
		||||
    metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
 | 
			
		||||
 | 
			
		||||
    tflite_model = model_util.convert_to_tflite(
 | 
			
		||||
| 
						 | 
				
			
			@ -174,7 +216,7 @@ class TextClassifier(classifier.Classifier):
 | 
			
		|||
    writer = self._get_metadata_writer(tflite_model, vocab_filepath)
 | 
			
		||||
    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
			
		||||
    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
			
		||||
    with open(metadata_file, "w") as f:
 | 
			
		||||
    with tf.io.gfile.GFile(metadata_file, "w") as f:
 | 
			
		||||
      f.write(metadata_json)
 | 
			
		||||
 | 
			
		||||
  @abc.abstractmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -191,13 +233,23 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
 | 
			
		|||
 | 
			
		||||
  _DELIM_REGEX_PATTERN = r"[^\w\']+"
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self,
 | 
			
		||||
      model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
			
		||||
      model_options: mo.AverageWordEmbeddingModelOptions,
 | 
			
		||||
               hparams: hp.BaseHParams, label_names: Sequence[str]):
 | 
			
		||||
    super().__init__(model_spec, hparams, label_names)
 | 
			
		||||
      hparams: hp.AverageWordEmbeddingHParams,
 | 
			
		||||
      label_names: Sequence[str],
 | 
			
		||||
  ):
 | 
			
		||||
    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._loss_function = "sparse_categorical_crossentropy"
 | 
			
		||||
    self._metric_function = "accuracy"
 | 
			
		||||
    self._metric_functions = [
 | 
			
		||||
        "accuracy",
 | 
			
		||||
        metrics.SparsePrecision(name="precision", dtype=tf.float32),
 | 
			
		||||
        metrics.SparseRecall(name="recall", dtype=tf.float32),
 | 
			
		||||
    ]
 | 
			
		||||
    self._text_preprocessor: (
 | 
			
		||||
        preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -306,16 +358,26 @@ class _BertClassifier(TextClassifier):
 | 
			
		|||
 | 
			
		||||
  _INITIALIZER_RANGE = 0.02
 | 
			
		||||
 | 
			
		||||
  def __init__(self, model_spec: ms.BertClassifierSpec,
 | 
			
		||||
               model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
 | 
			
		||||
               label_names: Sequence[str]):
 | 
			
		||||
    super().__init__(model_spec, hparams, label_names)
 | 
			
		||||
  def __init__(
 | 
			
		||||
      self,
 | 
			
		||||
      model_spec: ms.BertClassifierSpec,
 | 
			
		||||
      model_options: mo.BertModelOptions,
 | 
			
		||||
      hparams: hp.BertHParams,
 | 
			
		||||
      label_names: Sequence[str],
 | 
			
		||||
  ):
 | 
			
		||||
    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    with self._hparams.get_strategy().scope():
 | 
			
		||||
      self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 | 
			
		||||
      self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
 | 
			
		||||
      self._metric_functions = [
 | 
			
		||||
          tf.keras.metrics.SparseCategoricalAccuracy(
 | 
			
		||||
              "test_accuracy", dtype=tf.float32
 | 
			
		||||
      )
 | 
			
		||||
          ),
 | 
			
		||||
          metrics.SparsePrecision(name="precision", dtype=tf.float32),
 | 
			
		||||
          metrics.SparseRecall(name="recall", dtype=tf.float32),
 | 
			
		||||
      ]
 | 
			
		||||
    self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
| 
						 | 
				
			
			@ -438,11 +500,26 @@ class _BertClassifier(TextClassifier):
 | 
			
		|||
          initial_learning_rate=initial_lr,
 | 
			
		||||
          decay_schedule_fn=lr_schedule,
 | 
			
		||||
          warmup_steps=warmup_steps)
 | 
			
		||||
 | 
			
		||||
    if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
 | 
			
		||||
      self._optimizer = tf.keras.optimizers.experimental.AdamW(
 | 
			
		||||
        lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
 | 
			
		||||
          lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
 | 
			
		||||
      )
 | 
			
		||||
      self._optimizer.exclude_from_weight_decay(
 | 
			
		||||
        var_names=["LayerNorm", "layer_norm", "bias"])
 | 
			
		||||
          var_names=["LayerNorm", "layer_norm", "bias"]
 | 
			
		||||
      )
 | 
			
		||||
    elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
 | 
			
		||||
      self._optimizer = tfa_optimizers.LAMB(
 | 
			
		||||
          lr_schedule,
 | 
			
		||||
          weight_decay_rate=0.01,
 | 
			
		||||
          epsilon=1e-6,
 | 
			
		||||
          exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
 | 
			
		||||
          global_clipnorm=1.0,
 | 
			
		||||
      )
 | 
			
		||||
    else:
 | 
			
		||||
      raise ValueError(
 | 
			
		||||
          "BertHParams.optimizer must be set to ADAM or "
 | 
			
		||||
          f"LAMB. Got {self._hparams.optimizer}."
 | 
			
		||||
      )
 | 
			
		||||
 | 
			
		||||
  def _save_vocab(self, vocab_filepath: str):
 | 
			
		||||
    tf.io.gfile.copy(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,14 +66,16 @@ def run(data_dir,
 | 
			
		|||
  quantization_config = None
 | 
			
		||||
  if (supported_model ==
 | 
			
		||||
      text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
 | 
			
		||||
    hparams = text_classifier.HParams(
 | 
			
		||||
        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
 | 
			
		||||
    hparams = text_classifier.AverageWordEmbeddingHParams(
 | 
			
		||||
        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
 | 
			
		||||
    )
 | 
			
		||||
  # Warning: This takes extremely long to run on CPU
 | 
			
		||||
  elif (
 | 
			
		||||
      supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
 | 
			
		||||
    quantization_config = quantization.QuantizationConfig.for_dynamic()
 | 
			
		||||
    hparams = text_classifier.HParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
 | 
			
		||||
    hparams = text_classifier.BertHParams(
 | 
			
		||||
        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
  # Fine-tunes the model.
 | 
			
		||||
  options = text_classifier.TextClassifierOptions(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -16,7 +16,7 @@
 | 
			
		|||
import dataclasses
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
			
		||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -34,5 +34,5 @@ class TextClassifierOptions:
 | 
			
		|||
      architecture of the `supported_model`.
 | 
			
		||||
  """
 | 
			
		||||
  supported_model: ms.SupportedModels
 | 
			
		||||
  hparams: Optional[hp.BaseHParams] = None
 | 
			
		||||
  hparams: Optional[hp.HParams] = None
 | 
			
		||||
  model_options: Optional[mo.TextClassifierModelOptions] = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def test_create_and_train_average_word_embedding_model(self):
 | 
			
		||||
    train_data, validation_data = self._get_data()
 | 
			
		||||
    options = (
 | 
			
		||||
        text_classifier.TextClassifierOptions(
 | 
			
		||||
            supported_model=(text_classifier.SupportedModels
 | 
			
		||||
                             .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
 | 
			
		||||
            hparams=text_classifier.HParams(
 | 
			
		||||
                epochs=1, batch_size=1, learning_rate=0)))
 | 
			
		||||
    options = text_classifier.TextClassifierOptions(
 | 
			
		||||
        supported_model=(
 | 
			
		||||
            text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
 | 
			
		||||
        ),
 | 
			
		||||
        hparams=text_classifier.AverageWordEmbeddingHParams(
 | 
			
		||||
            epochs=1, batch_size=1, learning_rate=0
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    average_word_embedding_classifier = (
 | 
			
		||||
        text_classifier.TextClassifier.create(train_data, validation_data,
 | 
			
		||||
                                              options))
 | 
			
		||||
| 
						 | 
				
			
			@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
 | 
			
		|||
    options = text_classifier.TextClassifierOptions(
 | 
			
		||||
        supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
 | 
			
		||||
        model_options=text_classifier.BertModelOptions(
 | 
			
		||||
            do_fine_tuning=False, seq_len=2),
 | 
			
		||||
        hparams=text_classifier.HParams(
 | 
			
		||||
            do_fine_tuning=False, seq_len=2
 | 
			
		||||
        ),
 | 
			
		||||
        hparams=text_classifier.BertHParams(
 | 
			
		||||
            epochs=1,
 | 
			
		||||
            batch_size=1,
 | 
			
		||||
            learning_rate=3e-5,
 | 
			
		||||
            distribution_strategy='off'))
 | 
			
		||||
            distribution_strategy='off',
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    bert_classifier = text_classifier.TextClassifier.create(
 | 
			
		||||
        train_data, validation_data, options)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,13 +20,6 @@ licenses(["notice"])
 | 
			
		|||
 | 
			
		||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
			
		||||
 | 
			
		||||
filegroup(
 | 
			
		||||
    name = "testdata",
 | 
			
		||||
    srcs = glob([
 | 
			
		||||
        "testdata/**",
 | 
			
		||||
    ]),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_library(
 | 
			
		||||
    name = "constants",
 | 
			
		||||
    srcs = ["constants.py"],
 | 
			
		||||
| 
						 | 
				
			
			@ -72,18 +65,11 @@ py_library(
 | 
			
		|||
    name = "dataset",
 | 
			
		||||
    srcs = ["dataset.py"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":constants",
 | 
			
		||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
			
		||||
        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
py_test(
 | 
			
		||||
    name = "dataset_test",
 | 
			
		||||
    srcs = ["dataset_test.py"],
 | 
			
		||||
    data = [":testdata"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":dataset",
 | 
			
		||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
			
		||||
        "//mediapipe/python:_framework_bindings",
 | 
			
		||||
        "//mediapipe/tasks/python/core:base_options",
 | 
			
		||||
        "//mediapipe/tasks/python/vision:face_aligner",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
 | 
			
		|||
    'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
 | 
			
		||||
    'face_stylizer/face_landmarker_v2.task',
 | 
			
		||||
    'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
 | 
			
		||||
    is_folder=False,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Dimension of the input style vector to the decoder
 | 
			
		||||
STYLE_DIM = 512
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,13 +13,37 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
"""Face stylizer dataset library."""
 | 
			
		||||
 | 
			
		||||
from typing import Sequence
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import constants
 | 
			
		||||
from mediapipe.python._framework_bindings import image as image_module
 | 
			
		||||
from mediapipe.tasks.python.core import base_options as base_options_module
 | 
			
		||||
from mediapipe.tasks.python.vision import face_aligner
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _preprocess_face_dataset(
 | 
			
		||||
    all_image_paths: Sequence[str],
 | 
			
		||||
) -> Sequence[tf.Tensor]:
 | 
			
		||||
  """Preprocess face image dataset by aligning the face."""
 | 
			
		||||
  path = constants.FACE_ALIGNER_TASK_FILES.get_path()
 | 
			
		||||
  base_options = base_options_module.BaseOptions(model_asset_path=path)
 | 
			
		||||
  options = face_aligner.FaceAlignerOptions(base_options=base_options)
 | 
			
		||||
  aligner = face_aligner.FaceAligner.create_from_options(options)
 | 
			
		||||
 | 
			
		||||
  preprocessed_images = []
 | 
			
		||||
  for path in all_image_paths:
 | 
			
		||||
    tf.compat.v1.logging.info('Preprocess image %s', path)
 | 
			
		||||
    image = image_module.Image.create_from_file(path)
 | 
			
		||||
    aligned_image = aligner.align(image)
 | 
			
		||||
    aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
 | 
			
		||||
    preprocessed_images.append(aligned_image_tensor)
 | 
			
		||||
 | 
			
		||||
  return preprocessed_images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: Change to a unlabeled dataset if it makes sense.
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		|||
    ):
 | 
			
		||||
      raise ValueError('No images found under given directory')
 | 
			
		||||
 | 
			
		||||
    image_data = _preprocess_face_dataset(all_image_paths)
 | 
			
		||||
    label_names = sorted(
 | 
			
		||||
        name
 | 
			
		||||
        for name in os.listdir(data_root)
 | 
			
		||||
| 
						 | 
				
			
			@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
			
		|||
        for path in all_image_paths
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
 | 
			
		||||
 | 
			
		||||
    image_ds = path_ds.map(
 | 
			
		||||
        image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
 | 
			
		||||
    )
 | 
			
		||||
    image_ds = tf.data.Dataset.from_tensor_slices(image_data)
 | 
			
		||||
 | 
			
		||||
    # Load label
 | 
			
		||||
    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,8 +12,10 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
 | 
			
		||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
			
		||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
			
		||||
from mediapipe.tasks.python.test import test_utils
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
 | 
			
		|||
 | 
			
		||||
  def setUp(self):
 | 
			
		||||
    super().setUp()
 | 
			
		||||
    self._test_data_dirname = 'input/style'
 | 
			
		||||
 | 
			
		||||
  def test_from_folder(self):
 | 
			
		||||
    input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
 | 
			
		||||
    test_data_dirname = 'input/style'
 | 
			
		||||
    input_data_dir = test_utils.get_test_data_path(test_data_dirname)
 | 
			
		||||
    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
			
		||||
    self.assertEqual(data.num_classes, 2)
 | 
			
		||||
    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,7 @@
 | 
			
		|||
"""APIs to train face stylization model."""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from typing import Callable, Optional
 | 
			
		||||
from typing import Any, Callable, Optional
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
| 
						 | 
				
			
			@ -54,7 +54,6 @@ class FaceStylizer(object):
 | 
			
		|||
    self._model_spec = model_spec
 | 
			
		||||
    self._model_options = model_options
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    # TODO: Support face alignment in image preprocessor.
 | 
			
		||||
    self._preprocessor = image_preprocessing.Preprocessor(
 | 
			
		||||
        input_shape=self._model_spec.input_image_shape,
 | 
			
		||||
        num_classes=1,
 | 
			
		||||
| 
						 | 
				
			
			@ -128,7 +127,7 @@ class FaceStylizer(object):
 | 
			
		|||
  def _train_model(
 | 
			
		||||
      self,
 | 
			
		||||
      train_data: classification_ds.ClassificationDataset,
 | 
			
		||||
      preprocessor: Optional[Callable[..., bool]] = None,
 | 
			
		||||
      preprocessor: Optional[Callable[..., Any]] = None,
 | 
			
		||||
  ):
 | 
			
		||||
    """Trains the face stylizer model.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier):
 | 
			
		|||
    self._model_options = model_options
 | 
			
		||||
    self._hparams = hparams
 | 
			
		||||
    self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
 | 
			
		||||
    self._metric_function = 'categorical_accuracy'
 | 
			
		||||
    self._metric_functions = ['categorical_accuracy']
 | 
			
		||||
    self._optimizer = 'adam'
 | 
			
		||||
    self._callbacks = self._get_callbacks()
 | 
			
		||||
    self._history = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier):
 | 
			
		|||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
			
		||||
    self._loss_function = tf.keras.losses.CategoricalCrossentropy(
 | 
			
		||||
        label_smoothing=self._hparams.label_smoothing)
 | 
			
		||||
    self._metric_function = 'accuracy'
 | 
			
		||||
    self._metric_functions = ['accuracy']
 | 
			
		||||
    self._history = None  # Training history returned from `keras_model.fit`.
 | 
			
		||||
 | 
			
		||||
  @classmethod
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -101,11 +101,14 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
			
		|||
    )
 | 
			
		||||
    return model_config
 | 
			
		||||
 | 
			
		||||
  def _build_model(self) -> tf.keras.Model:
 | 
			
		||||
  def _build_model(self, omit_l2=False) -> tf.keras.Model:
 | 
			
		||||
    """Builds a RetinaNet object detector model."""
 | 
			
		||||
    input_specs = tf.keras.layers.InputSpec(
 | 
			
		||||
        shape=[None] + self._model_spec.input_image_shape
 | 
			
		||||
    )
 | 
			
		||||
    if omit_l2:
 | 
			
		||||
      l2_regularizer = None
 | 
			
		||||
    else:
 | 
			
		||||
      l2_regularizer = tf.keras.regularizers.l2(
 | 
			
		||||
          self._model_options.l2_weight_decay / 2.0
 | 
			
		||||
      )
 | 
			
		||||
| 
						 | 
				
			
			@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
			
		|||
 | 
			
		||||
  def convert_to_qat(self) -> None:
 | 
			
		||||
    """Converts the model to a QAT RetinaNet model."""
 | 
			
		||||
    model = self._build_model()
 | 
			
		||||
    model = self._build_model(omit_l2=True)
 | 
			
		||||
    dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
 | 
			
		||||
    model(dummy_input, training=True)
 | 
			
		||||
    model.set_weights(self._model.get_weights())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,6 +43,7 @@ cc_library(
 | 
			
		|||
        ":base_audio_task_api",
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -27,6 +27,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "mediapipe/framework/calculator.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
| 
						 | 
				
			
			@ -60,13 +61,8 @@ class AudioTaskApiFactory {
 | 
			
		|||
            "Task graph config should only contain one task subgraph node.",
 | 
			
		||||
            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
      } else {
 | 
			
		||||
        if (!node.options().HasExtension(Options::ext)) {
 | 
			
		||||
          return CreateStatusWithPayload(
 | 
			
		||||
              absl::StatusCode::kInvalidArgument,
 | 
			
		||||
              absl::StrCat(node.calculator(),
 | 
			
		||||
                           " is missing the required task options field."),
 | 
			
		||||
              MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
        }
 | 
			
		||||
        MP_RETURN_IF_ERROR(
 | 
			
		||||
            tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
 | 
			
		||||
        found_task_subgraph = true;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
 | 
			
		||||
        "@com_google_absl//absl/log",
 | 
			
		||||
        "@com_google_absl//absl/memory",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,15 +17,56 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
#include "absl/log/log.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
namespace core {
 | 
			
		||||
 | 
			
		||||
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
 | 
			
		||||
    const BaseOptions::CpuOptions& options) {
 | 
			
		||||
  proto::Acceleration acceleration_proto = proto::Acceleration();
 | 
			
		||||
  acceleration_proto.mutable_tflite();
 | 
			
		||||
  return acceleration_proto;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
 | 
			
		||||
    const BaseOptions::GpuOptions& options) {
 | 
			
		||||
  proto::Acceleration acceleration_proto = proto::Acceleration();
 | 
			
		||||
  auto* gpu = acceleration_proto.mutable_gpu();
 | 
			
		||||
  gpu->set_use_advanced_gpu_api(true);
 | 
			
		||||
  gpu->set_cached_kernel_path(options.cached_kernel_path);
 | 
			
		||||
  gpu->set_serialized_model_dir(options.serialized_model_dir);
 | 
			
		||||
  gpu->set_model_token(options.model_token);
 | 
			
		||||
  return acceleration_proto;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void SetDelegateOptionsOrDie(const BaseOptions* base_options,
 | 
			
		||||
                             proto::BaseOptions& base_options_proto) {
 | 
			
		||||
  if (base_options->delegate_options.has_value()) {
 | 
			
		||||
    if (!std::holds_alternative<T>(*base_options->delegate_options)) {
 | 
			
		||||
      LOG(FATAL) << "Specified Delegate type does not match the provided "
 | 
			
		||||
                    "delegate options.";
 | 
			
		||||
    } else {
 | 
			
		||||
      std::visit(
 | 
			
		||||
          [&base_options_proto](const auto& delegate_options) {
 | 
			
		||||
            proto::Acceleration acceleration_proto =
 | 
			
		||||
                ConvertDelegateOptionsToAccelerationProto(delegate_options);
 | 
			
		||||
            base_options_proto.mutable_acceleration()->Swap(
 | 
			
		||||
                &acceleration_proto);
 | 
			
		||||
          },
 | 
			
		||||
          *base_options->delegate_options);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
			
		||||
  proto::BaseOptions base_options_proto;
 | 
			
		||||
  if (!base_options->model_asset_path.empty()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -53,11 +94,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
			
		|||
  switch (base_options->delegate) {
 | 
			
		||||
    case BaseOptions::Delegate::CPU:
 | 
			
		||||
      base_options_proto.mutable_acceleration()->mutable_tflite();
 | 
			
		||||
      SetDelegateOptionsOrDie<BaseOptions::CpuOptions>(base_options,
 | 
			
		||||
                                                       base_options_proto);
 | 
			
		||||
      break;
 | 
			
		||||
    case BaseOptions::Delegate::GPU:
 | 
			
		||||
      base_options_proto.mutable_acceleration()
 | 
			
		||||
          ->mutable_gpu()
 | 
			
		||||
          ->set_use_advanced_gpu_api(true);
 | 
			
		||||
      SetDelegateOptionsOrDie<BaseOptions::GpuOptions>(base_options,
 | 
			
		||||
                                                       base_options_proto);
 | 
			
		||||
      break;
 | 
			
		||||
    case BaseOptions::Delegate::EDGETPU_NNAPI:
 | 
			
		||||
      base_options_proto.mutable_acceleration()
 | 
			
		||||
| 
						 | 
				
			
			@ -65,7 +110,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
			
		|||
          ->set_accelerator_name("google-edgetpu");
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return base_options_proto;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace core
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,7 +17,9 @@ limitations under the License.
 | 
			
		|||
#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
#include "absl/memory/memory.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -38,7 +40,8 @@ struct BaseOptions {
 | 
			
		|||
  std::string model_asset_path = "";
 | 
			
		||||
 | 
			
		||||
  // The delegate to run MediaPipe. If the delegate is not set, the default
 | 
			
		||||
  // delegate CPU is used.
 | 
			
		||||
  // delegate CPU is used. Use `delegate_options` to configure advanced
 | 
			
		||||
  // features of the selected delegate."
 | 
			
		||||
  enum Delegate {
 | 
			
		||||
    CPU = 0,
 | 
			
		||||
    GPU = 1,
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +51,30 @@ struct BaseOptions {
 | 
			
		|||
 | 
			
		||||
  Delegate delegate = CPU;
 | 
			
		||||
 | 
			
		||||
  // Options for CPU.
 | 
			
		||||
  struct CpuOptions {};
 | 
			
		||||
 | 
			
		||||
  // Options for GPU.
 | 
			
		||||
  struct GpuOptions {
 | 
			
		||||
    // Load pre-compiled serialized binary cache to accelerate init process.
 | 
			
		||||
    // Only available on Android. Kernel caching will only be enabled if this
 | 
			
		||||
    // path is set. NOTE: binary cache usage may be skipped if valid serialized
 | 
			
		||||
    // model, specified by "serialized_model_dir", exists.
 | 
			
		||||
    std::string cached_kernel_path;
 | 
			
		||||
 | 
			
		||||
    // A dir to load from and save to a pre-compiled serialized model used to
 | 
			
		||||
    // accelerate init process.
 | 
			
		||||
    // NOTE: serialized model takes precedence over binary cache
 | 
			
		||||
    // specified by "cached_kernel_path", which still can be used if
 | 
			
		||||
    // serialized model is invalid or missing.
 | 
			
		||||
    std::string serialized_model_dir;
 | 
			
		||||
 | 
			
		||||
    // Unique token identifying the model. Used in conjunction with
 | 
			
		||||
    // "serialized_model_dir". It is the caller's responsibility to ensure
 | 
			
		||||
    // there is no clash of the tokens.
 | 
			
		||||
    std::string model_token;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // The file descriptor to a file opened with open(2), with optional additional
 | 
			
		||||
  // offset and length information.
 | 
			
		||||
  struct FileDescriptorMeta {
 | 
			
		||||
| 
						 | 
				
			
			@ -67,6 +94,10 @@ struct BaseOptions {
 | 
			
		|||
  // built-in Ops.
 | 
			
		||||
  std::unique_ptr<tflite::OpResolver> op_resolver =
 | 
			
		||||
      absl::make_unique<MediaPipeBuiltinOpResolver>();
 | 
			
		||||
 | 
			
		||||
  // Options for the chosen delegate. If not set, the default delegate options
 | 
			
		||||
  // is used.
 | 
			
		||||
  std::optional<std::variant<CpuOptions, GpuOptions>> delegate_options;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Converts a BaseOptions to a BaseOptionsProto.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,9 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -11,6 +14,8 @@
 | 
			
		|||
 | 
			
		||||
constexpr char kTestModelBundlePath[] =
 | 
			
		||||
    "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
 | 
			
		||||
constexpr char kCachedModelDir[] = "/data/local/tmp";
 | 
			
		||||
constexpr char kModelToken[] = "dummy_model_token";
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace tasks {
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +45,44 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) {
 | 
			
		|||
  EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(DelegateOptionsTest, SucceedCpuOptions) {
 | 
			
		||||
  BaseOptions base_options;
 | 
			
		||||
  base_options.delegate = BaseOptions::Delegate::CPU;
 | 
			
		||||
  BaseOptions::CpuOptions cpu_options;
 | 
			
		||||
  base_options.delegate_options = cpu_options;
 | 
			
		||||
  proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
 | 
			
		||||
  EXPECT_TRUE(proto.acceleration().has_tflite());
 | 
			
		||||
  ASSERT_FALSE(proto.acceleration().has_gpu());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(DelegateOptionsTest, SucceedGpuOptions) {
 | 
			
		||||
  BaseOptions base_options;
 | 
			
		||||
  base_options.delegate = BaseOptions::Delegate::GPU;
 | 
			
		||||
  BaseOptions::GpuOptions gpu_options;
 | 
			
		||||
  gpu_options.cached_kernel_path = kCachedModelDir;
 | 
			
		||||
  gpu_options.model_token = kModelToken;
 | 
			
		||||
  base_options.delegate_options = gpu_options;
 | 
			
		||||
  proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
 | 
			
		||||
  ASSERT_TRUE(proto.acceleration().has_gpu());
 | 
			
		||||
  ASSERT_FALSE(proto.acceleration().has_tflite());
 | 
			
		||||
  EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
 | 
			
		||||
  EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir);
 | 
			
		||||
  EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) {
 | 
			
		||||
  BaseOptions base_options;
 | 
			
		||||
  base_options.delegate = BaseOptions::Delegate::CPU;
 | 
			
		||||
  BaseOptions::GpuOptions gpu_options;
 | 
			
		||||
  gpu_options.cached_kernel_path = kCachedModelDir;
 | 
			
		||||
  gpu_options.model_token = kModelToken;
 | 
			
		||||
  base_options.delegate_options = gpu_options;
 | 
			
		||||
  ASSERT_DEATH(
 | 
			
		||||
      { proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); },
 | 
			
		||||
      "Specified Delegate type does not match the provided "
 | 
			
		||||
      "delegate options.");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace core
 | 
			
		||||
}  // namespace tasks
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -81,7 +81,6 @@ class TaskApiFactory {
 | 
			
		|||
    return std::make_unique<T>(std::move(runner));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  template <typename Options>
 | 
			
		||||
  static absl::Status CheckHasValidOptions(
 | 
			
		||||
      const CalculatorGraphConfig::Node& node) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -86,10 +86,9 @@ cc_test(
 | 
			
		|||
        "//mediapipe/tasks/cc/components/containers:classification_result",
 | 
			
		||||
        "@com_google_absl//absl/flags:flag",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_absl//absl/strings:cord",
 | 
			
		||||
        "@com_google_sentencepiece//src:sentencepiece_processor",
 | 
			
		||||
        "@com_google_sentencepiece//src:sentencepiece_processor",  # fixdeps: keep
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite:test_util",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,8 +15,6 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
#include <string>
 | 
			
		||||
| 
						 | 
				
			
			@ -24,7 +22,6 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
#include "absl/flags/flag.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/cord.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,7 +45,7 @@ constexpr char kUniversalSentenceEncoderModel[] =
 | 
			
		|||
// Tolerance for embedding vector coordinate values.
 | 
			
		||||
constexpr float kEpsilon = 1e-4;
 | 
			
		||||
// Tolerancy for cosine similarity evaluation.
 | 
			
		||||
constexpr double kSimilarityTolerancy = 1e-6;
 | 
			
		||||
constexpr double kSimilarityTolerancy = 2e-2;
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::file::JoinPath;
 | 
			
		||||
using ::testing::HasSubstr;
 | 
			
		||||
| 
						 | 
				
			
			@ -79,6 +79,8 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
 | 
			
		|||
  ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
 | 
			
		||||
#ifdef _WIN32
 | 
			
		||||
  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
 | 
			
		||||
#elif defined(__FMA__)
 | 
			
		||||
  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.3605f, kEpsilon);
 | 
			
		||||
#else
 | 
			
		||||
  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
 | 
			
		||||
#endif  // _WIN32
 | 
			
		||||
| 
						 | 
				
			
			@ -87,7 +89,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
 | 
			
		|||
      auto result1, text_embedder->Embed("what a great and fantastic trip"));
 | 
			
		||||
  ASSERT_EQ(result1.embeddings.size(), 1);
 | 
			
		||||
  ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
 | 
			
		||||
#ifdef __FMA__
 | 
			
		||||
  ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 21.254150f, kEpsilon);
 | 
			
		||||
#else
 | 
			
		||||
  ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Check cosine similarity.
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,6 +43,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:rect",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +59,7 @@ cc_library(
 | 
			
		|||
        ":base_vision_task_api",
 | 
			
		||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,6 +26,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "mediapipe/framework/calculator.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
 | 
			
		||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -60,13 +61,8 @@ class VisionTaskApiFactory {
 | 
			
		|||
            "Task graph config should only contain one task subgraph node.",
 | 
			
		||||
            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
      } else {
 | 
			
		||||
        if (!node.options().HasExtension(Options::ext)) {
 | 
			
		||||
          return CreateStatusWithPayload(
 | 
			
		||||
              absl::StatusCode::kInvalidArgument,
 | 
			
		||||
              absl::StrCat(node.calculator(),
 | 
			
		||||
                           " is missing the required task options field."),
 | 
			
		||||
              MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
			
		||||
        }
 | 
			
		||||
        MP_RETURN_IF_ERROR(
 | 
			
		||||
            tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
 | 
			
		||||
        found_task_subgraph = true;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -153,6 +153,8 @@ cc_library(
 | 
			
		|||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# TODO: open source hand joints graph
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "hand_landmarker_result",
 | 
			
		||||
    srcs = ["hand_landmarker_result.cc"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,3 +41,5 @@ mediapipe_proto_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# TODO: open source hand joints graph
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,6 +52,7 @@ cc_library(
 | 
			
		|||
    name = "interactive_segmenter_graph",
 | 
			
		||||
    srcs = ["interactive_segmenter_graph.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/image:image_transformation_calculator",
 | 
			
		||||
        "//mediapipe/calculators/image:set_alpha_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:annotation_overlay_calculator",
 | 
			
		||||
        "//mediapipe/calculators/util:flat_color_image_calculator",
 | 
			
		||||
| 
						 | 
				
			
			@ -60,6 +61,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/calculators/util:to_image_calculator",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/api2:node",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -35,6 +37,51 @@ namespace mediapipe {
 | 
			
		|||
namespace tasks {
 | 
			
		||||
namespace vision {
 | 
			
		||||
namespace interactive_segmenter {
 | 
			
		||||
namespace internal {
 | 
			
		||||
 | 
			
		||||
// A calculator to add thickness to the render data according to the image size,
 | 
			
		||||
// so that the render data is scale invariant to the image size. If the render
 | 
			
		||||
// data already has thickness, it will be kept as is.
 | 
			
		||||
class AddThicknessToRenderDataCalculator : public api2::Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr api2::Input<Image> kImageIn{"IMAGE"};
 | 
			
		||||
  static constexpr api2::Input<mediapipe::RenderData> kRenderDataIn{
 | 
			
		||||
      "RENDER_DATA"};
 | 
			
		||||
  static constexpr api2::Output<mediapipe::RenderData> kRenderDataOut{
 | 
			
		||||
      "RENDER_DATA"};
 | 
			
		||||
 | 
			
		||||
  static constexpr int kModelInputTensorWidth = 512;
 | 
			
		||||
  static constexpr int kModelInputTensorHeight = 512;
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
			
		||||
    mediapipe::RenderData render_data = kRenderDataIn(cc).Get();
 | 
			
		||||
    Image image = kImageIn(cc).Get();
 | 
			
		||||
    double thickness = std::max(
 | 
			
		||||
        std::max(image.width() / static_cast<double>(kModelInputTensorWidth),
 | 
			
		||||
                 image.height() / static_cast<double>(kModelInputTensorHeight)),
 | 
			
		||||
        1.0);
 | 
			
		||||
 | 
			
		||||
    for (auto& annotation : *render_data.mutable_render_annotations()) {
 | 
			
		||||
      if (!annotation.has_thickness()) {
 | 
			
		||||
        annotation.set_thickness(thickness);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    kRenderDataOut(cc).Send(render_data);
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
 | 
			
		||||
// moved to next line.
 | 
			
		||||
// clang-format off
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(
 | 
			
		||||
    ::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator);
 | 
			
		||||
// clang-format on
 | 
			
		||||
// NOLINTEND
 | 
			
		||||
 | 
			
		||||
}  // namespace internal
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -59,6 +106,7 @@ constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
 | 
			
		|||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
			
		||||
constexpr absl::string_view kRoiTag{"ROI"};
 | 
			
		||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
			
		||||
constexpr absl::string_view kRenderDataTag{"RENDER_DATA"};
 | 
			
		||||
 | 
			
		||||
// Updates the graph to return `roi` stream which has same dimension as
 | 
			
		||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
 | 
			
		||||
| 
						 | 
				
			
			@ -69,14 +117,23 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
 | 
			
		|||
  const absl::string_view image_tag_with_suffix =
 | 
			
		||||
      use_gpu ? kImageGpuTag : kImageCpuTag;
 | 
			
		||||
 | 
			
		||||
  // Adds thickness to the render data so that the render data is scale
 | 
			
		||||
  // invariant to the input image size.
 | 
			
		||||
  auto& add_thickness = graph.AddNode(
 | 
			
		||||
      "mediapipe::tasks::vision::interactive_segmenter::internal::"
 | 
			
		||||
      "AddThicknessToRenderDataCalculator");
 | 
			
		||||
  image >> add_thickness.In(kImageTag);
 | 
			
		||||
  roi >> add_thickness.In(kRenderDataTag);
 | 
			
		||||
  auto roi_with_thickness = add_thickness.Out(kRenderDataTag);
 | 
			
		||||
 | 
			
		||||
  // Generates a blank canvas with same size as input image.
 | 
			
		||||
  auto& flat_color = graph.AddNode("FlatColorImageCalculator");
 | 
			
		||||
  auto& flat_color_options =
 | 
			
		||||
      flat_color.GetOptions<FlatColorImageCalculatorOptions>();
 | 
			
		||||
  // SetAlphaCalculator only takes 1st channel.
 | 
			
		||||
  flat_color_options.mutable_color()->set_r(0);
 | 
			
		||||
  image >> flat_color.In(kImageTag)[0];
 | 
			
		||||
  auto blank_canvas = flat_color.Out(kImageTag)[0];
 | 
			
		||||
  image >> flat_color.In(kImageTag);
 | 
			
		||||
  auto blank_canvas = flat_color.Out(kImageTag);
 | 
			
		||||
 | 
			
		||||
  auto& from_mp_image = graph.AddNode("FromImageCalculator");
 | 
			
		||||
  blank_canvas >> from_mp_image.In(kImageTag);
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +142,7 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
 | 
			
		|||
  auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
 | 
			
		||||
  blank_canvas_in_cpu_or_gpu >>
 | 
			
		||||
      roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
 | 
			
		||||
  roi >> roi_to_alpha.In(0);
 | 
			
		||||
  roi_with_thickness >> roi_to_alpha.In(0);
 | 
			
		||||
  auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
 | 
			
		||||
 | 
			
		||||
  return alpha;
 | 
			
		||||
| 
						 | 
				
			
			@ -163,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
 | 
			
		|||
    image >> from_mp_image.In(kImageTag);
 | 
			
		||||
    auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
 | 
			
		||||
 | 
			
		||||
    // Creates an RGBA image with model input tensor size.
 | 
			
		||||
    auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
 | 
			
		||||
 | 
			
		||||
    auto& set_alpha = graph.AddNode("SetAlphaCalculator");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,6 +34,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/port/opencv_core_inc.h"
 | 
			
		||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
#include "mediapipe/framework/tool/test_util.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -70,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
 | 
			
		|||
// Golden mask for the dogs in cats_and_dogs.jpg.
 | 
			
		||||
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
 | 
			
		||||
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
 | 
			
		||||
constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"};
 | 
			
		||||
constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"};
 | 
			
		||||
constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"};
 | 
			
		||||
constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"};
 | 
			
		||||
 | 
			
		||||
constexpr float kGoldenMaskSimilarity = 0.97;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -183,6 +188,7 @@ struct InteractiveSegmenterTestParams {
 | 
			
		|||
  std::string test_name;
 | 
			
		||||
  RegionOfInterest::Format format;
 | 
			
		||||
  std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
 | 
			
		||||
  absl::string_view input_image_file;
 | 
			
		||||
  absl::string_view golden_mask_file;
 | 
			
		||||
  float similarity_threshold;
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			@ -220,8 +226,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
 | 
			
		|||
  const InteractiveSegmenterTestParams& params = GetParam();
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      Image image,
 | 
			
		||||
      DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
 | 
			
		||||
      Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
 | 
			
		||||
                                                params.input_image_file)));
 | 
			
		||||
  auto options = std::make_unique<InteractiveSegmenterOptions>();
 | 
			
		||||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kPtmModel);
 | 
			
		||||
| 
						 | 
				
			
			@ -244,6 +250,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
 | 
			
		|||
  EXPECT_THAT(actual_mask,
 | 
			
		||||
              SimilarToUint8Mask(expected_mask, params.similarity_threshold,
 | 
			
		||||
                                 kGoldenMaskMagnificationFactor));
 | 
			
		||||
 | 
			
		||||
  cv::Mat visualized_mask;
 | 
			
		||||
  actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
 | 
			
		||||
  ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
 | 
			
		||||
                              visualized_mask.cols, visualized_mask.rows,
 | 
			
		||||
                              visualized_mask.step, visualized_mask.data,
 | 
			
		||||
                              [visualized_mask](uint8_t[]) {});
 | 
			
		||||
  MP_EXPECT_OK(SavePngTestOutput(
 | 
			
		||||
      visualized_image, absl::StrFormat("%s_category_mask", params.test_name)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
			
		||||
| 
						 | 
				
			
			@ -252,8 +267,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
			
		|||
  const InteractiveSegmenterTestParams& params = GetParam();
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      Image image,
 | 
			
		||||
      DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
 | 
			
		||||
      Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
 | 
			
		||||
                                                params.input_image_file)));
 | 
			
		||||
  auto options = std::make_unique<InteractiveSegmenterOptions>();
 | 
			
		||||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kPtmModel);
 | 
			
		||||
| 
						 | 
				
			
			@ -275,6 +290,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
			
		|||
      result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
 | 
			
		||||
  EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
 | 
			
		||||
                                              params.similarity_threshold));
 | 
			
		||||
  cv::Mat visualized_mask;
 | 
			
		||||
  actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
 | 
			
		||||
  ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
 | 
			
		||||
                              visualized_mask.cols, visualized_mask.rows,
 | 
			
		||||
                              visualized_mask.step, visualized_mask.data,
 | 
			
		||||
                              [visualized_mask](uint8_t[]) {});
 | 
			
		||||
  MP_EXPECT_OK(SavePngTestOutput(
 | 
			
		||||
      visualized_image,
 | 
			
		||||
      absl::StrFormat("%s_confidence_mask", params.test_name)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
| 
						 | 
				
			
			@ -282,21 +306,28 @@ INSTANTIATE_TEST_SUITE_P(
 | 
			
		|||
    ::testing::ValuesIn<InteractiveSegmenterTestParams>(
 | 
			
		||||
        {// Keypoint input.
 | 
			
		||||
         {"PointToDog1", RegionOfInterest::Format::kKeyPoint,
 | 
			
		||||
          NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
 | 
			
		||||
          NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1,
 | 
			
		||||
          0.84f},
 | 
			
		||||
         {"PointToDog2", RegionOfInterest::Format::kKeyPoint,
 | 
			
		||||
          NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
 | 
			
		||||
          NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2,
 | 
			
		||||
          kGoldenMaskSimilarity},
 | 
			
		||||
         {"PenguinsSmall", RegionOfInterest::Format::kKeyPoint,
 | 
			
		||||
          NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask,
 | 
			
		||||
          0.9f},
 | 
			
		||||
         {"PenguinsLarge", RegionOfInterest::Format::kKeyPoint,
 | 
			
		||||
          NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask,
 | 
			
		||||
          0.9f},
 | 
			
		||||
         // Scribble input.
 | 
			
		||||
         {"ScribbleToDog1", RegionOfInterest::Format::kScribble,
 | 
			
		||||
          std::vector{NormalizedKeypoint{0.44, 0.70},
 | 
			
		||||
                      NormalizedKeypoint{0.44, 0.71},
 | 
			
		||||
                      NormalizedKeypoint{0.44, 0.72}},
 | 
			
		||||
          kCatsAndDogsMaskDog1, 0.84f},
 | 
			
		||||
          kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f},
 | 
			
		||||
         {"ScribbleToDog2", RegionOfInterest::Format::kScribble,
 | 
			
		||||
          std::vector{NormalizedKeypoint{0.66, 0.66},
 | 
			
		||||
                      NormalizedKeypoint{0.66, 0.67},
 | 
			
		||||
                      NormalizedKeypoint{0.66, 0.68}},
 | 
			
		||||
          kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
 | 
			
		||||
          kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
 | 
			
		||||
    [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
 | 
			
		||||
           info) { return info.param.test_name; });
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -60,3 +60,9 @@ objc_library(
 | 
			
		|||
    srcs = ["sources/MPPLandmark.m"],
 | 
			
		||||
    hdrs = ["sources/MPPLandmark.h"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPConnection",
 | 
			
		||||
    srcs = ["sources/MPPConnection.m"],
 | 
			
		||||
    hdrs = ["sources/MPPConnection.h"],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <Foundation/Foundation.h>
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
 | 
			
		||||
/** The value class representing a landmark connection. */
 | 
			
		||||
NS_SWIFT_NAME(Connection)
 | 
			
		||||
@interface MPPConnection : NSObject
 | 
			
		||||
 | 
			
		||||
@property(nonatomic, readonly) NSUInteger start;
 | 
			
		||||
 | 
			
		||||
@property(nonatomic, readonly) NSUInteger end;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Initializes a new `MPPConnection` with the start and end landmarks integer constants.
 | 
			
		||||
 *
 | 
			
		||||
 * @param start The integer representing the starting landmark of the connection.
 | 
			
		||||
 * @param end The integer representing the ending landmark of the connection.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An instance of `MPPConnection` initialized with the given start and end landmarks integer
 | 
			
		||||
 * constants.
 | 
			
		||||
 */
 | 
			
		||||
- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end NS_DESIGNATED_INITIALIZER;
 | 
			
		||||
 | 
			
		||||
- (instancetype)init NS_UNAVAILABLE;
 | 
			
		||||
 | 
			
		||||
+ (instancetype)new NS_UNAVAILABLE;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_END
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
 | 
			
		||||
 | 
			
		||||
@implementation MPPConnection
 | 
			
		||||
 | 
			
		||||
- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end {
 | 
			
		||||
  self = [super init];
 | 
			
		||||
  if (self) {
 | 
			
		||||
    _start = start;
 | 
			
		||||
    _end = end;
 | 
			
		||||
  }
 | 
			
		||||
  return self;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -54,3 +54,20 @@ ios_unit_test(
 | 
			
		|||
        ":MPPImageObjcTestLibrary",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPMaskObjcTestLibrary",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["MPPMaskTests.m"],
 | 
			
		||||
    deps = ["//mediapipe/tasks/ios/vision/core:MPPMask"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ios_unit_test(
 | 
			
		||||
    name = "MPPMaskObjcTest",
 | 
			
		||||
    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
			
		||||
    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
			
		||||
    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPMaskObjcTestLibrary",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										127
									
								
								mediapipe/tasks/ios/test/vision/core/MPPMaskTests.m
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								mediapipe/tasks/ios/test/vision/core/MPPMaskTests.m
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,127 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
 | 
			
		||||
 | 
			
		||||
#import <XCTest/XCTest.h>
 | 
			
		||||
 | 
			
		||||
/** Unit tests for `MPPMask`. */
 | 
			
		||||
@interface MPPMaskTests : XCTestCase
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
@implementation MPPMaskTests
 | 
			
		||||
 | 
			
		||||
#pragma mark - Tests
 | 
			
		||||
 | 
			
		||||
- (void)testInitWithUInt8ArrayNoCopySucceeds {
 | 
			
		||||
 | 
			
		||||
  NSInteger width = 2;
 | 
			
		||||
  NSInteger height = 3;
 | 
			
		||||
 | 
			
		||||
  UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
 | 
			
		||||
  float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
 | 
			
		||||
 | 
			
		||||
  MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:NO];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(mask.width, width);
 | 
			
		||||
  XCTAssertEqual(mask.height, height);
 | 
			
		||||
 | 
			
		||||
  // Test if UInt8 mask is not copied.
 | 
			
		||||
  XCTAssertEqual(mask.uint8Data, (const UInt8*)uint8Data);
 | 
			
		||||
  XCTAssertNotEqual(mask.float32Data, NULL);
 | 
			
		||||
 | 
			
		||||
  for (int i = 0 ; i < width * height ; i ++) {
 | 
			
		||||
    XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f, @"index i = %d", i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Test if repeated Float32 mask accesses return the same array in memory.
 | 
			
		||||
  XCTAssertEqual(mask.float32Data, mask.float32Data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testInitWithUInt8ArrayCopySucceeds {
 | 
			
		||||
 | 
			
		||||
  NSInteger width = 2;
 | 
			
		||||
  NSInteger height = 3;
 | 
			
		||||
 | 
			
		||||
  UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
 | 
			
		||||
  float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
 | 
			
		||||
 | 
			
		||||
  MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:YES];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(mask.width, width);
 | 
			
		||||
  XCTAssertEqual(mask.height, height);
 | 
			
		||||
 | 
			
		||||
  // Test if UInt8 mask is copied.
 | 
			
		||||
  XCTAssertNotEqual(mask.uint8Data, (const UInt8*)uint8Data);
 | 
			
		||||
  XCTAssertNotEqual(mask.float32Data, NULL);
 | 
			
		||||
 | 
			
		||||
  for (int i = 0 ; i < width * height ; i ++) {
 | 
			
		||||
    XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Test if repeated Float32 mask accesses return the same array in memory.
 | 
			
		||||
  XCTAssertEqual(mask.float32Data, mask.float32Data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testInitWithFloat32ArrayNoCopySucceeds {
 | 
			
		||||
 | 
			
		||||
  NSInteger width = 2;
 | 
			
		||||
  NSInteger height = 3;
 | 
			
		||||
 | 
			
		||||
  UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
 | 
			
		||||
  float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
 | 
			
		||||
  MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:NO];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(mask.width, width);
 | 
			
		||||
  XCTAssertEqual(mask.height, height);
 | 
			
		||||
 | 
			
		||||
  // Test if Float32 mask is not copied.
 | 
			
		||||
  XCTAssertEqual(mask.float32Data, (const float*)float32Data);
 | 
			
		||||
  XCTAssertNotEqual(mask.uint8Data, NULL);
 | 
			
		||||
 | 
			
		||||
  for (int i = 0 ; i < width * height ; i ++) {
 | 
			
		||||
    XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Test if repeated UInt8 mask accesses return the same array in memory.
 | 
			
		||||
  XCTAssertEqual(mask.uint8Data, mask.uint8Data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testInitWithFloat32ArrayCopySucceeds {
 | 
			
		||||
 | 
			
		||||
  NSInteger width = 2;
 | 
			
		||||
  NSInteger height = 3;
 | 
			
		||||
 | 
			
		||||
  UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
 | 
			
		||||
  float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
 | 
			
		||||
 | 
			
		||||
  MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:YES];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(mask.width, width);
 | 
			
		||||
  XCTAssertEqual(mask.height, height);
 | 
			
		||||
 | 
			
		||||
  // Test if Float32 mask is copied.
 | 
			
		||||
  XCTAssertNotEqual(mask.float32Data, (const float*)float32Data);
 | 
			
		||||
  XCTAssertNotEqual(mask.uint8Data, NULL);
 | 
			
		||||
 | 
			
		||||
  for (int i = 0 ; i < width * height ; i ++) {
 | 
			
		||||
    XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Test if repeated UInt8 mask accesses return the same array in memory.
 | 
			
		||||
  XCTAssertEqual(mask.uint8Data, mask.uint8Data);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -155,12 +155,12 @@ static const float kKeypointErrorThreshold = 1e-2;
 | 
			
		|||
  NSInteger iterationCount = 100;
 | 
			
		||||
 | 
			
		||||
  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
			
		||||
  // normal expectation will fail if expectation.fullfill() is not called
 | 
			
		||||
  // normal expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
			
		||||
  // only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since it is not possible to predict how many times the expectation is supposed to be
 | 
			
		||||
  // fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
 | 
			
		||||
  // fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
 | 
			
		||||
  // `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
			
		||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
			
		||||
| 
						 | 
				
			
			@ -385,13 +385,13 @@ static const float kKeypointErrorThreshold = 1e-2;
 | 
			
		|||
  NSInteger iterationCount = 100;
 | 
			
		||||
 | 
			
		||||
  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
			
		||||
  // normal expectation will fail if expectation.fullfill() is not called times. An normal
 | 
			
		||||
  // expectation will fail if expectation.fullfill() is not called
 | 
			
		||||
  // normal expectation will fail if expectation.fulfill() is not called times. An normal
 | 
			
		||||
  // expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
			
		||||
  // only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since it it not possible to determine how many times the expectation is supposed to be
 | 
			
		||||
  // fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
 | 
			
		||||
  // fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
 | 
			
		||||
  // `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
			
		||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -174,12 +174,12 @@ constexpr float kFacialTransformationMatrixErrorThreshold = 0.2f;
 | 
			
		|||
  NSInteger iterationCount = 100;
 | 
			
		||||
 | 
			
		||||
  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
			
		||||
  // normal expectation will fail if expectation.fullfill() is not called
 | 
			
		||||
  // normal expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
			
		||||
  // only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since it is not possible to predict how many times the expectation is supposed to be
 | 
			
		||||
  // fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
 | 
			
		||||
  // fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
 | 
			
		||||
  // `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
			
		||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										62
									
								
								mediapipe/tasks/ios/test/vision/gesture_recognizer/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								mediapipe/tasks/ios/test/vision/gesture_recognizer/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,62 @@
 | 
			
		|||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
 | 
			
		||||
load(
 | 
			
		||||
    "//mediapipe/framework/tool:ios.bzl",
 | 
			
		||||
    "MPP_TASK_MINIMUM_OS_VERSION",
 | 
			
		||||
)
 | 
			
		||||
load(
 | 
			
		||||
    "@org_tensorflow//tensorflow/lite:special_rules.bzl",
 | 
			
		||||
    "tflite_ios_lab_runner",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
 | 
			
		||||
TFL_DEFAULT_TAGS = [
 | 
			
		||||
    "apple",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Following sanitizer tests are not supported by iOS test targets.
 | 
			
		||||
TFL_DISABLED_SANITIZER_TAGS = [
 | 
			
		||||
    "noasan",
 | 
			
		||||
    "nomsan",
 | 
			
		||||
    "notsan",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPGestureRecognizerObjcTestLibrary",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["MPPGestureRecognizerTests.m"],
 | 
			
		||||
    copts = [
 | 
			
		||||
        "-ObjC++",
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:gesture_recognizer.task",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_images",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_protos",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/gesture_recognizer/utils:MPPGestureRecognizerResultProtobufHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        "//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
			
		||||
        "//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
			
		||||
        "//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
			
		||||
        "//conditions:default": ["@ios_opencv//:OpencvFramework"],
 | 
			
		||||
    }),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ios_unit_test(
 | 
			
		||||
    name = "MPPGestureRecognizerObjcTest",
 | 
			
		||||
    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
			
		||||
    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
			
		||||
    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPGestureRecognizerObjcTestLibrary",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,706 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <XCTest/XCTest.h>
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizer.h"
 | 
			
		||||
 | 
			
		||||
static NSString *const kPbFileExtension = @"pbtxt";
 | 
			
		||||
 | 
			
		||||
typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
 | 
			
		||||
 | 
			
		||||
static ResourceFileInfo *const kGestureRecognizerBundleAssetFile =
 | 
			
		||||
    @{@"name" : @"gesture_recognizer", @"type" : @"task"};
 | 
			
		||||
 | 
			
		||||
static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kFistImage = @{@"name" : @"fist", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kPointingUpRotatedImage =
 | 
			
		||||
    @{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
 | 
			
		||||
 | 
			
		||||
static ResourceFileInfo *const kExpectedFistLandmarksFile =
 | 
			
		||||
    @{@"name" : @"fist_landmarks", @"type" : kPbFileExtension};
 | 
			
		||||
static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
 | 
			
		||||
    @{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
 | 
			
		||||
 | 
			
		||||
static NSString *const kFistLabel = @"Closed_Fist";
 | 
			
		||||
static NSString *const kExpectedThumbUpLabel = @"Thumb_Up";
 | 
			
		||||
static NSString *const kExpectedPointingUpLabel = @"Pointing_Up";
 | 
			
		||||
static NSString *const kRockLabel = @"Rock";
 | 
			
		||||
 | 
			
		||||
static const NSInteger kGestureExpectedIndex = -1;
 | 
			
		||||
 | 
			
		||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
			
		||||
static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		||||
 | 
			
		||||
static NSString *const kLiveStreamTestsDictGestureRecognizerKey = @"gesture_recognizer";
 | 
			
		||||
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		||||
 | 
			
		||||
#define AssertEqualErrors(error, expectedError)              \
 | 
			
		||||
  XCTAssertNotNil(error);                                    \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain); \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);            \
 | 
			
		||||
  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
			
		||||
 | 
			
		||||
#define AssertEqualGestures(gesture, expectedGesture, handIndex, gestureIndex)                  \
 | 
			
		||||
  XCTAssertEqual(gesture.index, kGestureExpectedIndex, @"hand index = %d gesture index j = %d", \
 | 
			
		||||
                 handIndex, gestureIndex);                                                      \
 | 
			
		||||
  XCTAssertEqualObjects(gesture.categoryName, expectedGesture.categoryName,                     \
 | 
			
		||||
                        @"hand index = %d gesture index j = %d", handIndex, gestureIndex);
 | 
			
		||||
 | 
			
		||||
#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex)   \
 | 
			
		||||
  XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance,            \
 | 
			
		||||
                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
 | 
			
		||||
  XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance,            \
 | 
			
		||||
                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
 | 
			
		||||
 | 
			
		||||
#define AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult) \
 | 
			
		||||
  XCTAssertTrue(gestureRecognizerResult.gestures.count == 0);         \
 | 
			
		||||
  XCTAssertTrue(gestureRecognizerResult.handedness.count == 0);       \
 | 
			
		||||
  XCTAssertTrue(gestureRecognizerResult.landmarks.count == 0);        \
 | 
			
		||||
  XCTAssertTrue(gestureRecognizerResult.worldLandmarks.count == 0);
 | 
			
		||||
 | 
			
		||||
@interface MPPGestureRecognizerTests : XCTestCase <MPPGestureRecognizerLiveStreamDelegate> {
 | 
			
		||||
  NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
 | 
			
		||||
  NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
 | 
			
		||||
}
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
@implementation MPPGestureRecognizerTests
 | 
			
		||||
 | 
			
		||||
#pragma mark Expected Results
 | 
			
		||||
 | 
			
		||||
+ (MPPGestureRecognizerResult *)emptyGestureRecognizerResult {
 | 
			
		||||
  return [[MPPGestureRecognizerResult alloc] initWithGestures:@[]
 | 
			
		||||
                                                   handedness:@[]
 | 
			
		||||
                                                    landmarks:@[]
 | 
			
		||||
                                               worldLandmarks:@[]
 | 
			
		||||
                                      timestampInMilliseconds:0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (MPPGestureRecognizerResult *)thumbUpGestureRecognizerResult {
 | 
			
		||||
  NSString *filePath =
 | 
			
		||||
      [MPPGestureRecognizerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
 | 
			
		||||
 | 
			
		||||
  return [MPPGestureRecognizerResult
 | 
			
		||||
      gestureRecognizerResultsFromProtobufFileWithName:filePath
 | 
			
		||||
                                          gestureLabel:kExpectedThumbUpLabel
 | 
			
		||||
                                 shouldRemoveZPosition:YES];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (MPPGestureRecognizerResult *)fistGestureRecognizerResultWithLabel:(NSString *)gestureLabel {
 | 
			
		||||
  NSString *filePath = [MPPGestureRecognizerTests filePathWithFileInfo:kExpectedFistLandmarksFile];
 | 
			
		||||
 | 
			
		||||
  return [MPPGestureRecognizerResult gestureRecognizerResultsFromProtobufFileWithName:filePath
 | 
			
		||||
                                                                         gestureLabel:gestureLabel
 | 
			
		||||
                                                                shouldRemoveZPosition:YES];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Assert Gesture Recognizer Results
 | 
			
		||||
 | 
			
		||||
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
 | 
			
		||||
    areApproximatelyEqualToExpectedMultiHandLandmarks:
 | 
			
		||||
        (NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
 | 
			
		||||
  XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
 | 
			
		||||
  if (multiHandLandmarks.count == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
 | 
			
		||||
  NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
 | 
			
		||||
  for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
 | 
			
		||||
    MPPNormalizedLandmark *landmark = topHandLandmarks[i];
 | 
			
		||||
    XCTAssertNotNil(landmark);
 | 
			
		||||
    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
 | 
			
		||||
    areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
 | 
			
		||||
        (NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
 | 
			
		||||
  XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
 | 
			
		||||
  if (expectedMultiHandWorldLandmarks.count == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
 | 
			
		||||
  NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
 | 
			
		||||
  for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
 | 
			
		||||
    MPPLandmark *landmark = topHandWorldLandmarks[i];
 | 
			
		||||
    XCTAssertNotNil(landmark);
 | 
			
		||||
    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertMultiHandGestures:(NSArray<NSArray<MPPCategory *> *> *)multiHandGestures
 | 
			
		||||
    areApproximatelyEqualToExpectedMultiHandGestures:
 | 
			
		||||
        (NSArray<NSArray<MPPCategory *> *> *)expectedMultiHandGestures {
 | 
			
		||||
  XCTAssertEqual(multiHandGestures.count, expectedMultiHandGestures.count);
 | 
			
		||||
  if (multiHandGestures.count == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSArray<MPPCategory *> *topHandGestures = multiHandGestures[0];
 | 
			
		||||
  NSArray<MPPCategory *> *expectedTopHandGestures = expectedMultiHandGestures[0];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(topHandGestures.count, expectedTopHandGestures.count);
 | 
			
		||||
  for (int i = 0; i < expectedTopHandGestures.count; i++) {
 | 
			
		||||
    MPPCategory *gesture = topHandGestures[i];
 | 
			
		||||
    XCTAssertNotNil(gesture);
 | 
			
		||||
    AssertEqualGestures(gesture, expectedTopHandGestures[i], 0, i);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertGestureRecognizerResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
 | 
			
		||||
    isApproximatelyEqualToExpectedResult:
 | 
			
		||||
        (MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
 | 
			
		||||
  [self assertMultiHandLandmarks:gestureRecognizerResult.landmarks
 | 
			
		||||
      areApproximatelyEqualToExpectedMultiHandLandmarks:expectedGestureRecognizerResult.landmarks];
 | 
			
		||||
  [self assertMultiHandWorldLandmarks:gestureRecognizerResult.worldLandmarks
 | 
			
		||||
      areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedGestureRecognizerResult
 | 
			
		||||
                                                                 .worldLandmarks];
 | 
			
		||||
  [self assertMultiHandGestures:gestureRecognizerResult.gestures
 | 
			
		||||
      areApproximatelyEqualToExpectedMultiHandGestures:expectedGestureRecognizerResult.gestures];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertResultsOfRecognizeImageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
			
		||||
                           usingGestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
 | 
			
		||||
       approximatelyEqualsGestureRecognizerResult:
 | 
			
		||||
           (MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
			
		||||
      [self recognizeImageWithFileInfo:fileInfo usingGestureRecognizer:gestureRecognizer];
 | 
			
		||||
  [self assertGestureRecognizerResult:gestureRecognizerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:expectedGestureRecognizerResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark File
 | 
			
		||||
 | 
			
		||||
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
			
		||||
  NSString *filePath = [MPPGestureRecognizerTests filePathWithName:fileInfo[@"name"]
 | 
			
		||||
                                                         extension:fileInfo[@"type"]];
 | 
			
		||||
  return filePath;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
 | 
			
		||||
  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
 | 
			
		||||
                                                                      ofType:extension];
 | 
			
		||||
  return filePath;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Gesture Recognizer Initializers
 | 
			
		||||
 | 
			
		||||
- (MPPGestureRecognizerOptions *)gestureRecognizerOptionsWithModelFileInfo:
 | 
			
		||||
    (ResourceFileInfo *)modelFileInfo {
 | 
			
		||||
  NSString *modelPath = [MPPGestureRecognizerTests filePathWithFileInfo:modelFileInfo];
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [[MPPGestureRecognizerOptions alloc] init];
 | 
			
		||||
  gestureRecognizerOptions.baseOptions.modelAssetPath = modelPath;
 | 
			
		||||
 | 
			
		||||
  return gestureRecognizerOptions;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPGestureRecognizer *)createGestureRecognizerWithOptionsSucceeds:
 | 
			
		||||
    (MPPGestureRecognizerOptions *)gestureRecognizerOptions {
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:nil];
 | 
			
		||||
  XCTAssertNotNil(gestureRecognizer);
 | 
			
		||||
 | 
			
		||||
  return gestureRecognizer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertCreateGestureRecognizerWithOptions:
 | 
			
		||||
            (MPPGestureRecognizerOptions *)gestureRecognizerOptions
 | 
			
		||||
                          failsWithExpectedError:(NSError *)expectedError {
 | 
			
		||||
  NSError *error = nil;
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:&error];
 | 
			
		||||
 | 
			
		||||
  XCTAssertNil(gestureRecognizer);
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Recognize Helpers
 | 
			
		||||
 | 
			
		||||
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
			
		||||
  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
 | 
			
		||||
                                              fileName:fileInfo[@"name"]
 | 
			
		||||
                                                ofType:fileInfo[@"type"]];
 | 
			
		||||
  XCTAssertNotNil(image);
 | 
			
		||||
 | 
			
		||||
  return image;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
			
		||||
                    orientation:(UIImageOrientation)orientation {
 | 
			
		||||
  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
 | 
			
		||||
                                              fileName:fileInfo[@"name"]
 | 
			
		||||
                                                ofType:fileInfo[@"type"]
 | 
			
		||||
                                           orientation:orientation];
 | 
			
		||||
  XCTAssertNotNil(image);
 | 
			
		||||
 | 
			
		||||
  return image;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPGestureRecognizerResult *)recognizeImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
 | 
			
		||||
                                    usingGestureRecognizer:
 | 
			
		||||
                                        (MPPGestureRecognizer *)gestureRecognizer {
 | 
			
		||||
  MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
 | 
			
		||||
                                                                                    error:nil];
 | 
			
		||||
  XCTAssertNotNil(gestureRecognizerResult);
 | 
			
		||||
 | 
			
		||||
  return gestureRecognizerResult;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark General Tests
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithModelPathSucceeds {
 | 
			
		||||
  NSString *modelPath =
 | 
			
		||||
      [MPPGestureRecognizerTests filePathWithFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [[MPPGestureRecognizer alloc] initWithModelPath:modelPath error:nil];
 | 
			
		||||
  XCTAssertNotNil(gestureRecognizer);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfRecognizeImageWithFileInfo:kThumbUpImage
 | 
			
		||||
                           usingGestureRecognizer:gestureRecognizer
 | 
			
		||||
       approximatelyEqualsGestureRecognizerResult:[MPPGestureRecognizerTests
 | 
			
		||||
                                                      thumbUpGestureRecognizerResult]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithEmptyResultsSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
			
		||||
      [self recognizeImageWithFileInfo:kNoHandsImage usingGestureRecognizer:gestureRecognizer];
 | 
			
		||||
  AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithScoreThresholdSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
			
		||||
      [self recognizeImageWithFileInfo:kThumbUpImage usingGestureRecognizer:gestureRecognizer];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizerResult *expectedGestureRecognizerResult =
 | 
			
		||||
      [MPPGestureRecognizerTests thumbUpGestureRecognizerResult];
 | 
			
		||||
 | 
			
		||||
  XCTAssertTrue(gestureRecognizerResult.gestures.count == 1);
 | 
			
		||||
  AssertEqualGestures(gestureRecognizerResult.gestures[0][0],
 | 
			
		||||
                      expectedGestureRecognizerResult.gestures[0][0], 0, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithNumHandsSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  const NSInteger numHands = 2;
 | 
			
		||||
  gestureRecognizerOptions.numHands = numHands;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
			
		||||
      [self recognizeImageWithFileInfo:kTwoHandsImage usingGestureRecognizer:gestureRecognizer];
 | 
			
		||||
 | 
			
		||||
  XCTAssertTrue(gestureRecognizerResult.handedness.count == numHands);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithRotationSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  gestureRecognizerOptions.numHands = 1;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
  MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
 | 
			
		||||
                                   orientation:UIImageOrientationRight];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
 | 
			
		||||
                                                                                    error:nil];
 | 
			
		||||
 | 
			
		||||
  XCTAssertNotNil(gestureRecognizerResult);
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(gestureRecognizerResult.gestures.count, 1);
 | 
			
		||||
  XCTAssertEqualObjects(gestureRecognizerResult.gestures[0][0].categoryName,
 | 
			
		||||
                        kExpectedPointingUpLabel);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithCannedGestureFistSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  gestureRecognizerOptions.numHands = 1;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfRecognizeImageWithFileInfo:kFistImage
 | 
			
		||||
                           usingGestureRecognizer:gestureRecognizer
 | 
			
		||||
       approximatelyEqualsGestureRecognizerResult:
 | 
			
		||||
           [MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithAllowGestureFistSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
 | 
			
		||||
 | 
			
		||||
  gestureRecognizerOptions.numHands = 1;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfRecognizeImageWithFileInfo:kFistImage
 | 
			
		||||
                           usingGestureRecognizer:gestureRecognizer
 | 
			
		||||
       approximatelyEqualsGestureRecognizerResult:
 | 
			
		||||
           [MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithDenyGestureFistSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
 | 
			
		||||
 | 
			
		||||
  gestureRecognizerOptions.numHands = 1;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
			
		||||
      [self recognizeImageWithFileInfo:kFistImage usingGestureRecognizer:gestureRecognizer];
 | 
			
		||||
  AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithPreferAllowlistOverDenylistSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
 | 
			
		||||
  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
 | 
			
		||||
 | 
			
		||||
  gestureRecognizerOptions.numHands = 1;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfRecognizeImageWithFileInfo:kFistImage
 | 
			
		||||
                           usingGestureRecognizer:gestureRecognizer
 | 
			
		||||
       approximatelyEqualsGestureRecognizerResult:
 | 
			
		||||
           [MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Running Mode Tests
 | 
			
		||||
 | 
			
		||||
- (void)testCreateGestureRecognizerFailsWithDelegateInNonLiveStreamMode {
 | 
			
		||||
  MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
 | 
			
		||||
  for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
 | 
			
		||||
    MPPGestureRecognizerOptions *options =
 | 
			
		||||
        [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
    options.runningMode = runningModesToTest[i];
 | 
			
		||||
    options.gestureRecognizerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
    [self assertCreateGestureRecognizerWithOptions:options
 | 
			
		||||
                            failsWithExpectedError:
 | 
			
		||||
                                [NSError
 | 
			
		||||
                                    errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                                           userInfo:@{
 | 
			
		||||
                                             NSLocalizedDescriptionKey :
 | 
			
		||||
                                                 @"The vision task is in image or video mode. The "
 | 
			
		||||
                                                 @"delegate must not be set in the task's options."
 | 
			
		||||
                                           }]];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testCreateGestureRecognizerFailsWithMissingDelegateInLiveStreamMode {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
 | 
			
		||||
  [self
 | 
			
		||||
      assertCreateGestureRecognizerWithOptions:options
 | 
			
		||||
                        failsWithExpectedError:
 | 
			
		||||
                            [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                                                code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                                            userInfo:@{
 | 
			
		||||
                                              NSLocalizedDescriptionKey :
 | 
			
		||||
                                                  @"The vision task is in live stream mode. An "
 | 
			
		||||
                                                  @"object must be set as the delegate of the task "
 | 
			
		||||
                                                  @"in its options to ensure asynchronous delivery "
 | 
			
		||||
                                                  @"of results."
 | 
			
		||||
                                            }]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeFailsWithCallingWrongApiInImageMode {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kFistImage];
 | 
			
		||||
 | 
			
		||||
  NSError *liveStreamApiCallError;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
 | 
			
		||||
                                timestampInMilliseconds:0
 | 
			
		||||
                                                  error:&liveStreamApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedLiveStreamApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
			
		||||
                                                    @"stream mode. Current Running Mode: Image"
 | 
			
		||||
                      }];
 | 
			
		||||
 | 
			
		||||
  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
			
		||||
 | 
			
		||||
  NSError *videoApiCallError;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
 | 
			
		||||
                                timestampInMilliseconds:0
 | 
			
		||||
                                                  error:&videoApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedVideoApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"video mode. Current Running Mode: Image"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeFailsWithCallingWrongApiInVideoMode {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeVideo;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kFistImage];
 | 
			
		||||
 | 
			
		||||
  NSError *liveStreamApiCallError;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
 | 
			
		||||
                                timestampInMilliseconds:0
 | 
			
		||||
                                                  error:&liveStreamApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedLiveStreamApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
			
		||||
                                                    @"stream mode. Current Running Mode: Video"
 | 
			
		||||
                      }];
 | 
			
		||||
 | 
			
		||||
  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
			
		||||
 | 
			
		||||
  NSError *imageApiCallError;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedImageApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"image mode. Current Running Mode: Video"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeFailsWithCallingWrongApiInLiveStreamMode {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
  options.gestureRecognizerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kFistImage];
 | 
			
		||||
 | 
			
		||||
  NSError *imageApiCallError;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedImageApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"image mode. Current Running Mode: Live Stream"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
			
		||||
 | 
			
		||||
  NSError *videoApiCallError;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
 | 
			
		||||
                                timestampInMilliseconds:0
 | 
			
		||||
                                                  error:&videoApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedVideoApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"video mode. Current Running Mode: Live Stream"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithVideoModeSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeVideo;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < 3; i++) {
 | 
			
		||||
    MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
			
		||||
        [gestureRecognizer recognizeVideoFrame:image timestampInMilliseconds:i error:nil];
 | 
			
		||||
    [self assertGestureRecognizerResult:gestureRecognizerResult
 | 
			
		||||
        isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
 | 
			
		||||
                                                 thumbUpGestureRecognizerResult]];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithOutOfOrderTimestampsAndLiveStreamModeFails {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
  options.gestureRecognizerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
			
		||||
      initWithDescription:@"recognizeWithOutOfOrderTimestampsAndLiveStream"];
 | 
			
		||||
 | 
			
		||||
  expectation.expectedFulfillmentCount = 1;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  _outOfOrderTimestampTestDict = @{
 | 
			
		||||
    kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
 | 
			
		||||
    kLiveStreamTestsDictExpectationKey : expectation
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image timestampInMilliseconds:1 error:nil]);
 | 
			
		||||
 | 
			
		||||
  NSError *error;
 | 
			
		||||
  XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
 | 
			
		||||
                                timestampInMilliseconds:0
 | 
			
		||||
                                                  error:&error]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey :
 | 
			
		||||
                            @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
 | 
			
		||||
  NSTimeInterval timeout = 0.5f;
 | 
			
		||||
  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testRecognizeWithLiveStreamModeSucceeds {
 | 
			
		||||
  MPPGestureRecognizerOptions *options =
 | 
			
		||||
      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
  options.gestureRecognizerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
  NSInteger iterationCount = 100;
 | 
			
		||||
 | 
			
		||||
  // Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
 | 
			
		||||
  // times. An normal expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
			
		||||
  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since in our case we cannot predict how many times the expectation is supposed to be fulfilled
 | 
			
		||||
  // setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds ifexpectation is fulfilled <=
 | 
			
		||||
  // `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation =
 | 
			
		||||
      [[XCTestExpectation alloc] initWithDescription:@"recognizeWithLiveStream"];
 | 
			
		||||
 | 
			
		||||
  expectation.expectedFulfillmentCount = iterationCount + 1;
 | 
			
		||||
  expectation.inverted = YES;
 | 
			
		||||
 | 
			
		||||
  MPPGestureRecognizer *gestureRecognizer =
 | 
			
		||||
      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  _liveStreamSucceedsTestDict = @{
 | 
			
		||||
    kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
 | 
			
		||||
    kLiveStreamTestsDictExpectationKey : expectation
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
 | 
			
		||||
  // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
 | 
			
		||||
  // `CMSampleBuffer`.
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < iterationCount; i++) {
 | 
			
		||||
    XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image
 | 
			
		||||
                                 timestampInMilliseconds:i
 | 
			
		||||
                                                   error:nil]);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSTimeInterval timeout = 0.5f;
 | 
			
		||||
  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)gestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
 | 
			
		||||
    didFinishRecognitionWithResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
 | 
			
		||||
           timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                             error:(NSError *)error {
 | 
			
		||||
  [self assertGestureRecognizerResult:gestureRecognizerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
 | 
			
		||||
                                               thumbUpGestureRecognizerResult]];
 | 
			
		||||
 | 
			
		||||
  if (gestureRecognizer == _outOfOrderTimestampTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
 | 
			
		||||
    [_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
			
		||||
  } else if (gestureRecognizer ==
 | 
			
		||||
             _liveStreamSucceedsTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
 | 
			
		||||
    [_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPGestureRecognizerResultProtobufHelpers",
 | 
			
		||||
    srcs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.mm"],
 | 
			
		||||
    hdrs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.h"],
 | 
			
		||||
    copts = [
 | 
			
		||||
        "-ObjC++",
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizerResult",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <Foundation/Foundation.h>
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.h"
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
@interface MPPGestureRecognizerResult (ProtobufHelpers)
 | 
			
		||||
 | 
			
		||||
+ (MPPGestureRecognizerResult *)
 | 
			
		||||
    gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
                                        gestureLabel:(NSString *)gestureLabel
 | 
			
		||||
                               shouldRemoveZPosition:(BOOL)removeZPosition;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_END
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,65 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+Helpers.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
using ClassificationListProto = ::mediapipe::ClassificationList;
 | 
			
		||||
using ClassificationProto = ::mediapipe::Classification;
 | 
			
		||||
using LandmarksDetectionResultProto =
 | 
			
		||||
    ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
 | 
			
		||||
using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
 | 
			
		||||
@implementation MPPGestureRecognizerResult (ProtobufHelpers)
 | 
			
		||||
 | 
			
		||||
+ (MPPGestureRecognizerResult *)
 | 
			
		||||
    gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
                                        gestureLabel:(NSString *)gestureLabel
 | 
			
		||||
                               shouldRemoveZPosition:(BOOL)removeZPosition {
 | 
			
		||||
  LandmarksDetectionResultProto landmarkDetectionResultProto;
 | 
			
		||||
 | 
			
		||||
  if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
 | 
			
		||||
    return nil;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (removeZPosition) {
 | 
			
		||||
    // Remove z position of landmarks, because they are not used in correctness testing. For video
 | 
			
		||||
    // or live stream mode, the z positions varies a lot during tracking from frame to frame.
 | 
			
		||||
    for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
 | 
			
		||||
      auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
 | 
			
		||||
      landmark.clear_z();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ClassificationListProto gesturesProto;
 | 
			
		||||
  ClassificationProto *classificationProto = gesturesProto.add_classification();
 | 
			
		||||
  classificationProto->set_label([gestureLabel UTF8String]);
 | 
			
		||||
 | 
			
		||||
  return [MPPGestureRecognizerResult
 | 
			
		||||
      gestureRecognizerResultWithHandGesturesProto:{gesturesProto}
 | 
			
		||||
                                   handednessProto:{landmarkDetectionResultProto.classifications()}
 | 
			
		||||
                                handLandmarksProto:{landmarkDetectionResultProto.landmarks()}
 | 
			
		||||
                               worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
 | 
			
		||||
                           timestampInMilliseconds:0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
							
								
								
									
										62
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,62 @@
 | 
			
		|||
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
 | 
			
		||||
load(
 | 
			
		||||
    "//mediapipe/framework/tool:ios.bzl",
 | 
			
		||||
    "MPP_TASK_MINIMUM_OS_VERSION",
 | 
			
		||||
)
 | 
			
		||||
load(
 | 
			
		||||
    "@org_tensorflow//tensorflow/lite:special_rules.bzl",
 | 
			
		||||
    "tflite_ios_lab_runner",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
 | 
			
		||||
TFL_DEFAULT_TAGS = [
 | 
			
		||||
    "apple",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Following sanitizer tests are not supported by iOS test targets.
 | 
			
		||||
TFL_DISABLED_SANITIZER_TAGS = [
 | 
			
		||||
    "noasan",
 | 
			
		||||
    "nomsan",
 | 
			
		||||
    "notsan",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPHandLandmarkerObjcTestLibrary",
 | 
			
		||||
    testonly = 1,
 | 
			
		||||
    srcs = ["MPPHandLandmarkerTests.m"],
 | 
			
		||||
    copts = [
 | 
			
		||||
        "-ObjC++",
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:hand_landmarker.task",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_images",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_protos",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/hand_landmarker/utils:MPPHandLandmarkerResultProtobufHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker",
 | 
			
		||||
    ] + select({
 | 
			
		||||
        "//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
			
		||||
        "//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
			
		||||
        "//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
			
		||||
        "//conditions:default": ["@ios_opencv//:OpencvFramework"],
 | 
			
		||||
    }),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ios_unit_test(
 | 
			
		||||
    name = "MPPHandLandmarkerObjcTest",
 | 
			
		||||
    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
			
		||||
    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
			
		||||
    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":MPPHandLandmarkerObjcTestLibrary",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,557 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <XCTest/XCTest.h>
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarker.h"
 | 
			
		||||
 | 
			
		||||
static NSString *const kPbFileExtension = @"pbtxt";
 | 
			
		||||
 | 
			
		||||
typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
 | 
			
		||||
 | 
			
		||||
static ResourceFileInfo *const kHandLandmarkerBundleAssetFile =
 | 
			
		||||
    @{@"name" : @"hand_landmarker", @"type" : @"task"};
 | 
			
		||||
 | 
			
		||||
static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
 | 
			
		||||
static ResourceFileInfo *const kPointingUpRotatedImage =
 | 
			
		||||
    @{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
 | 
			
		||||
 | 
			
		||||
static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
 | 
			
		||||
    @{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
 | 
			
		||||
static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
 | 
			
		||||
    @{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
 | 
			
		||||
 | 
			
		||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
			
		||||
static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		||||
 | 
			
		||||
static NSString *const kLiveStreamTestsDictHandLandmarkerKey = @"gesture_recognizer";
 | 
			
		||||
static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		||||
 | 
			
		||||
#define AssertEqualErrors(error, expectedError)              \
 | 
			
		||||
  XCTAssertNotNil(error);                                    \
 | 
			
		||||
  XCTAssertEqualObjects(error.domain, expectedError.domain); \
 | 
			
		||||
  XCTAssertEqual(error.code, expectedError.code);            \
 | 
			
		||||
  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
			
		||||
 | 
			
		||||
#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex)   \
 | 
			
		||||
  XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance,            \
 | 
			
		||||
                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
 | 
			
		||||
  XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance,            \
 | 
			
		||||
                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
 | 
			
		||||
 | 
			
		||||
#define AssertHandLandmarkerResultIsEmpty(handLandmarkerResult) \
 | 
			
		||||
  XCTAssertTrue(handLandmarkerResult.handedness.count == 0);    \
 | 
			
		||||
  XCTAssertTrue(handLandmarkerResult.landmarks.count == 0);     \
 | 
			
		||||
  XCTAssertTrue(handLandmarkerResult.worldLandmarks.count == 0);
 | 
			
		||||
 | 
			
		||||
@interface MPPHandLandmarkerTests : XCTestCase <MPPHandLandmarkerLiveStreamDelegate> {
 | 
			
		||||
  NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
 | 
			
		||||
  NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
 | 
			
		||||
}
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
@implementation MPPHandLandmarkerTests
 | 
			
		||||
 | 
			
		||||
#pragma mark Results
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)emptyHandLandmarkerResult {
 | 
			
		||||
  return [[MPPHandLandmarkerResult alloc] initWithLandmarks:@[]
 | 
			
		||||
                                             worldLandmarks:@[]
 | 
			
		||||
                                                 handedness:@[]
 | 
			
		||||
 | 
			
		||||
                                    timestampInMilliseconds:0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
 | 
			
		||||
  NSString *filePath = [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
 | 
			
		||||
 | 
			
		||||
  return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
 | 
			
		||||
                                                         shouldRemoveZPosition:YES];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)pointingUpRotatedHandLandmarkerResult {
 | 
			
		||||
  NSString *filePath =
 | 
			
		||||
      [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
 | 
			
		||||
 | 
			
		||||
  return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
 | 
			
		||||
                                                         shouldRemoveZPosition:YES];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
 | 
			
		||||
    areApproximatelyEqualToExpectedMultiHandLandmarks:
 | 
			
		||||
        (NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
 | 
			
		||||
  XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
 | 
			
		||||
  if (multiHandLandmarks.count == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
 | 
			
		||||
  NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
 | 
			
		||||
  for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
 | 
			
		||||
    MPPNormalizedLandmark *landmark = topHandLandmarks[i];
 | 
			
		||||
    XCTAssertNotNil(landmark);
 | 
			
		||||
    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
 | 
			
		||||
    areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
 | 
			
		||||
        (NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
 | 
			
		||||
  XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
 | 
			
		||||
  if (expectedMultiHandWorldLandmarks.count == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
 | 
			
		||||
  NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
 | 
			
		||||
 | 
			
		||||
  XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
 | 
			
		||||
  for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
 | 
			
		||||
    MPPLandmark *landmark = topHandWorldLandmarks[i];
 | 
			
		||||
    XCTAssertNotNil(landmark);
 | 
			
		||||
    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
 | 
			
		||||
    isApproximatelyEqualToExpectedResult:(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
			
		||||
  [self assertMultiHandLandmarks:handLandmarkerResult.landmarks
 | 
			
		||||
      areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
 | 
			
		||||
  [self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
 | 
			
		||||
      areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedHandLandmarkerResult
 | 
			
		||||
                                                                 .worldLandmarks];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark File
 | 
			
		||||
 | 
			
		||||
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
			
		||||
  NSString *filePath = [MPPHandLandmarkerTests filePathWithName:fileInfo[@"name"]
 | 
			
		||||
                                                      extension:fileInfo[@"type"]];
 | 
			
		||||
  return filePath;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
 | 
			
		||||
  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
 | 
			
		||||
                                                                      ofType:extension];
 | 
			
		||||
  return filePath;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Hand Landmarker Initializers
 | 
			
		||||
 | 
			
		||||
- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
 | 
			
		||||
    (ResourceFileInfo *)modelFileInfo {
 | 
			
		||||
  NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
 | 
			
		||||
  MPPHandLandmarkerOptions *handLandmarkerOptions = [[MPPHandLandmarkerOptions alloc] init];
 | 
			
		||||
  handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
 | 
			
		||||
 | 
			
		||||
  return handLandmarkerOptions;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
 | 
			
		||||
    (MPPHandLandmarkerOptions *)handLandmarkerOptions {
 | 
			
		||||
  NSError *error;
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
 | 
			
		||||
  XCTAssertNotNil(handLandmarker);
 | 
			
		||||
  XCTAssertNil(error);
 | 
			
		||||
 | 
			
		||||
  return handLandmarker;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertCreateHandLandmarkerWithOptions:(MPPHandLandmarkerOptions *)handLandmarkerOptions
 | 
			
		||||
                       failsWithExpectedError:(NSError *)expectedError {
 | 
			
		||||
  NSError *error = nil;
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
 | 
			
		||||
 | 
			
		||||
  XCTAssertNil(handLandmarker);
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Assert Hand Landmarker Results
 | 
			
		||||
 | 
			
		||||
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
			
		||||
  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
 | 
			
		||||
                                              fileName:fileInfo[@"name"]
 | 
			
		||||
                                                ofType:fileInfo[@"type"]];
 | 
			
		||||
  XCTAssertNotNil(image);
 | 
			
		||||
 | 
			
		||||
  return image;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
			
		||||
                    orientation:(UIImageOrientation)orientation {
 | 
			
		||||
  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
 | 
			
		||||
                                              fileName:fileInfo[@"name"]
 | 
			
		||||
                                                ofType:fileInfo[@"type"]
 | 
			
		||||
                                           orientation:orientation];
 | 
			
		||||
  XCTAssertNotNil(image);
 | 
			
		||||
 | 
			
		||||
  return image;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
 | 
			
		||||
                                   usingHandLandmarker:(MPPHandLandmarker *)handLandmarker {
 | 
			
		||||
  MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
 | 
			
		||||
  XCTAssertNotNil(handLandmarkerResult);
 | 
			
		||||
 | 
			
		||||
  return handLandmarkerResult;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertResultsOfDetectInImageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
			
		||||
                             usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
 | 
			
		||||
         approximatelyEqualsHandLandmarkerResult:
 | 
			
		||||
             (MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:fileInfo
 | 
			
		||||
                                                              usingHandLandmarker:handLandmarker];
 | 
			
		||||
  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark General Tests
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithModelPathSucceeds {
 | 
			
		||||
  NSString *modelPath =
 | 
			
		||||
      [MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [[MPPHandLandmarker alloc] initWithModelPath:modelPath
 | 
			
		||||
                                                                             error:nil];
 | 
			
		||||
  XCTAssertNotNil(handLandmarker);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
 | 
			
		||||
                             usingHandLandmarker:handLandmarker
 | 
			
		||||
         approximatelyEqualsHandLandmarkerResult:[MPPHandLandmarkerTests
 | 
			
		||||
                                                     thumbUpHandLandmarkerResult]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithEmptyResultsSucceeds {
 | 
			
		||||
  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kNoHandsImage
 | 
			
		||||
                                                              usingHandLandmarker:handLandmarker];
 | 
			
		||||
  AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithNumHandsSucceeds {
 | 
			
		||||
  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  const NSInteger numHands = 2;
 | 
			
		||||
  handLandmarkerOptions.numHands = numHands;
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kTwoHandsImage
 | 
			
		||||
                                                              usingHandLandmarker:handLandmarker];
 | 
			
		||||
 | 
			
		||||
  XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithRotationSucceeds {
 | 
			
		||||
  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
 | 
			
		||||
                                   orientation:UIImageOrientationRight];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
 | 
			
		||||
 | 
			
		||||
  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests
 | 
			
		||||
                                               pointingUpRotatedHandLandmarkerResult]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Running Mode Tests
 | 
			
		||||
 | 
			
		||||
- (void)testCreateHandLandmarkerFailsWithDelegateInNonLiveStreamMode {
 | 
			
		||||
  MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
 | 
			
		||||
  for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
 | 
			
		||||
    MPPHandLandmarkerOptions *options =
 | 
			
		||||
        [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
    options.runningMode = runningModesToTest[i];
 | 
			
		||||
    options.handLandmarkerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
    [self
 | 
			
		||||
        assertCreateHandLandmarkerWithOptions:options
 | 
			
		||||
                       failsWithExpectedError:
 | 
			
		||||
                           [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                                           userInfo:@{
 | 
			
		||||
                                             NSLocalizedDescriptionKey :
 | 
			
		||||
                                                 @"The vision task is in image or video mode. The "
 | 
			
		||||
                                                 @"delegate must not be set in the task's options."
 | 
			
		||||
                                           }]];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testCreateHandLandmarkerFailsWithMissingDelegateInLiveStreamMode {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
 | 
			
		||||
  [self assertCreateHandLandmarkerWithOptions:options
 | 
			
		||||
                       failsWithExpectedError:
 | 
			
		||||
                           [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                                           userInfo:@{
 | 
			
		||||
                                             NSLocalizedDescriptionKey :
 | 
			
		||||
                                                 @"The vision task is in live stream mode. An "
 | 
			
		||||
                                                 @"object must be set as the delegate of the task "
 | 
			
		||||
                                                 @"in its options to ensure asynchronous delivery "
 | 
			
		||||
                                                 @"of results."
 | 
			
		||||
                                           }]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectFailsWithCallingWrongApiInImageMode {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  NSError *liveStreamApiCallError;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectAsyncInImage:image
 | 
			
		||||
                            timestampInMilliseconds:0
 | 
			
		||||
                                              error:&liveStreamApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedLiveStreamApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
			
		||||
                                                    @"stream mode. Current Running Mode: Image"
 | 
			
		||||
                      }];
 | 
			
		||||
 | 
			
		||||
  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
			
		||||
 | 
			
		||||
  NSError *videoApiCallError;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectInVideoFrame:image
 | 
			
		||||
                            timestampInMilliseconds:0
 | 
			
		||||
                                              error:&videoApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedVideoApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"video mode. Current Running Mode: Image"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectFailsWithCallingWrongApiInVideoMode {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeVideo;
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  NSError *liveStreamApiCallError;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectAsyncInImage:image
 | 
			
		||||
                            timestampInMilliseconds:0
 | 
			
		||||
                                              error:&liveStreamApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedLiveStreamApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
			
		||||
                                                    @"stream mode. Current Running Mode: Video"
 | 
			
		||||
                      }];
 | 
			
		||||
 | 
			
		||||
  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
			
		||||
 | 
			
		||||
  NSError *imageApiCallError;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedImageApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"image mode. Current Running Mode: Video"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectFailsWithCallingWrongApiInLiveStreamMode {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
  options.handLandmarkerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  NSError *imageApiCallError;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedImageApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"image mode. Current Running Mode: Live Stream"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
			
		||||
 | 
			
		||||
  NSError *videoApiCallError;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectInVideoFrame:image
 | 
			
		||||
                            timestampInMilliseconds:0
 | 
			
		||||
                                              error:&videoApiCallError]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedVideoApiCallError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
			
		||||
                                                    @"video mode. Current Running Mode: Live Stream"
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithVideoModeSucceeds {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeVideo;
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < 3; i++) {
 | 
			
		||||
    MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInVideoFrame:image
 | 
			
		||||
                                                               timestampInMilliseconds:i
 | 
			
		||||
                                                                                 error:nil];
 | 
			
		||||
    [self assertHandLandmarkerResult:handLandmarkerResult
 | 
			
		||||
        isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithOutOfOrderTimestampsAndLiveStreamModeFails {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
  options.handLandmarkerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
			
		||||
      initWithDescription:@"detectWiththOutOfOrderTimestampsAndLiveStream"];
 | 
			
		||||
 | 
			
		||||
  expectation.expectedFulfillmentCount = 1;
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  _outOfOrderTimestampTestDict = @{
 | 
			
		||||
    kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
 | 
			
		||||
    kLiveStreamTestsDictExpectationKey : expectation
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:1 error:nil]);
 | 
			
		||||
 | 
			
		||||
  NSError *error;
 | 
			
		||||
  XCTAssertFalse([handLandmarker detectAsyncInImage:image timestampInMilliseconds:0 error:&error]);
 | 
			
		||||
 | 
			
		||||
  NSError *expectedError =
 | 
			
		||||
      [NSError errorWithDomain:kExpectedErrorDomain
 | 
			
		||||
                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
			
		||||
                      userInfo:@{
 | 
			
		||||
                        NSLocalizedDescriptionKey :
 | 
			
		||||
                            @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
 | 
			
		||||
                      }];
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
 | 
			
		||||
  NSTimeInterval timeout = 0.5f;
 | 
			
		||||
  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)testDetectWithLiveStreamModeSucceeds {
 | 
			
		||||
  MPPHandLandmarkerOptions *options =
 | 
			
		||||
      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  options.runningMode = MPPRunningModeLiveStream;
 | 
			
		||||
  options.handLandmarkerLiveStreamDelegate = self;
 | 
			
		||||
 | 
			
		||||
  NSInteger iterationCount = 100;
 | 
			
		||||
 | 
			
		||||
  // Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
 | 
			
		||||
  // times. An normal expectation will fail if expectation.fulfill() is not called
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
			
		||||
  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since in our case we cannot predict how many times the expectation is supposed to be fullfilled
 | 
			
		||||
  // setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds ifexpectation is fullfilled <=
 | 
			
		||||
  // `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation =
 | 
			
		||||
      [[XCTestExpectation alloc] initWithDescription:@"detectWithLiveStream"];
 | 
			
		||||
 | 
			
		||||
  expectation.expectedFulfillmentCount = iterationCount + 1;
 | 
			
		||||
  expectation.inverted = YES;
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
			
		||||
 | 
			
		||||
  _liveStreamSucceedsTestDict = @{
 | 
			
		||||
    kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
 | 
			
		||||
    kLiveStreamTestsDictExpectationKey : expectation
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
 | 
			
		||||
  // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
 | 
			
		||||
  // `CMSampleBuffer`.
 | 
			
		||||
  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < iterationCount; i++) {
 | 
			
		||||
    XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:i error:nil]);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  NSTimeInterval timeout = 0.5f;
 | 
			
		||||
  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)handLandmarker:(MPPHandLandmarker *)handLandmarker
 | 
			
		||||
    didFinishDetectionWithResult:(MPPHandLandmarkerResult *)handLandmarkerResult
 | 
			
		||||
         timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
			
		||||
                           error:(NSError *)error {
 | 
			
		||||
  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
 | 
			
		||||
 | 
			
		||||
  if (handLandmarker == _outOfOrderTimestampTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
 | 
			
		||||
    [_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
			
		||||
  } else if (handLandmarker == _liveStreamSucceedsTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
 | 
			
		||||
    [_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
							
								
								
									
										22
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/utils/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/utils/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPHandLandmarkerResultProtobufHelpers",
 | 
			
		||||
    srcs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.mm"],
 | 
			
		||||
    hdrs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.h"],
 | 
			
		||||
    copts = [
 | 
			
		||||
        "-ObjC++",
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarkerResult",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,26 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <Foundation/Foundation.h>
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarkerResult.h"
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
@interface MPPHandLandmarkerResult (ProtobufHelpers)
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
                                                    shouldRemoveZPosition:(BOOL)removeZPosition;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_END
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,58 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+Helpers.h"
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
using ClassificationListProto = ::mediapipe::ClassificationList;
 | 
			
		||||
using ClassificationProto = ::mediapipe::Classification;
 | 
			
		||||
using LandmarksDetectionResultProto =
 | 
			
		||||
    ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
 | 
			
		||||
using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
 | 
			
		||||
@implementation MPPHandLandmarkerResult (ProtobufHelpers)
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
                                                    shouldRemoveZPosition:(BOOL)removeZPosition {
 | 
			
		||||
  LandmarksDetectionResultProto landmarkDetectionResultProto;
 | 
			
		||||
 | 
			
		||||
  if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
 | 
			
		||||
    return nil;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (removeZPosition) {
 | 
			
		||||
    // Remove z position of landmarks, because they are not used in correctness testing. For video
 | 
			
		||||
    // or live stream mode, the z positions varies a lot during tracking from frame to frame.
 | 
			
		||||
    for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
 | 
			
		||||
      auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
 | 
			
		||||
      landmark.clear_z();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return [MPPHandLandmarkerResult
 | 
			
		||||
      handLandmarkerResultWithLandmarksProto:{landmarkDetectionResultProto.landmarks()}
 | 
			
		||||
                         worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
 | 
			
		||||
                             handednessProto:{landmarkDetectionResultProto.classifications()}
 | 
			
		||||
                     timestampInMilliSeconds:0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -673,10 +673,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
			
		|||
  // If `expectation.isInverted = true`, the test will only succeed if
 | 
			
		||||
  // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
			
		||||
  // Since in our case we cannot predict how many times the expectation is
 | 
			
		||||
  // supposed to be fullfilled setting,
 | 
			
		||||
  // supposed to be fulfilled setting,
 | 
			
		||||
  // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
			
		||||
  // `expectation.isInverted = true` ensures that test succeeds if
 | 
			
		||||
  // expectation is fullfilled <= `iterationCount` times.
 | 
			
		||||
  // expectation is fulfilled <= `iterationCount` times.
 | 
			
		||||
  XCTestExpectation *expectation =
 | 
			
		||||
      [[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -64,3 +64,17 @@ objc_library(
 | 
			
		|||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPMask",
 | 
			
		||||
    srcs = ["sources/MPPMask.mm"],
 | 
			
		||||
    hdrs = ["sources/MPPMask.h"],
 | 
			
		||||
    copts = [
 | 
			
		||||
        "-ObjC++",
 | 
			
		||||
        "-std=c++17",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										118
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,118 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import <Foundation/Foundation.h>
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
 | 
			
		||||
/** The underlying type of the segmentation mask. */
 | 
			
		||||
typedef NS_ENUM(NSUInteger, MPPMaskDataType) {
 | 
			
		||||
 | 
			
		||||
  /** Represents the native `UInt8 *` type. */
 | 
			
		||||
  MPPMaskDataTypeUInt8,
 | 
			
		||||
 | 
			
		||||
  /** Represents the native `float *` type. */
 | 
			
		||||
  MPPMaskDataTypeFloat32,
 | 
			
		||||
 | 
			
		||||
} NS_SWIFT_NAME(MaskDataType);
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The wrapper class for MediaPipe segmentation masks.
 | 
			
		||||
 *
 | 
			
		||||
 * Masks are stored as `UInt8 *` or `float *` objects.
 | 
			
		||||
 * Every mask has an underlying type which can be accessed using `dataType`. You can access the
 | 
			
		||||
 * mask as any other type using the appropriate properties. For example, if the underlying type is
 | 
			
		||||
 * `MPPMaskDataTypeUInt8`, in addition to accessing the mask using `uint8Data`, you can access
 | 
			
		||||
 * `float32Data` to get the 32 bit float data (with values ranging from 0.0 to 1.0). The first
 | 
			
		||||
 * time you access the data as a type different from the underlying type, an expensive type
 | 
			
		||||
 * conversion is performed. Subsequent accesses return a pointer to the memory location fo the same
 | 
			
		||||
 * type converted array. As type conversions can be expensive, it is recommended to limit the
 | 
			
		||||
 * accesses to data of types different from the underlying type.
 | 
			
		||||
 *
 | 
			
		||||
 * Masks that are returned from a MediaPipe Tasks are owned by by the underlying C++ Task. If you
 | 
			
		||||
 * need to extend the lifetime of these objects, you can invoke the `[MPPMask copy:]` method.
 | 
			
		||||
 */
 | 
			
		||||
NS_SWIFT_NAME(Mask)
 | 
			
		||||
@interface MPPMask : NSObject <NSCopying>
 | 
			
		||||
 | 
			
		||||
/** The width of the mask. */
 | 
			
		||||
@property(nonatomic, readonly) NSInteger width;
 | 
			
		||||
 | 
			
		||||
/** The height of the mask. */
 | 
			
		||||
@property(nonatomic, readonly) NSInteger height;
 | 
			
		||||
 | 
			
		||||
/** The data type of the mask. */
 | 
			
		||||
@property(nonatomic, readonly) MPPMaskDataType dataType;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The pointer to the memory location where the underlying mask as a single channel `UInt8` array is
 | 
			
		||||
 * stored. Uint8 values use the full value range and range from 0 to 255.
 | 
			
		||||
 */
 | 
			
		||||
@property(nonatomic, readonly, assign) const UInt8 *uint8Data;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The pointer to the memory location where the underlying mask as a single channel float 32 array
 | 
			
		||||
 * is stored. Float values range from 0.0 to 1.0.
 | 
			
		||||
 */
 | 
			
		||||
@property(nonatomic, readonly, assign) const float *float32Data;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Initializes an `MPPMask` object of type `MPPMaskDataTypeUInt8` with the given `UInt8*` data,
 | 
			
		||||
 * width and height.
 | 
			
		||||
 *
 | 
			
		||||
 * If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
 | 
			
		||||
 * `uint8Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
 | 
			
		||||
 * the `MPPMask` must outlive the passed in `uint8Data`.
 | 
			
		||||
 *
 | 
			
		||||
 * @param uint8Data A pointer to the memory location of the `UInt8` data array.
 | 
			
		||||
 * @param width The width of the mask.
 | 
			
		||||
 * @param height The height of the mask.
 | 
			
		||||
 * @param shouldCopy The height of the mask.
 | 
			
		||||
 *
 | 
			
		||||
 * @return A new `MPPMask` instance with the given `UInt8*` data, width and height.
 | 
			
		||||
 */
 | 
			
		||||
- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
 | 
			
		||||
                                     width:(NSInteger)width
 | 
			
		||||
                                    height:(NSInteger)height
 | 
			
		||||
                                shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Initializes an `MPPMask` object of type `MPPMaskDataTypeFloat32` with the given `float*` data,
 | 
			
		||||
 * width and height.
 | 
			
		||||
 *
 | 
			
		||||
 * If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
 | 
			
		||||
 * `float32Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
 | 
			
		||||
 * the `MPPMask` must outlive the passed in `float32Data`.
 | 
			
		||||
 *
 | 
			
		||||
 * @param float32Data A pointer to the memory location of the `float` data array.
 | 
			
		||||
 * @param width The width of the mask.
 | 
			
		||||
 * @param height The height of the mask.
 | 
			
		||||
 *
 | 
			
		||||
 * @return A new `MPPMask` instance with the given `float*` data, width and height.
 | 
			
		||||
 */
 | 
			
		||||
- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
 | 
			
		||||
                                       width:(NSInteger)width
 | 
			
		||||
                                      height:(NSInteger)height
 | 
			
		||||
                                  shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
 | 
			
		||||
 | 
			
		||||
// TODO: Add methods for CVPixelBuffer conversion.
 | 
			
		||||
 | 
			
		||||
/** Unavailable. */
 | 
			
		||||
- (instancetype)init NS_UNAVAILABLE;
 | 
			
		||||
 | 
			
		||||
+ (instancetype)new NS_UNAVAILABLE;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
NS_ASSUME_NONNULL_END
 | 
			
		||||
							
								
								
									
										135
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.mm
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.mm
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,135 @@
 | 
			
		|||
// Copyright 2023 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
 | 
			
		||||
 | 
			
		||||
@interface MPPMask () {
 | 
			
		||||
  const UInt8 *_uint8Data;
 | 
			
		||||
  const float *_float32Data;
 | 
			
		||||
  std::unique_ptr<UInt8[]> _uint8DataPtr;
 | 
			
		||||
  std::unique_ptr<float[]> _float32DataPtr;
 | 
			
		||||
}
 | 
			
		||||
@end
 | 
			
		||||
 | 
			
		||||
@implementation MPPMask
 | 
			
		||||
 | 
			
		||||
- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
 | 
			
		||||
                                     width:(NSInteger)width
 | 
			
		||||
                                    height:(NSInteger)height
 | 
			
		||||
                                shouldCopy:(BOOL)shouldCopy {
 | 
			
		||||
 | 
			
		||||
  self = [super init];
 | 
			
		||||
  if (self) {
 | 
			
		||||
    _width = width;
 | 
			
		||||
    _height = height;
 | 
			
		||||
    _dataType = MPPMaskDataTypeUInt8;
 | 
			
		||||
 | 
			
		||||
    if (shouldCopy) {
 | 
			
		||||
      size_t length = _width * _height;
 | 
			
		||||
      _uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
 | 
			
		||||
      _uint8Data = _uint8DataPtr.get();
 | 
			
		||||
      memcpy((UInt8 *)_uint8Data, uint8Data, length * sizeof(UInt8));
 | 
			
		||||
    } else {
 | 
			
		||||
      _uint8Data = uint8Data;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return self;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
 | 
			
		||||
                                       width:(NSInteger)width
 | 
			
		||||
                                      height:(NSInteger)height
 | 
			
		||||
                                  shouldCopy:(BOOL)shouldCopy {
 | 
			
		||||
  self = [super init];
 | 
			
		||||
  if (self) {
 | 
			
		||||
    _width = width;
 | 
			
		||||
    _height = height;
 | 
			
		||||
    _dataType = MPPMaskDataTypeFloat32;
 | 
			
		||||
 | 
			
		||||
    if (shouldCopy) {
 | 
			
		||||
      size_t length = _width * _height;
 | 
			
		||||
      _float32DataPtr = std::unique_ptr<float[]>(new float[length]);
 | 
			
		||||
      _float32Data = _float32DataPtr.get();
 | 
			
		||||
      memcpy((float *)_float32Data, float32Data, length * sizeof(float));
 | 
			
		||||
    } else {
 | 
			
		||||
      _float32Data = float32Data;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return self;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (const UInt8 *)uint8Data {
 | 
			
		||||
  switch (_dataType) {
 | 
			
		||||
    case MPPMaskDataTypeUInt8: {
 | 
			
		||||
      return _uint8Data;
 | 
			
		||||
    }
 | 
			
		||||
    case MPPMaskDataTypeFloat32: {
 | 
			
		||||
      if (_uint8DataPtr) {
 | 
			
		||||
        return _uint8DataPtr.get();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      size_t length = _width * _height;
 | 
			
		||||
      _uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
 | 
			
		||||
      UInt8 *data = _uint8DataPtr.get();
 | 
			
		||||
      for (int i = 0; i < length; i++) {
 | 
			
		||||
        data[i] = _float32Data[i] * 255;
 | 
			
		||||
      }
 | 
			
		||||
      return data;
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
      return NULL;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (const float *)float32Data {
 | 
			
		||||
  switch (_dataType) {
 | 
			
		||||
    case MPPMaskDataTypeUInt8: {
 | 
			
		||||
      if (_float32DataPtr) {
 | 
			
		||||
        return _float32DataPtr.get();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      size_t length = _width * _height;
 | 
			
		||||
      _float32DataPtr = std::unique_ptr<float[]>(new float[length]);
 | 
			
		||||
      float *data = _float32DataPtr.get();
 | 
			
		||||
      for (int i = 0; i < length; i++) {
 | 
			
		||||
        data[i] = (float)_uint8Data[i] / 255;
 | 
			
		||||
      }
 | 
			
		||||
      return data;
 | 
			
		||||
    }
 | 
			
		||||
    case MPPMaskDataTypeFloat32: {
 | 
			
		||||
      return _float32Data;
 | 
			
		||||
    }
 | 
			
		||||
    default:
 | 
			
		||||
      return NULL;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (id)copyWithZone:(NSZone *)zone {
 | 
			
		||||
  switch (_dataType) {
 | 
			
		||||
    case MPPMaskDataTypeUInt8:
 | 
			
		||||
      return [[MPPMask alloc] initWithUInt8Data:self.uint8Data
 | 
			
		||||
                                          width:self.width
 | 
			
		||||
                                         height:self.height
 | 
			
		||||
                                     shouldCopy:YES];
 | 
			
		||||
    case MPPMaskDataTypeFloat32:
 | 
			
		||||
      return [[MPPMask alloc] initWithFloat32Data:self.float32Data
 | 
			
		||||
                                            width:self.width
 | 
			
		||||
                                           height:self.height
 | 
			
		||||
                                       shouldCopy:YES];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			@ -165,7 +165,7 @@ static NSString *const kTaskPrefix = @"com.mediapipe.tasks.vision";
 | 
			
		|||
  // For 90° and 270° rotations, we need to swap width and height.
 | 
			
		||||
  // This is due to the internal behavior of ImageToTensorCalculator, which:
 | 
			
		||||
  // - first denormalizes the provided rect by multiplying the rect width or height by the image
 | 
			
		||||
  //   width or height, repectively.
 | 
			
		||||
  //   width or height, respectively.
 | 
			
		||||
  // - then rotates this by denormalized rect by the provided rotation, and uses this for cropping,
 | 
			
		||||
  // - then finally rotates this back.
 | 
			
		||||
  if (rotationDegrees % 180 == 0) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -28,6 +28,7 @@
 | 
			
		|||
- (id)copyWithZone:(NSZone *)zone {
 | 
			
		||||
  MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
 | 
			
		||||
 | 
			
		||||
  faceDetectorOptions.runningMode = self.runningMode;
 | 
			
		||||
  faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
 | 
			
		||||
  faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
 | 
			
		||||
  faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,9 +43,16 @@ objc_library(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPFaceLandmarksConnections",
 | 
			
		||||
    hdrs = ["sources/MPPFaceLandmarksConnections.h"],
 | 
			
		||||
    module_name = "MPPFaceLandmarksConnections",
 | 
			
		||||
    deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
objc_library(
 | 
			
		||||
    name = "MPPFaceLandmarker",
 | 
			
		||||
    srcs = ["sources/MPPFaceLandmarker.m"],
 | 
			
		||||
    srcs = ["sources/MPPFaceLandmarker.mm"],
 | 
			
		||||
    hdrs = ["sources/MPPFaceLandmarker.h"],
 | 
			
		||||
    copts = [
 | 
			
		||||
        "-ObjC++",
 | 
			
		||||
| 
						 | 
				
			
			@ -55,9 +62,11 @@ objc_library(
 | 
			
		|||
    deps = [
 | 
			
		||||
        ":MPPFaceLandmarkerOptions",
 | 
			
		||||
        ":MPPFaceLandmarkerResult",
 | 
			
		||||
        ":MPPFaceLandmarksConnections",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/components/containers:MPPConnection",
 | 
			
		||||
        "//mediapipe/tasks/ios/core:MPPTaskInfo",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/core:MPPImage",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,7 @@
 | 
			
		|||
 | 
			
		||||
#import <Foundation/Foundation.h>
 | 
			
		||||
 | 
			
		||||
#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h"
 | 
			
		||||
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerResult.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -147,6 +148,83 @@ NS_SWIFT_NAME(FaceLandmarker)
 | 
			
		|||
                      error:(NSError **)error
 | 
			
		||||
    NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the lips.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the lips.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)lipsConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the left eye.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the left eye.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)leftEyeConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the left eyebrow.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the left eyebrow.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)leftEyebrowConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the left iris.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the left iris.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)leftIrisConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the right eye.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the right eyr.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)rightEyeConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the right eyebrow.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the right eyebrow.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)rightEyebrowConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the right iris.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the right iris.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)rightIrisConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks of the face oval.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks of the face oval.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)faceOvalConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between making up the contours of the face.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the contours of the face.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)contoursConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks making up the tesselation of the face.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks  making up the tesselation of the face.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)tesselationConnections;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns the connections between all the landmarks in the face.
 | 
			
		||||
 *
 | 
			
		||||
 * @return An array of connections between all the landmarks in the face.
 | 
			
		||||
 */
 | 
			
		||||
+ (NSArray<MPPConnection *> *)faceConnections;
 | 
			
		||||
 | 
			
		||||
- (instancetype)init NS_UNAVAILABLE;
 | 
			
		||||
 | 
			
		||||
+ (instancetype)new NS_UNAVAILABLE;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue
	
	Block a user