Merge remote-tracking branch 'origin/master' into nguyencse/facemeshioslib
This commit is contained in:
		
						commit
						5093b7a2a9
					
				
							
								
								
									
										16
									
								
								WORKSPACE
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								WORKSPACE
									
									
									
									
									
								
							| 
						 | 
					@ -45,12 +45,13 @@ http_archive(
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
http_archive(
 | 
					http_archive(
 | 
				
			||||||
   name = "rules_foreign_cc",
 | 
					    name = "rules_foreign_cc",
 | 
				
			||||||
   strip_prefix = "rules_foreign_cc-0.1.0",
 | 
					    sha256 = "2a4d07cd64b0719b39a7c12218a3e507672b82a97b98c6a89d38565894cf7c51",
 | 
				
			||||||
   url = "https://github.com/bazelbuild/rules_foreign_cc/archive/0.1.0.zip",
 | 
					    strip_prefix = "rules_foreign_cc-0.9.0",
 | 
				
			||||||
 | 
					    url = "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.9.0.tar.gz",
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies")
 | 
					load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
rules_foreign_cc_dependencies()
 | 
					rules_foreign_cc_dependencies()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -492,9 +493,10 @@ http_archive(
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TensorFlow repo should always go after the other external dependencies.
 | 
					# TensorFlow repo should always go after the other external dependencies.
 | 
				
			||||||
# TF on 2023-05-26.
 | 
					# TF on 2023-06-13.
 | 
				
			||||||
_TENSORFLOW_GIT_COMMIT = "67d5c561981edc45daf3f9d73ddd1a77963733ca"
 | 
					_TENSORFLOW_GIT_COMMIT = "491681a5620e41bf079a582ac39c585cc86878b9"
 | 
				
			||||||
_TENSORFLOW_SHA256 = "0c8326285e9cb695313e194b97d388eea70bf8bf5b13e8f0962ca8eed5179ece"
 | 
					# curl -L https://github.com/tensorflow/tensorflow/archive/<TENSORFLOW_GIT_COMMIT>.tar.gz | shasum -a 256
 | 
				
			||||||
 | 
					_TENSORFLOW_SHA256 = "9f76389af7a2835e68413322c1eaabfadc912f02a76d71dc16be507f9ca3d3ac"
 | 
				
			||||||
http_archive(
 | 
					http_archive(
 | 
				
			||||||
    name = "org_tensorflow",
 | 
					    name = "org_tensorflow",
 | 
				
			||||||
    urls = [
 | 
					    urls = [
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -219,12 +219,10 @@ cc_library(
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":time_series_framer_calculator_cc_proto",
 | 
					        ":time_series_framer_calculator_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework:calculator_framework",
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:timestamp",
 | 
				
			||||||
        "//mediapipe/framework/formats:matrix",
 | 
					        "//mediapipe/framework/formats:matrix",
 | 
				
			||||||
        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
					        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework/port:integral_types",
 | 
					 | 
				
			||||||
        "//mediapipe/framework/port:logging",
 | 
					 | 
				
			||||||
        "//mediapipe/framework/port:ret_check",
 | 
					        "//mediapipe/framework/port:ret_check",
 | 
				
			||||||
        "//mediapipe/framework/port:status",
 | 
					 | 
				
			||||||
        "//mediapipe/util:time_series_util",
 | 
					        "//mediapipe/util:time_series_util",
 | 
				
			||||||
        "@com_google_audio_tools//audio/dsp:window_functions",
 | 
					        "@com_google_audio_tools//audio/dsp:window_functions",
 | 
				
			||||||
        "@eigen_archive//:eigen3",
 | 
					        "@eigen_archive//:eigen3",
 | 
				
			||||||
| 
						 | 
					@ -319,6 +317,20 @@ cc_test(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_binary(
 | 
				
			||||||
 | 
					    name = "time_series_framer_calculator_benchmark",
 | 
				
			||||||
 | 
					    srcs = ["time_series_framer_calculator_benchmark.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":time_series_framer_calculator",
 | 
				
			||||||
 | 
					        ":time_series_framer_calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:packet",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:matrix",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
				
			||||||
 | 
					        "@com_google_benchmark//:benchmark",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_test(
 | 
					cc_test(
 | 
				
			||||||
    name = "time_series_framer_calculator_test",
 | 
					    name = "time_series_framer_calculator_test",
 | 
				
			||||||
    srcs = ["time_series_framer_calculator_test.cc"],
 | 
					    srcs = ["time_series_framer_calculator_test.cc"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,9 +15,7 @@
 | 
				
			||||||
// Defines TimeSeriesFramerCalculator.
 | 
					// Defines TimeSeriesFramerCalculator.
 | 
				
			||||||
#include <math.h>
 | 
					#include <math.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <deque>
 | 
					#include <vector>
 | 
				
			||||||
#include <memory>
 | 
					 | 
				
			||||||
#include <string>
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "Eigen/Core"
 | 
					#include "Eigen/Core"
 | 
				
			||||||
#include "audio/dsp/window_functions.h"
 | 
					#include "audio/dsp/window_functions.h"
 | 
				
			||||||
| 
						 | 
					@ -25,9 +23,8 @@
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/matrix.h"
 | 
					#include "mediapipe/framework/formats/matrix.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
					#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/port/integral_types.h"
 | 
					 | 
				
			||||||
#include "mediapipe/framework/port/logging.h"
 | 
					 | 
				
			||||||
#include "mediapipe/framework/port/ret_check.h"
 | 
					#include "mediapipe/framework/port/ret_check.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/timestamp.h"
 | 
				
			||||||
#include "mediapipe/util/time_series_util.h"
 | 
					#include "mediapipe/util/time_series_util.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
| 
						 | 
					@ -88,11 +85,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
 | 
				
			||||||
  absl::Status Close(CalculatorContext* cc) override;
 | 
					  absl::Status Close(CalculatorContext* cc) override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  // Adds input data to the internal buffer.
 | 
					 | 
				
			||||||
  void EnqueueInput(CalculatorContext* cc);
 | 
					 | 
				
			||||||
  // Constructs and emits framed output packets.
 | 
					 | 
				
			||||||
  void FrameOutput(CalculatorContext* cc);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Timestamp CurrentOutputTimestamp() {
 | 
					  Timestamp CurrentOutputTimestamp() {
 | 
				
			||||||
    if (use_local_timestamp_) {
 | 
					    if (use_local_timestamp_) {
 | 
				
			||||||
      return current_timestamp_;
 | 
					      return current_timestamp_;
 | 
				
			||||||
| 
						 | 
					@ -106,14 +98,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
 | 
				
			||||||
                 Timestamp::kTimestampUnitsPerSecond);
 | 
					                 Timestamp::kTimestampUnitsPerSecond);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Returns the timestamp of a sample on a base, which is usually the time
 | 
					 | 
				
			||||||
  // stamp of a packet.
 | 
					 | 
				
			||||||
  Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
 | 
					 | 
				
			||||||
                                   int64_t number_of_samples) {
 | 
					 | 
				
			||||||
    return timestamp_base + round(number_of_samples / sample_rate_ *
 | 
					 | 
				
			||||||
                                  Timestamp::kTimestampUnitsPerSecond);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // The number of input samples to advance after the current output frame is
 | 
					  // The number of input samples to advance after the current output frame is
 | 
				
			||||||
  // emitted.
 | 
					  // emitted.
 | 
				
			||||||
  int next_frame_step_samples() const {
 | 
					  int next_frame_step_samples() const {
 | 
				
			||||||
| 
						 | 
					@ -142,61 +126,174 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
 | 
				
			||||||
  Timestamp initial_input_timestamp_;
 | 
					  Timestamp initial_input_timestamp_;
 | 
				
			||||||
  // The current timestamp is updated along with the incoming packets.
 | 
					  // The current timestamp is updated along with the incoming packets.
 | 
				
			||||||
  Timestamp current_timestamp_;
 | 
					  Timestamp current_timestamp_;
 | 
				
			||||||
  int num_channels_;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Each entry in this deque consists of a single sample, i.e. a
 | 
					  // Samples are buffered in a vector of sample blocks.
 | 
				
			||||||
  // single column vector, and its timestamp.
 | 
					  class SampleBlockBuffer {
 | 
				
			||||||
  std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
 | 
					   public:
 | 
				
			||||||
 | 
					    // Initializes the buffer.
 | 
				
			||||||
 | 
					    void Init(double sample_rate, int num_channels) {
 | 
				
			||||||
 | 
					      ts_units_per_sample_ = Timestamp::kTimestampUnitsPerSecond / sample_rate;
 | 
				
			||||||
 | 
					      num_channels_ = num_channels;
 | 
				
			||||||
 | 
					      num_samples_ = 0;
 | 
				
			||||||
 | 
					      first_block_offset_ = 0;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Number of channels, equal to the number of rows in each Matrix.
 | 
				
			||||||
 | 
					    int num_channels() const { return num_channels_; }
 | 
				
			||||||
 | 
					    // Total number of available samples over all blocks.
 | 
				
			||||||
 | 
					    int num_samples() const { return num_samples_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Pushes a new block of samples on the back of the buffer with `timestamp`
 | 
				
			||||||
 | 
					    // being the input timestamp of the packet containing the Matrix.
 | 
				
			||||||
 | 
					    void Push(const Matrix& samples, Timestamp timestamp);
 | 
				
			||||||
 | 
					    // Copies `count` samples from the front of the buffer. If there are fewer
 | 
				
			||||||
 | 
					    // samples than this, the result is zero padded to have `count` samples.
 | 
				
			||||||
 | 
					    // The timestamp of the last copied sample is written to *last_timestamp.
 | 
				
			||||||
 | 
					    // This output is used below to update `current_timestamp_`, which is only
 | 
				
			||||||
 | 
					    // used when `use_local_timestamp` is true.
 | 
				
			||||||
 | 
					    Matrix CopySamples(int count, Timestamp* last_timestamp) const;
 | 
				
			||||||
 | 
					    // Drops `count` samples from the front of the buffer. If `count` exceeds
 | 
				
			||||||
 | 
					    // `num_samples()`, the buffer is emptied.  Returns how many samples were
 | 
				
			||||||
 | 
					    // dropped.
 | 
				
			||||||
 | 
					    int DropSamples(int count);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   private:
 | 
				
			||||||
 | 
					    struct Block {
 | 
				
			||||||
 | 
					      // Matrix of num_channels rows by num_samples columns, a block of possibly
 | 
				
			||||||
 | 
					      // multiple samples.
 | 
				
			||||||
 | 
					      Matrix samples;
 | 
				
			||||||
 | 
					      // Timestamp of the first sample in the Block. This comes from the input
 | 
				
			||||||
 | 
					      // packet's timestamp that contains this Matrix.
 | 
				
			||||||
 | 
					      Timestamp timestamp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      Block() : timestamp(Timestamp::Unstarted()) {}
 | 
				
			||||||
 | 
					      Block(const Matrix& samples, Timestamp timestamp)
 | 
				
			||||||
 | 
					          : samples(samples), timestamp(timestamp) {}
 | 
				
			||||||
 | 
					      int num_samples() const { return samples.cols(); }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    std::vector<Block> blocks_;
 | 
				
			||||||
 | 
					    // Number of timestamp units per sample. Used to compute timestamps as
 | 
				
			||||||
 | 
					    // nth sample timestamp = base_timestamp + round(ts_units_per_sample_ * n).
 | 
				
			||||||
 | 
					    double ts_units_per_sample_;
 | 
				
			||||||
 | 
					    // Number of rows in each Matrix.
 | 
				
			||||||
 | 
					    int num_channels_;
 | 
				
			||||||
 | 
					    // The total number of samples over all blocks, equal to
 | 
				
			||||||
 | 
					    // (sum_i blocks_[i].num_samples()) - first_block_offset_.
 | 
				
			||||||
 | 
					    int num_samples_;
 | 
				
			||||||
 | 
					    // The number of samples in the first block that have been discarded. This
 | 
				
			||||||
 | 
					    // way we can cheaply represent "partially discarding" a block.
 | 
				
			||||||
 | 
					    int first_block_offset_;
 | 
				
			||||||
 | 
					  } sample_buffer_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  bool use_window_;
 | 
					  bool use_window_;
 | 
				
			||||||
  Matrix window_;
 | 
					  Eigen::RowVectorXf window_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  bool use_local_timestamp_;
 | 
					  bool use_local_timestamp_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
 | 
					REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
 | 
					void TimeSeriesFramerCalculator::SampleBlockBuffer::Push(const Matrix& samples,
 | 
				
			||||||
  const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
 | 
					                                                         Timestamp timestamp) {
 | 
				
			||||||
 | 
					  num_samples_ += samples.cols();
 | 
				
			||||||
  for (int i = 0; i < input_frame.cols(); ++i) {
 | 
					  blocks_.emplace_back(samples, timestamp);
 | 
				
			||||||
    sample_buffer_.emplace_back(std::make_pair(
 | 
					 | 
				
			||||||
        input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
 | 
					Matrix TimeSeriesFramerCalculator::SampleBlockBuffer::CopySamples(
 | 
				
			||||||
  while (sample_buffer_.size() >=
 | 
					    int count, Timestamp* last_timestamp) const {
 | 
				
			||||||
 | 
					  Matrix copied(num_channels_, count);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!blocks_.empty()) {
 | 
				
			||||||
 | 
					    int num_copied = 0;
 | 
				
			||||||
 | 
					    // First block has an offset for samples that have been discarded.
 | 
				
			||||||
 | 
					    int offset = first_block_offset_;
 | 
				
			||||||
 | 
					    int n;
 | 
				
			||||||
 | 
					    Timestamp last_block_ts;
 | 
				
			||||||
 | 
					    int last_sample_index;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (auto it = blocks_.begin(); it != blocks_.end() && count > 0; ++it) {
 | 
				
			||||||
 | 
					      n = std::min(it->num_samples() - offset, count);
 | 
				
			||||||
 | 
					      // Copy `n` samples from the next block.
 | 
				
			||||||
 | 
					      copied.middleCols(num_copied, n) = it->samples.middleCols(offset, n);
 | 
				
			||||||
 | 
					      count -= n;
 | 
				
			||||||
 | 
					      num_copied += n;
 | 
				
			||||||
 | 
					      last_block_ts = it->timestamp;
 | 
				
			||||||
 | 
					      last_sample_index = offset + n - 1;
 | 
				
			||||||
 | 
					      offset = 0;  // No samples have been discarded in subsequent blocks.
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute the timestamp of the last copied sample.
 | 
				
			||||||
 | 
					    *last_timestamp =
 | 
				
			||||||
 | 
					        last_block_ts + std::round(ts_units_per_sample_ * last_sample_index);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (count > 0) {
 | 
				
			||||||
 | 
					    copied.rightCols(count).setZero();  // Zero pad if needed.
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return copied;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int TimeSeriesFramerCalculator::SampleBlockBuffer::DropSamples(int count) {
 | 
				
			||||||
 | 
					  if (blocks_.empty()) {
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto block_it = blocks_.begin();
 | 
				
			||||||
 | 
					  if (first_block_offset_ + count < block_it->num_samples()) {
 | 
				
			||||||
 | 
					    // `count` is less than the remaining samples in the first block.
 | 
				
			||||||
 | 
					    first_block_offset_ += count;
 | 
				
			||||||
 | 
					    num_samples_ -= count;
 | 
				
			||||||
 | 
					    return count;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  int num_samples_dropped = block_it->num_samples() - first_block_offset_;
 | 
				
			||||||
 | 
					  count -= num_samples_dropped;
 | 
				
			||||||
 | 
					  first_block_offset_ = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (++block_it; block_it != blocks_.end(); ++block_it) {
 | 
				
			||||||
 | 
					    if (block_it->num_samples() > count) {
 | 
				
			||||||
 | 
					      break;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    num_samples_dropped += block_it->num_samples();
 | 
				
			||||||
 | 
					    count -= block_it->num_samples();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  blocks_.erase(blocks_.begin(), block_it);  // Drop whole blocks.
 | 
				
			||||||
 | 
					  if (!blocks_.empty()) {
 | 
				
			||||||
 | 
					    first_block_offset_ = count;  // Drop part of the next block.
 | 
				
			||||||
 | 
					    num_samples_dropped += count;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  num_samples_ -= num_samples_dropped;
 | 
				
			||||||
 | 
					  return num_samples_dropped;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
 | 
					  if (initial_input_timestamp_ == Timestamp::Unstarted()) {
 | 
				
			||||||
 | 
					    initial_input_timestamp_ = cc->InputTimestamp();
 | 
				
			||||||
 | 
					    current_timestamp_ = initial_input_timestamp_;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Add input data to the internal buffer.
 | 
				
			||||||
 | 
					  sample_buffer_.Push(cc->Inputs().Index(0).Get<Matrix>(),
 | 
				
			||||||
 | 
					                      cc->InputTimestamp());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Construct and emit framed output packets.
 | 
				
			||||||
 | 
					  while (sample_buffer_.num_samples() >=
 | 
				
			||||||
         frame_duration_samples_ + samples_still_to_drop_) {
 | 
					         frame_duration_samples_ + samples_still_to_drop_) {
 | 
				
			||||||
    while (samples_still_to_drop_ > 0) {
 | 
					    sample_buffer_.DropSamples(samples_still_to_drop_);
 | 
				
			||||||
      sample_buffer_.pop_front();
 | 
					    Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
 | 
				
			||||||
      --samples_still_to_drop_;
 | 
					                                                     ¤t_timestamp_);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    const int frame_step_samples = next_frame_step_samples();
 | 
					    const int frame_step_samples = next_frame_step_samples();
 | 
				
			||||||
    std::unique_ptr<Matrix> output_frame(
 | 
					    samples_still_to_drop_ = frame_step_samples;
 | 
				
			||||||
        new Matrix(num_channels_, frame_duration_samples_));
 | 
					 | 
				
			||||||
    for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
 | 
					 | 
				
			||||||
         ++i) {
 | 
					 | 
				
			||||||
      output_frame->col(i) = sample_buffer_.front().first;
 | 
					 | 
				
			||||||
      current_timestamp_ = sample_buffer_.front().second;
 | 
					 | 
				
			||||||
      sample_buffer_.pop_front();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    const int frame_overlap_samples =
 | 
					 | 
				
			||||||
        frame_duration_samples_ - frame_step_samples;
 | 
					 | 
				
			||||||
    if (frame_overlap_samples > 0) {
 | 
					 | 
				
			||||||
      for (int i = 0; i < frame_overlap_samples; ++i) {
 | 
					 | 
				
			||||||
        output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
 | 
					 | 
				
			||||||
        current_timestamp_ = sample_buffer_[i].second;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      samples_still_to_drop_ = -frame_overlap_samples;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (use_window_) {
 | 
					    if (use_window_) {
 | 
				
			||||||
      *output_frame = (output_frame->array() * window_.array()).matrix();
 | 
					      // Apply the window to each row of output_frame.
 | 
				
			||||||
 | 
					      output_frame.array().rowwise() *= window_.array();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cc->Outputs().Index(0).Add(output_frame.release(),
 | 
					    cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
 | 
				
			||||||
                               CurrentOutputTimestamp());
 | 
					                                         .At(CurrentOutputTimestamp()));
 | 
				
			||||||
    ++cumulative_output_frames_;
 | 
					    ++cumulative_output_frames_;
 | 
				
			||||||
    cumulative_completed_samples_ += frame_step_samples;
 | 
					    cumulative_completed_samples_ += frame_step_samples;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -206,35 +303,18 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
 | 
				
			||||||
    // fact to enable packet queueing optimizations.
 | 
					    // fact to enable packet queueing optimizations.
 | 
				
			||||||
    cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
 | 
					    cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
 | 
					 | 
				
			||||||
  if (initial_input_timestamp_ == Timestamp::Unstarted()) {
 | 
					 | 
				
			||||||
    initial_input_timestamp_ = cc->InputTimestamp();
 | 
					 | 
				
			||||||
    current_timestamp_ = initial_input_timestamp_;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  EnqueueInput(cc);
 | 
					 | 
				
			||||||
  FrameOutput(cc);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
 | 
					absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) {
 | 
				
			||||||
  while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) {
 | 
					  sample_buffer_.DropSamples(samples_still_to_drop_);
 | 
				
			||||||
    sample_buffer_.pop_front();
 | 
					 | 
				
			||||||
    --samples_still_to_drop_;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  if (!sample_buffer_.empty() && pad_final_packet_) {
 | 
					 | 
				
			||||||
    std::unique_ptr<Matrix> output_frame(new Matrix);
 | 
					 | 
				
			||||||
    output_frame->setZero(num_channels_, frame_duration_samples_);
 | 
					 | 
				
			||||||
    for (int i = 0; i < sample_buffer_.size(); ++i) {
 | 
					 | 
				
			||||||
      output_frame->col(i) = sample_buffer_[i].first;
 | 
					 | 
				
			||||||
      current_timestamp_ = sample_buffer_[i].second;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cc->Outputs().Index(0).Add(output_frame.release(),
 | 
					  if (sample_buffer_.num_samples() > 0 && pad_final_packet_) {
 | 
				
			||||||
                               CurrentOutputTimestamp());
 | 
					    Matrix output_frame = sample_buffer_.CopySamples(frame_duration_samples_,
 | 
				
			||||||
 | 
					                                                     ¤t_timestamp_);
 | 
				
			||||||
 | 
					    cc->Outputs().Index(0).AddPacket(MakePacket<Matrix>(std::move(output_frame))
 | 
				
			||||||
 | 
					                                         .At(CurrentOutputTimestamp()));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
| 
						 | 
					@ -258,7 +338,7 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
      cc->Inputs().Index(0).Header(), &input_header));
 | 
					      cc->Inputs().Index(0).Header(), &input_header));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  sample_rate_ = input_header.sample_rate();
 | 
					  sample_rate_ = input_header.sample_rate();
 | 
				
			||||||
  num_channels_ = input_header.num_channels();
 | 
					  sample_buffer_.Init(sample_rate_, input_header.num_channels());
 | 
				
			||||||
  frame_duration_samples_ = time_series_util::SecondsToSamples(
 | 
					  frame_duration_samples_ = time_series_util::SecondsToSamples(
 | 
				
			||||||
      framer_options.frame_duration_seconds(), sample_rate_);
 | 
					      framer_options.frame_duration_seconds(), sample_rate_);
 | 
				
			||||||
  RET_CHECK_GT(frame_duration_samples_, 0)
 | 
					  RET_CHECK_GT(frame_duration_samples_, 0)
 | 
				
			||||||
| 
						 | 
					@ -312,9 +392,8 @@ absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (use_window_) {
 | 
					  if (use_window_) {
 | 
				
			||||||
    window_ = Matrix::Ones(num_channels_, 1) *
 | 
					    window_ = Eigen::Map<Eigen::RowVectorXd>(window_vector.data(),
 | 
				
			||||||
              Eigen::Map<Eigen::MatrixXd>(window_vector.data(), 1,
 | 
					                                             frame_duration_samples_)
 | 
				
			||||||
                                          frame_duration_samples_)
 | 
					 | 
				
			||||||
                  .cast<float>();
 | 
					                  .cast<float>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  use_local_timestamp_ = framer_options.use_local_timestamp();
 | 
					  use_local_timestamp_ = framer_options.use_local_timestamp();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,92 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Benchmark for TimeSeriesFramerCalculator.
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <random>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "benchmark/benchmark.h"
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/audio/time_series_framer_calculator.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/matrix.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/packet.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using ::mediapipe::Matrix;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void BM_TimeSeriesFramerCalculator(benchmark::State& state) {
 | 
				
			||||||
 | 
					  constexpr float kSampleRate = 32000.0;
 | 
				
			||||||
 | 
					  constexpr int kNumChannels = 2;
 | 
				
			||||||
 | 
					  constexpr int kFrameDurationSeconds = 5.0;
 | 
				
			||||||
 | 
					  std::mt19937 rng(0 /*seed*/);
 | 
				
			||||||
 | 
					  // Input around a half second's worth of samples at a time.
 | 
				
			||||||
 | 
					  std::uniform_int_distribution<int> input_size_dist(15000, 17000);
 | 
				
			||||||
 | 
					  // Generate a pool of random blocks of samples up front.
 | 
				
			||||||
 | 
					  std::vector<Matrix> sample_pool;
 | 
				
			||||||
 | 
					  sample_pool.reserve(20);
 | 
				
			||||||
 | 
					  for (int i = 0; i < 20; ++i) {
 | 
				
			||||||
 | 
					    sample_pool.push_back(Matrix::Random(kNumChannels, input_size_dist(rng)));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  std::uniform_int_distribution<int> pool_index_dist(0, sample_pool.size() - 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mediapipe::CalculatorGraphConfig config;
 | 
				
			||||||
 | 
					  config.add_input_stream("input");
 | 
				
			||||||
 | 
					  config.add_output_stream("output");
 | 
				
			||||||
 | 
					  auto* node = config.add_node();
 | 
				
			||||||
 | 
					  node->set_calculator("TimeSeriesFramerCalculator");
 | 
				
			||||||
 | 
					  node->add_input_stream("input");
 | 
				
			||||||
 | 
					  node->add_output_stream("output");
 | 
				
			||||||
 | 
					  mediapipe::TimeSeriesFramerCalculatorOptions* options =
 | 
				
			||||||
 | 
					      node->mutable_options()->MutableExtension(
 | 
				
			||||||
 | 
					          mediapipe::TimeSeriesFramerCalculatorOptions::ext);
 | 
				
			||||||
 | 
					  options->set_frame_duration_seconds(kFrameDurationSeconds);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (auto _ : state) {
 | 
				
			||||||
 | 
					    state.PauseTiming();  // Pause benchmark timing.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Prepare input packets of random blocks of samples.
 | 
				
			||||||
 | 
					    std::vector<mediapipe::Packet> input_packets;
 | 
				
			||||||
 | 
					    input_packets.reserve(32);
 | 
				
			||||||
 | 
					    float t = 0;
 | 
				
			||||||
 | 
					    for (int i = 0; i < 32; ++i) {
 | 
				
			||||||
 | 
					      auto samples =
 | 
				
			||||||
 | 
					          std::make_unique<Matrix>(sample_pool[pool_index_dist(rng)]);
 | 
				
			||||||
 | 
					      const int num_samples = samples->cols();
 | 
				
			||||||
 | 
					      input_packets.push_back(mediapipe::Adopt(samples.release())
 | 
				
			||||||
 | 
					                                  .At(mediapipe::Timestamp::FromSeconds(t)));
 | 
				
			||||||
 | 
					      t += num_samples / kSampleRate;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // Initialize graph.
 | 
				
			||||||
 | 
					    mediapipe::CalculatorGraph graph;
 | 
				
			||||||
 | 
					    CHECK_OK(graph.Initialize(config));
 | 
				
			||||||
 | 
					    // Prepare input header.
 | 
				
			||||||
 | 
					    auto header = std::make_unique<mediapipe::TimeSeriesHeader>();
 | 
				
			||||||
 | 
					    header->set_sample_rate(kSampleRate);
 | 
				
			||||||
 | 
					    header->set_num_channels(kNumChannels);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    state.ResumeTiming();  // Resume benchmark timing.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    CHECK_OK(graph.StartRun({}, {{"input", Adopt(header.release())}}));
 | 
				
			||||||
 | 
					    for (auto& packet : input_packets) {
 | 
				
			||||||
 | 
					      CHECK_OK(graph.AddPacketToInputStream("input", packet));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    CHECK(!graph.HasError());
 | 
				
			||||||
 | 
					    CHECK_OK(graph.CloseAllInputStreams());
 | 
				
			||||||
 | 
					    CHECK_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					BENCHMARK(BM_TimeSeriesFramerCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BENCHMARK_MAIN();
 | 
				
			||||||
| 
						 | 
					@ -117,6 +117,7 @@ mediapipe_proto_library(
 | 
				
			||||||
        "//mediapipe/framework:calculator_proto",
 | 
					        "//mediapipe/framework:calculator_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:classification_proto",
 | 
					        "//mediapipe/framework/formats:classification_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:landmark_proto",
 | 
					        "//mediapipe/framework/formats:landmark_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:matrix_data_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:time_series_header_proto",
 | 
					        "//mediapipe/framework/formats:time_series_header_proto",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -289,6 +290,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/framework/api2:node",
 | 
					        "//mediapipe/framework/api2:node",
 | 
				
			||||||
        "//mediapipe/framework/api2:port",
 | 
					        "//mediapipe/framework/api2:port",
 | 
				
			||||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
					        "//mediapipe/framework/formats:classification_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:image",
 | 
				
			||||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
					        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:tensor",
 | 
					        "//mediapipe/framework/formats:tensor",
 | 
				
			||||||
        "//mediapipe/framework/port:integral_types",
 | 
					        "//mediapipe/framework/port:integral_types",
 | 
				
			||||||
| 
						 | 
					@ -1167,6 +1169,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/framework:collection_item_id",
 | 
					        "//mediapipe/framework:collection_item_id",
 | 
				
			||||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
					        "//mediapipe/framework/formats:classification_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
					        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:matrix_data_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
					        "//mediapipe/framework/formats:time_series_header_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework/port:integral_types",
 | 
					        "//mediapipe/framework/port:integral_types",
 | 
				
			||||||
        "//mediapipe/framework/port:ret_check",
 | 
					        "//mediapipe/framework/port:ret_check",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,6 +17,7 @@
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/image.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
					#include "mediapipe/framework/formats/landmark.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/tensor.h"
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
#include "mediapipe/framework/port/integral_types.h"
 | 
					#include "mediapipe/framework/port/integral_types.h"
 | 
				
			||||||
| 
						 | 
					@ -104,4 +105,7 @@ typedef ConcatenateVectorCalculator<mediapipe::RenderData>
 | 
				
			||||||
    ConcatenateRenderDataVectorCalculator;
 | 
					    ConcatenateRenderDataVectorCalculator;
 | 
				
			||||||
MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
 | 
					MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef ConcatenateVectorCalculator<mediapipe::Image>
 | 
				
			||||||
 | 
					    ConcatenateImageVectorCalculator;
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(ConcatenateImageVectorCalculator);
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@
 | 
				
			||||||
#include "mediapipe/framework/collection_item_id.h"
 | 
					#include "mediapipe/framework/collection_item_id.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/landmark.pb.h"
 | 
					#include "mediapipe/framework/formats/landmark.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/matrix_data.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
					#include "mediapipe/framework/formats/time_series_header.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/port/canonical_errors.h"
 | 
					#include "mediapipe/framework/port/canonical_errors.h"
 | 
				
			||||||
#include "mediapipe/framework/port/integral_types.h"
 | 
					#include "mediapipe/framework/port/integral_types.h"
 | 
				
			||||||
| 
						 | 
					@ -85,8 +86,12 @@ class ConstantSidePacketCalculator : public CalculatorBase {
 | 
				
			||||||
        packet.Set<LandmarkList>();
 | 
					        packet.Set<LandmarkList>();
 | 
				
			||||||
      } else if (packet_options.has_double_value()) {
 | 
					      } else if (packet_options.has_double_value()) {
 | 
				
			||||||
        packet.Set<double>();
 | 
					        packet.Set<double>();
 | 
				
			||||||
 | 
					      } else if (packet_options.has_matrix_data_value()) {
 | 
				
			||||||
 | 
					        packet.Set<MatrixData>();
 | 
				
			||||||
      } else if (packet_options.has_time_series_header_value()) {
 | 
					      } else if (packet_options.has_time_series_header_value()) {
 | 
				
			||||||
        packet.Set<TimeSeriesHeader>();
 | 
					        packet.Set<TimeSeriesHeader>();
 | 
				
			||||||
 | 
					      } else if (packet_options.has_int64_value()) {
 | 
				
			||||||
 | 
					        packet.Set<int64_t>();
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        return absl::InvalidArgumentError(
 | 
					        return absl::InvalidArgumentError(
 | 
				
			||||||
            "None of supported values were specified in options.");
 | 
					            "None of supported values were specified in options.");
 | 
				
			||||||
| 
						 | 
					@ -121,9 +126,13 @@ class ConstantSidePacketCalculator : public CalculatorBase {
 | 
				
			||||||
            MakePacket<LandmarkList>(packet_options.landmark_list_value()));
 | 
					            MakePacket<LandmarkList>(packet_options.landmark_list_value()));
 | 
				
			||||||
      } else if (packet_options.has_double_value()) {
 | 
					      } else if (packet_options.has_double_value()) {
 | 
				
			||||||
        packet.Set(MakePacket<double>(packet_options.double_value()));
 | 
					        packet.Set(MakePacket<double>(packet_options.double_value()));
 | 
				
			||||||
 | 
					      } else if (packet_options.has_matrix_data_value()) {
 | 
				
			||||||
 | 
					        packet.Set(MakePacket<MatrixData>(packet_options.matrix_data_value()));
 | 
				
			||||||
      } else if (packet_options.has_time_series_header_value()) {
 | 
					      } else if (packet_options.has_time_series_header_value()) {
 | 
				
			||||||
        packet.Set(MakePacket<TimeSeriesHeader>(
 | 
					        packet.Set(MakePacket<TimeSeriesHeader>(
 | 
				
			||||||
            packet_options.time_series_header_value()));
 | 
					            packet_options.time_series_header_value()));
 | 
				
			||||||
 | 
					      } else if (packet_options.has_int64_value()) {
 | 
				
			||||||
 | 
					        packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        return absl::InvalidArgumentError(
 | 
					        return absl::InvalidArgumentError(
 | 
				
			||||||
            "None of supported values were specified in options.");
 | 
					            "None of supported values were specified in options.");
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@ package mediapipe;
 | 
				
			||||||
import "mediapipe/framework/calculator.proto";
 | 
					import "mediapipe/framework/calculator.proto";
 | 
				
			||||||
import "mediapipe/framework/formats/classification.proto";
 | 
					import "mediapipe/framework/formats/classification.proto";
 | 
				
			||||||
import "mediapipe/framework/formats/landmark.proto";
 | 
					import "mediapipe/framework/formats/landmark.proto";
 | 
				
			||||||
 | 
					import "mediapipe/framework/formats/matrix_data.proto";
 | 
				
			||||||
import "mediapipe/framework/formats/time_series_header.proto";
 | 
					import "mediapipe/framework/formats/time_series_header.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message ConstantSidePacketCalculatorOptions {
 | 
					message ConstantSidePacketCalculatorOptions {
 | 
				
			||||||
| 
						 | 
					@ -29,14 +30,16 @@ message ConstantSidePacketCalculatorOptions {
 | 
				
			||||||
  message ConstantSidePacket {
 | 
					  message ConstantSidePacket {
 | 
				
			||||||
    oneof value {
 | 
					    oneof value {
 | 
				
			||||||
      int32 int_value = 1;
 | 
					      int32 int_value = 1;
 | 
				
			||||||
 | 
					      uint64 uint64_value = 5;
 | 
				
			||||||
 | 
					      int64 int64_value = 11;
 | 
				
			||||||
      float float_value = 2;
 | 
					      float float_value = 2;
 | 
				
			||||||
 | 
					      double double_value = 9;
 | 
				
			||||||
      bool bool_value = 3;
 | 
					      bool bool_value = 3;
 | 
				
			||||||
      string string_value = 4;
 | 
					      string string_value = 4;
 | 
				
			||||||
      uint64 uint64_value = 5;
 | 
					 | 
				
			||||||
      ClassificationList classification_list_value = 6;
 | 
					      ClassificationList classification_list_value = 6;
 | 
				
			||||||
      LandmarkList landmark_list_value = 7;
 | 
					      LandmarkList landmark_list_value = 7;
 | 
				
			||||||
      double double_value = 9;
 | 
					 | 
				
			||||||
      TimeSeriesHeader time_series_header_value = 10;
 | 
					      TimeSeriesHeader time_series_header_value = 10;
 | 
				
			||||||
 | 
					      MatrixData matrix_data_value = 12;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,6 +12,7 @@
 | 
				
			||||||
// See the License for the specific language governing permissions and
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
// limitations under the License.
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstdint>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/strings/string_view.h"
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
| 
						 | 
					@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
 | 
				
			||||||
  DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
 | 
					  DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
 | 
				
			||||||
  DoTestSingleSidePacket("{ bool_value: true }", true);
 | 
					  DoTestSingleSidePacket("{ bool_value: true }", true);
 | 
				
			||||||
  DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
 | 
					  DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
 | 
				
			||||||
 | 
					  DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
 | 
					TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -228,7 +228,6 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
					        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
				
			||||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
					        "@com_google_absl//absl/container:flat_hash_set",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    alwayslink = 1,
 | 
					    alwayslink = 1,
 | 
				
			||||||
| 
						 | 
					@ -280,7 +279,6 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
 | 
					        "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
 | 
				
			||||||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
					        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    alwayslink = 1,
 | 
					    alwayslink = 1,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,7 +22,6 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/container/flat_hash_set.h"
 | 
					#include "absl/container/flat_hash_set.h"
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					 | 
				
			||||||
#include "absl/strings/ascii.h"
 | 
					#include "absl/strings/ascii.h"
 | 
				
			||||||
#include "absl/strings/string_view.h"
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
#include "absl/strings/substitute.h"
 | 
					#include "absl/strings/substitute.h"
 | 
				
			||||||
| 
						 | 
					@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
 | 
				
			||||||
  input_tensors.reserve(kNumInputTensorsForBert);
 | 
					  input_tensors.reserve(kNumInputTensorsForBert);
 | 
				
			||||||
  for (int i = 0; i < kNumInputTensorsForBert; ++i) {
 | 
					  for (int i = 0; i < kNumInputTensorsForBert; ++i) {
 | 
				
			||||||
    input_tensors.push_back(
 | 
					    input_tensors.push_back(
 | 
				
			||||||
        {Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
 | 
					        {Tensor::ElementType::kInt32,
 | 
				
			||||||
 | 
					         Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  std::memcpy(input_tensors[input_ids_tensor_index_]
 | 
					  std::memcpy(input_tensors[input_ids_tensor_index_]
 | 
				
			||||||
                  .GetCpuWriteView()
 | 
					                  .GetCpuWriteView()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
				
			||||||
    CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
 | 
					    CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
 | 
				
			||||||
  // Read CPU input into tensors.
 | 
					  // Read CPU input into tensors.
 | 
				
			||||||
  RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
 | 
					  RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // If the input tensors have dynamic shape, then the tensors need to be
 | 
				
			||||||
 | 
					  // resized and reallocated before we can copy the tensor values.
 | 
				
			||||||
 | 
					  bool resized_tensor_shapes = false;
 | 
				
			||||||
 | 
					  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
				
			||||||
 | 
					    if (input_tensors[i].shape().is_dynamic) {
 | 
				
			||||||
 | 
					      interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
 | 
				
			||||||
 | 
					      resized_tensor_shapes = true;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Reallocation is needed for memory sanity.
 | 
				
			||||||
 | 
					  if (resized_tensor_shapes) interpreter_->AllocateTensors();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
					  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
				
			||||||
    const TfLiteType input_tensor_type =
 | 
					    const TfLiteType input_tensor_type =
 | 
				
			||||||
        interpreter_->tensor(interpreter_->inputs()[i])->type;
 | 
					        interpreter_->tensor(interpreter_->inputs()[i])->type;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,7 +20,6 @@
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					 | 
				
			||||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
 | 
					#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/api2/node.h"
 | 
					#include "mediapipe/framework/api2/node.h"
 | 
				
			||||||
#include "mediapipe/framework/api2/port.h"
 | 
					#include "mediapipe/framework/api2/port.h"
 | 
				
			||||||
| 
						 | 
					@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
  // not found in the tokenizer vocab.
 | 
					  // not found in the tokenizer vocab.
 | 
				
			||||||
  std::vector<Tensor> result;
 | 
					  std::vector<Tensor> result;
 | 
				
			||||||
  result.push_back(
 | 
					  result.push_back(
 | 
				
			||||||
      {Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
 | 
					      {Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
 | 
				
			||||||
  std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
 | 
					  std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
 | 
				
			||||||
              input_tokens.data(), input_tokens.size() * sizeof(int32_t));
 | 
					              input_tokens.data(), input_tokens.size() * sizeof(int32_t));
 | 
				
			||||||
  kTensorsOut(cc).Send(std::move(result));
 | 
					  kTensorsOut(cc).Send(std::move(result));
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1077,6 +1077,7 @@ cc_test(
 | 
				
			||||||
    linkstatic = 1,
 | 
					    linkstatic = 1,
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":tensor_to_image_frame_calculator",
 | 
					        ":tensor_to_image_frame_calculator",
 | 
				
			||||||
 | 
					        ":tensor_to_image_frame_calculator_cc_proto",
 | 
				
			||||||
        "//mediapipe/framework:calculator_framework",
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
        "//mediapipe/framework:calculator_runner",
 | 
					        "//mediapipe/framework:calculator_runner",
 | 
				
			||||||
        "//mediapipe/framework/formats:image_frame",
 | 
					        "//mediapipe/framework/formats:image_frame",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 private:
 | 
				
			||||||
  float scale_factor_;
 | 
					  float scale_factor_;
 | 
				
			||||||
 | 
					  bool scale_per_frame_min_max_;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
 | 
					REGISTER_CALCULATOR(TensorToImageFrameCalculator);
 | 
				
			||||||
| 
						 | 
					@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
 | 
				
			||||||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
 | 
					absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
 | 
				
			||||||
  scale_factor_ =
 | 
					  scale_factor_ =
 | 
				
			||||||
      cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
 | 
					      cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
 | 
				
			||||||
 | 
					  scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
 | 
				
			||||||
 | 
					                                 .scale_per_frame_min_max();
 | 
				
			||||||
  cc->SetOffset(TimestampDiff(0));
 | 
					  cc->SetOffset(TimestampDiff(0));
 | 
				
			||||||
  return absl::OkStatus();
 | 
					  return absl::OkStatus();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
  auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
 | 
					  auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
 | 
				
			||||||
  const int32_t total_size = height * width * depth;
 | 
					  const int32_t total_size = height * width * depth;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (scale_per_frame_min_max_) {
 | 
				
			||||||
 | 
					    RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
 | 
				
			||||||
 | 
					        << "Setting scale_per_frame_min_max requires FLOAT input tensors.";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  ::std::unique_ptr<const ImageFrame> output;
 | 
					  ::std::unique_ptr<const ImageFrame> output;
 | 
				
			||||||
  if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
 | 
					  if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
 | 
				
			||||||
    // Allocate buffer with alignments.
 | 
					    // Allocate buffer with alignments.
 | 
				
			||||||
    std::unique_ptr<uint8_t[]> buffer(
 | 
					    std::unique_ptr<uint8_t[]> buffer(
 | 
				
			||||||
        new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
 | 
					        new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
 | 
				
			||||||
    auto data = input_tensor.flat<float>().data();
 | 
					    auto data = input_tensor.flat<float>().data();
 | 
				
			||||||
 | 
					    float min = 1e23;
 | 
				
			||||||
 | 
					    float max = -1e23;
 | 
				
			||||||
 | 
					    if (scale_per_frame_min_max_) {
 | 
				
			||||||
 | 
					      for (int i = 0; i < total_size; ++i) {
 | 
				
			||||||
 | 
					        float d = scale_factor_ * data[i];
 | 
				
			||||||
 | 
					        if (d < min) {
 | 
				
			||||||
 | 
					          min = d;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (d > max) {
 | 
				
			||||||
 | 
					          max = d;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    for (int i = 0; i < total_size; ++i) {
 | 
					    for (int i = 0; i < total_size; ++i) {
 | 
				
			||||||
      float d = scale_factor_ * data[i];
 | 
					      float d = data[i];
 | 
				
			||||||
      if (d < 0) d = 0;
 | 
					      if (scale_per_frame_min_max_) {
 | 
				
			||||||
      if (d > 255) d = 255;
 | 
					        d = 255 * (d - min) / (max - min + 1e-9);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        d = scale_factor_ * d;
 | 
				
			||||||
 | 
					        if (d < 0) d = 0;
 | 
				
			||||||
 | 
					        if (d > 255) d = 255;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
      buffer[i] = d;
 | 
					      buffer[i] = d;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    output = ::absl::make_unique<ImageFrame>(
 | 
					    output = ::absl::make_unique<ImageFrame>(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
 | 
				
			||||||
  // Multiples floating point tensor outputs by this value before converting to
 | 
					  // Multiples floating point tensor outputs by this value before converting to
 | 
				
			||||||
  // uint8. This is useful for converting from range [0, 1] to [0, 255]
 | 
					  // uint8. This is useful for converting from range [0, 1] to [0, 255]
 | 
				
			||||||
  optional float scale_factor = 1 [default = 1.0];
 | 
					  optional float scale_factor = 1 [default = 1.0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
 | 
				
			||||||
 | 
					  // per frame. This overrides any explicit scale_factor.
 | 
				
			||||||
 | 
					  optional bool scale_per_frame_min_max = 2 [default = false];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,7 +11,9 @@
 | 
				
			||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
// See the License for the specific language governing permissions and
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
// limitations under the License.
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					#include <type_traits>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_runner.h"
 | 
					#include "mediapipe/framework/calculator_runner.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/image_frame.h"
 | 
					#include "mediapipe/framework/formats/image_frame.h"
 | 
				
			||||||
| 
						 | 
					@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
 | 
				
			||||||
template <class TypeParam>
 | 
					template <class TypeParam>
 | 
				
			||||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
 | 
					class TensorToImageFrameCalculatorTest : public ::testing::Test {
 | 
				
			||||||
 protected:
 | 
					 protected:
 | 
				
			||||||
  void SetUpRunner() {
 | 
					  void SetUpRunner(bool scale_per_frame_min_max = false) {
 | 
				
			||||||
    CalculatorGraphConfig::Node config;
 | 
					    CalculatorGraphConfig::Node config;
 | 
				
			||||||
    config.set_calculator("TensorToImageFrameCalculator");
 | 
					    config.set_calculator("TensorToImageFrameCalculator");
 | 
				
			||||||
    config.add_input_stream("TENSOR:input_tensor");
 | 
					    config.add_input_stream("TENSOR:input_tensor");
 | 
				
			||||||
    config.add_output_stream("IMAGE:output_image");
 | 
					    config.add_output_stream("IMAGE:output_image");
 | 
				
			||||||
 | 
					    config.mutable_options()
 | 
				
			||||||
 | 
					        ->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
 | 
				
			||||||
 | 
					        ->set_scale_per_frame_min_max(scale_per_frame_min_max);
 | 
				
			||||||
    runner_ = absl::make_unique<CalculatorRunner>(config);
 | 
					    runner_ = absl::make_unique<CalculatorRunner>(config);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TYPED_TEST(TensorToImageFrameCalculatorTest,
 | 
				
			||||||
 | 
					           Converts3DTensorToImageFrame2DGrayWithScaling) {
 | 
				
			||||||
 | 
					  this->SetUpRunner(true);
 | 
				
			||||||
 | 
					  auto& runner = this->runner_;
 | 
				
			||||||
 | 
					  constexpr int kWidth = 16;
 | 
				
			||||||
 | 
					  constexpr int kHeight = 8;
 | 
				
			||||||
 | 
					  const tf::TensorShape tensor_shape{kHeight, kWidth};
 | 
				
			||||||
 | 
					  auto tensor = absl::make_unique<tf::Tensor>(
 | 
				
			||||||
 | 
					      tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
 | 
				
			||||||
 | 
					  auto tensor_vec = tensor->template flat<TypeParam>().data();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Writing sequence of integers as floats which we want normalized.
 | 
				
			||||||
 | 
					  tensor_vec[0] = 255;
 | 
				
			||||||
 | 
					  for (int i = 1; i < kWidth * kHeight; ++i) {
 | 
				
			||||||
 | 
					    tensor_vec[i] = 200;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int64_t time = 1234;
 | 
				
			||||||
 | 
					  runner->MutableInputs()->Tag(kTensor).packets.push_back(
 | 
				
			||||||
 | 
					      Adopt(tensor.release()).At(Timestamp(time)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!std::is_same<TypeParam, float>::value) {
 | 
				
			||||||
 | 
					    EXPECT_FALSE(runner->Run().ok());
 | 
				
			||||||
 | 
					    return;  // Short circuit because does not apply to other types.
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    EXPECT_TRUE(runner->Run().ok());
 | 
				
			||||||
 | 
					    const std::vector<Packet>& output_packets =
 | 
				
			||||||
 | 
					        runner->Outputs().Tag(kImage).packets;
 | 
				
			||||||
 | 
					    EXPECT_EQ(1, output_packets.size());
 | 
				
			||||||
 | 
					    EXPECT_EQ(time, output_packets[0].Timestamp().Value());
 | 
				
			||||||
 | 
					    const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
 | 
				
			||||||
 | 
					    EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
 | 
				
			||||||
 | 
					    EXPECT_EQ(kWidth, output_image.Width());
 | 
				
			||||||
 | 
					    EXPECT_EQ(kHeight, output_image.Height());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    EXPECT_EQ(255, output_image.PixelData()[0]);
 | 
				
			||||||
 | 
					    for (int i = 1; i < kWidth * kHeight; ++i) {
 | 
				
			||||||
 | 
					      const uint8_t pixel_value = output_image.PixelData()[i];
 | 
				
			||||||
 | 
					      ASSERT_EQ(0, pixel_value);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace mediapipe
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1355,6 +1355,23 @@ cc_test(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "calculator_graph_summary_packet_test",
 | 
				
			||||||
 | 
					    srcs = ["calculator_graph_summary_packet_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":calculator_framework",
 | 
				
			||||||
 | 
					        ":packet",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:node",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:packet",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:port",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest_main",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:parse_text_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/tool:sink",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_test(
 | 
					cc_test(
 | 
				
			||||||
    name = "calculator_runner_test",
 | 
					    name = "calculator_runner_test",
 | 
				
			||||||
    size = "medium",
 | 
					    size = "medium",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,7 +32,7 @@ template <class T>
 | 
				
			||||||
struct dependent_false : std::false_type {};
 | 
					struct dependent_false : std::false_type {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename T>
 | 
					template <typename T>
 | 
				
			||||||
T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, int index) {
 | 
					T& GetWithAutoGrow(std::vector<std::unique_ptr<T>>* vecp, size_t index) {
 | 
				
			||||||
  auto& vec = *vecp;
 | 
					  auto& vec = *vecp;
 | 
				
			||||||
  if (vec.size() <= index) {
 | 
					  if (vec.size() <= index) {
 | 
				
			||||||
    vec.resize(index + 1);
 | 
					    vec.resize(index + 1);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -109,9 +109,20 @@ class CalculatorContext {
 | 
				
			||||||
  // use OutputStream::SetOffset() directly.
 | 
					  // use OutputStream::SetOffset() directly.
 | 
				
			||||||
  void SetOffset(TimestampDiff offset);
 | 
					  void SetOffset(TimestampDiff offset);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Returns the status of the graph run.
 | 
					  // DEPRECATED: This was intended to get graph run status during
 | 
				
			||||||
 | 
					  // `CalculatorBase::Close` call. However, `Close` can run simultaneously with
 | 
				
			||||||
 | 
					  // other calculators `CalculatorBase::Process`, hence the actual graph
 | 
				
			||||||
 | 
					  // status may change any time and returned graph status here does not
 | 
				
			||||||
 | 
					  // necessarily reflect the actual graph status.
 | 
				
			||||||
  //
 | 
					  //
 | 
				
			||||||
  // NOTE: This method should only be called during CalculatorBase::Close().
 | 
					  // As an alternative, instead of checking graph status in `Close` and doing
 | 
				
			||||||
 | 
					  // work for "done" state, you can enable timestamp bound processing for your
 | 
				
			||||||
 | 
					  // calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger
 | 
				
			||||||
 | 
					  // `Process` on timestamp bound updates and handle "done" state there.
 | 
				
			||||||
 | 
					  // Check examples in:
 | 
				
			||||||
 | 
					  // mediapipe/framework/calculator_graph_summary_packet_test.cc.
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
 | 
					  ABSL_DEPRECATED("Does not reflect the actual graph status.")
 | 
				
			||||||
  absl::Status GraphStatus() const { return graph_status_; }
 | 
					  absl::Status GraphStatus() const { return graph_status_; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ProfilingContext* GetProfilingContext() const {
 | 
					  ProfilingContext* GetProfilingContext() const {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -839,6 +839,13 @@ absl::Status CalculatorGraph::PrepareForRun(
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status CalculatorGraph::WaitUntilIdle() {
 | 
					absl::Status CalculatorGraph::WaitUntilIdle() {
 | 
				
			||||||
 | 
					  if (has_sources_) {
 | 
				
			||||||
 | 
					    LOG_FIRST_N(WARNING, 1)
 | 
				
			||||||
 | 
					        << "WaitUntilIdle called on a graph with source nodes, which "
 | 
				
			||||||
 | 
					           "is not fully supported at the moment. Source nodes: "
 | 
				
			||||||
 | 
					        << ListSourceNodes();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle());
 | 
					  MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle());
 | 
				
			||||||
  VLOG(2) << "Scheduler idle.";
 | 
					  VLOG(2) << "Scheduler idle.";
 | 
				
			||||||
  absl::Status status = absl::OkStatus();
 | 
					  absl::Status status = absl::OkStatus();
 | 
				
			||||||
| 
						 | 
					@ -1368,6 +1375,16 @@ const OutputStreamManager* CalculatorGraph::FindOutputStreamManager(
 | 
				
			||||||
              .get()[validated_graph_->OutputStreamIndex(name)];
 | 
					              .get()[validated_graph_->OutputStreamIndex(name)];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::string CalculatorGraph::ListSourceNodes() const {
 | 
				
			||||||
 | 
					  std::vector<std::string> sources;
 | 
				
			||||||
 | 
					  for (auto& node : nodes_) {
 | 
				
			||||||
 | 
					    if (node->IsSource()) {
 | 
				
			||||||
 | 
					      sources.push_back(node->DebugName());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return absl::StrJoin(sources, ", ");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
void PrintTimingToInfo(const std::string& label, int64_t timer_value) {
 | 
					void PrintTimingToInfo(const std::string& label, int64_t timer_value) {
 | 
				
			||||||
  const int64_t total_seconds = timer_value / 1000000ll;
 | 
					  const int64_t total_seconds = timer_value / 1000000ll;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -229,8 +229,11 @@ class CalculatorGraph {
 | 
				
			||||||
  // Wait until the running graph is in the idle mode, which is when nothing can
 | 
					  // Wait until the running graph is in the idle mode, which is when nothing can
 | 
				
			||||||
  // be scheduled and nothing is running in the worker threads. This function
 | 
					  // be scheduled and nothing is running in the worker threads. This function
 | 
				
			||||||
  // can be called only after StartRun().
 | 
					  // can be called only after StartRun().
 | 
				
			||||||
 | 
					  //
 | 
				
			||||||
  // NOTE: The graph must not have any source nodes because source nodes prevent
 | 
					  // NOTE: The graph must not have any source nodes because source nodes prevent
 | 
				
			||||||
  // the running graph from becoming idle until the source nodes are done.
 | 
					  // the running graph from becoming idle until the source nodes are done.
 | 
				
			||||||
 | 
					  // Currently, `WaitUntilIdle` cannot be used reliably on graphs with any
 | 
				
			||||||
 | 
					  // source nodes.
 | 
				
			||||||
  absl::Status WaitUntilIdle();
 | 
					  absl::Status WaitUntilIdle();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Wait until a packet is emitted on one of the observed output streams.
 | 
					  // Wait until a packet is emitted on one of the observed output streams.
 | 
				
			||||||
| 
						 | 
					@ -594,6 +597,9 @@ class CalculatorGraph {
 | 
				
			||||||
  // status before taking any action.
 | 
					  // status before taking any action.
 | 
				
			||||||
  void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
 | 
					  void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Returns a comma-separated list of source nodes.
 | 
				
			||||||
 | 
					  std::string ListSourceNodes() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
					#if !MEDIAPIPE_DISABLE_GPU
 | 
				
			||||||
  // Owns the legacy GpuSharedData if we need to create one for backwards
 | 
					  // Owns the legacy GpuSharedData if we need to create one for backwards
 | 
				
			||||||
  // compatibility.
 | 
					  // compatibility.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										430
									
								
								mediapipe/framework/calculator_graph_summary_packet_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										430
									
								
								mediapipe/framework/calculator_graph_summary_packet_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,430 @@
 | 
				
			||||||
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/node.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/packet.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/port.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/packet.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/parse_text_proto.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_matchers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using ::mediapipe::api2::Input;
 | 
				
			||||||
 | 
					using ::mediapipe::api2::Node;
 | 
				
			||||||
 | 
					using ::mediapipe::api2::Output;
 | 
				
			||||||
 | 
					using ::testing::ElementsAre;
 | 
				
			||||||
 | 
					using ::testing::Eq;
 | 
				
			||||||
 | 
					using ::testing::HasSubstr;
 | 
				
			||||||
 | 
					using ::testing::IsEmpty;
 | 
				
			||||||
 | 
					using ::testing::Value;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MATCHER_P2(IntPacket, value, timestamp, "") {
 | 
				
			||||||
 | 
					  *result_listener << "where object is (value: " << arg.template Get<int>()
 | 
				
			||||||
 | 
					                   << ", timestamp: " << arg.Timestamp() << ")";
 | 
				
			||||||
 | 
					  return Value(arg.template Get<int>(), Eq(value)) &&
 | 
				
			||||||
 | 
					         Value(arg.Timestamp(), Eq(timestamp));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Calculates and produces sum of all passed inputs when no more packets can be
 | 
				
			||||||
 | 
					// expected on the input stream.
 | 
				
			||||||
 | 
					class SummaryPacketCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Input<int> kIn{"IN"};
 | 
				
			||||||
 | 
					  static constexpr Output<int> kOut{"SUMMARY"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static absl::Status UpdateContract(CalculatorContract* cc) {
 | 
				
			||||||
 | 
					    // Makes sure there are no automatic timestamp bound updates when Process
 | 
				
			||||||
 | 
					    // is called.
 | 
				
			||||||
 | 
					    cc->SetTimestampOffset(TimestampDiff::Unset());
 | 
				
			||||||
 | 
					    // Currently, only ImmediateInputStreamHandler supports "done" timestamp
 | 
				
			||||||
 | 
					    // bound update. (ImmediateInputStreamhandler handles multiple input
 | 
				
			||||||
 | 
					    // streams differently, so, in that case, calculator adjustments may be
 | 
				
			||||||
 | 
					    // required.)
 | 
				
			||||||
 | 
					    // TODO: update all input stream handlers to support "done"
 | 
				
			||||||
 | 
					    // timestamp bound update.
 | 
				
			||||||
 | 
					    cc->SetInputStreamHandler("ImmediateInputStreamHandler");
 | 
				
			||||||
 | 
					    // Enables processing timestamp bound updates. For this use case we are
 | 
				
			||||||
 | 
					    // specifically interested in "done" timestamp bound update. (E.g. when
 | 
				
			||||||
 | 
					    // all input packet sources are closed.)
 | 
				
			||||||
 | 
					    cc->SetProcessTimestampBounds(true);
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) final {
 | 
				
			||||||
 | 
					    if (!kIn(cc).IsEmpty()) {
 | 
				
			||||||
 | 
					      value_ += kIn(cc).Get();
 | 
				
			||||||
 | 
					      value_set_ = true;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (kOut(cc).IsClosed()) {
 | 
				
			||||||
 | 
					      // This can happen:
 | 
				
			||||||
 | 
					      // 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
 | 
				
			||||||
 | 
					      //    source calculator finished generating packets sent to kIn) and
 | 
				
			||||||
 | 
					      //    HasNextAllowedInStream() == true (which is an often case).
 | 
				
			||||||
 | 
					      // 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
 | 
				
			||||||
 | 
					      //    invoke Process() with Timestamp::Max to indicate "Done" timestamp
 | 
				
			||||||
 | 
					      //    bound update.
 | 
				
			||||||
 | 
					      return absl::OkStatus();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // TODO: input stream holding a packet with timestamp that has
 | 
				
			||||||
 | 
					    // no next timestamp allowed in stream should always result in
 | 
				
			||||||
 | 
					    // InputStream::IsDone() == true.
 | 
				
			||||||
 | 
					    if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
 | 
				
			||||||
 | 
					      // `Process` may or may not be invoked for "done" timestamp bound when
 | 
				
			||||||
 | 
					      // upstream calculator fails in `Close`. Hence, extra care is needed to
 | 
				
			||||||
 | 
					      // identify whether the calculator needs to send output.
 | 
				
			||||||
 | 
					      // TODO: remove when "done" timestamp bound flakiness fixed.
 | 
				
			||||||
 | 
					      if (value_set_) {
 | 
				
			||||||
 | 
					        // kOut(cc).Send(value_) can be used here as well, however in the case
 | 
				
			||||||
 | 
					        // of source calculator sending inputs into kIn the resulting timestamp
 | 
				
			||||||
 | 
					        // is not well defined (e.g. it can be the last packet timestamp or
 | 
				
			||||||
 | 
					        // Timestamp::Max())
 | 
				
			||||||
 | 
					        // TODO: last packet from source should always result in
 | 
				
			||||||
 | 
					        // InputStream::IsDone() == true.
 | 
				
			||||||
 | 
					        kOut(cc).Send(value_, Timestamp::Max());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      kOut(cc).Close();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  int value_ = 0;
 | 
				
			||||||
 | 
					  bool value_set_ = false;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     ProducesSummaryPacketOnClosingAllPacketSources) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					    input_stream: 'input'
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					      input_stream: 'IN:input'
 | 
				
			||||||
 | 
					      output_stream: 'SUMMARY:output'
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
				
			||||||
 | 
					    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					        "input", MakePacket<int>(value).At(timestamp)));
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(10, Timestamp(10));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(20, Timestamp(11));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					    input_stream: 'input'
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					      input_stream: 'IN:input'
 | 
				
			||||||
 | 
					      output_stream: 'SUMMARY:output'
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
				
			||||||
 | 
					    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					        "input", MakePacket<int>(value).At(timestamp)));
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(10, Timestamp(10));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(20, Timestamp::Max());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  output_packets.clear();
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     ProducesSummaryPacketOnPreStreamTimestamp) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					    input_stream: 'input'
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					      input_stream: 'IN:input'
 | 
				
			||||||
 | 
					      output_stream: 'SUMMARY:output'
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
				
			||||||
 | 
					    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					        "input", MakePacket<int>(value).At(timestamp)));
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(10, Timestamp::PreStream());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  output_packets.clear();
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     ProducesSummaryPacketOnPostStreamTimestamp) {
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  CalculatorGraphConfig graph_config =
 | 
				
			||||||
 | 
					      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        input_stream: 'input'
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					          input_stream: 'IN:input'
 | 
				
			||||||
 | 
					          output_stream: 'SUMMARY:output'
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
				
			||||||
 | 
					    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					        "input", MakePacket<int>(value).At(timestamp)));
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(10, Timestamp::PostStream());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  output_packets.clear();
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class IntGeneratorCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Output<int> kOut{"INT"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) final {
 | 
				
			||||||
 | 
					    kOut(cc).Send(20, Timestamp(0));
 | 
				
			||||||
 | 
					    kOut(cc).Send(10, Timestamp(1000));
 | 
				
			||||||
 | 
					    return tool::StatusStop();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     ProducesSummaryPacketOnSourceCalculatorCompletion) {
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  CalculatorGraphConfig graph_config =
 | 
				
			||||||
 | 
					      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "IntGeneratorCalculator"
 | 
				
			||||||
 | 
					          output_stream: "INT:int_value"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					          input_stream: "IN:int_value"
 | 
				
			||||||
 | 
					          output_stream: "SUMMARY:output"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EmitOnCloseCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Input<int> kIn{"IN"};
 | 
				
			||||||
 | 
					  static constexpr Output<int> kOut{"INT"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Close(CalculatorContext* cc) final {
 | 
				
			||||||
 | 
					    kOut(cc).Send(20, Timestamp(0));
 | 
				
			||||||
 | 
					    kOut(cc).Send(10, Timestamp(1000));
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     ProducesSummaryPacketOnAnotherCalculatorClosure) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					    input_stream: "input"
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "EmitOnCloseCalculator"
 | 
				
			||||||
 | 
					      input_stream: "IN:input"
 | 
				
			||||||
 | 
					      output_stream: "INT:int_value"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					      input_stream: "IN:int_value"
 | 
				
			||||||
 | 
					      output_stream: "SUMMARY:output"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseInputStream("input"));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  output_packets.clear();
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FailureInCloseCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Input<int> kIn{"IN"};
 | 
				
			||||||
 | 
					  static constexpr Output<int> kOut{"INT"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Close(CalculatorContext* cc) final {
 | 
				
			||||||
 | 
					    return absl::InternalError("error");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(FailureInCloseCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInClose) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					    input_stream: "input"
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "FailureInCloseCalculator"
 | 
				
			||||||
 | 
					      input_stream: "IN:input"
 | 
				
			||||||
 | 
					      output_stream: "INT:int_value"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					      input_stream: "IN:int_value"
 | 
				
			||||||
 | 
					      output_stream: "SUMMARY:output"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.CloseInputStream("input"));
 | 
				
			||||||
 | 
					  EXPECT_THAT(graph.WaitUntilIdle(),
 | 
				
			||||||
 | 
					              StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FailureInProcessCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Input<int> kIn{"IN"};
 | 
				
			||||||
 | 
					  static constexpr Output<int> kOut{"INT"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) final {
 | 
				
			||||||
 | 
					    return absl::InternalError("error");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(FailureInProcessCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(SummaryPacketCalculatorUseCaseTest,
 | 
				
			||||||
 | 
					     DoesNotProduceSummaryPacketWhenUpstreamCalculatorFailsInProcess) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
				
			||||||
 | 
					    input_stream: "input"
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "FailureInProcessCalculator"
 | 
				
			||||||
 | 
					      input_stream: "IN:input"
 | 
				
			||||||
 | 
					      output_stream: "INT:int_value"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    node {
 | 
				
			||||||
 | 
					      calculator: "SummaryPacketCalculator"
 | 
				
			||||||
 | 
					      input_stream: "IN:int_value"
 | 
				
			||||||
 | 
					      output_stream: "SUMMARY:output"
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
				
			||||||
 | 
					    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					        "input", MakePacket<int>(value).At(timestamp)));
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  send_packet(10, Timestamp::PostStream());
 | 
				
			||||||
 | 
					  EXPECT_THAT(graph.WaitUntilIdle(),
 | 
				
			||||||
 | 
					              StatusIs(absl::StatusCode::kInternal, HasSubstr("error")));
 | 
				
			||||||
 | 
					  EXPECT_THAT(output_packets, IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					@ -117,11 +117,18 @@ class Tensor {
 | 
				
			||||||
    Shape() = default;
 | 
					    Shape() = default;
 | 
				
			||||||
    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
					    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
				
			||||||
    Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
 | 
					    Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
 | 
				
			||||||
 | 
					    Shape(std::initializer_list<int> dimensions, bool is_dynamic)
 | 
				
			||||||
 | 
					        : dims(dimensions), is_dynamic(is_dynamic) {}
 | 
				
			||||||
 | 
					    Shape(const std::vector<int>& dimensions, bool is_dynamic)
 | 
				
			||||||
 | 
					        : dims(dimensions), is_dynamic(is_dynamic) {}
 | 
				
			||||||
    int num_elements() const {
 | 
					    int num_elements() const {
 | 
				
			||||||
      return std::accumulate(dims.begin(), dims.end(), 1,
 | 
					      return std::accumulate(dims.begin(), dims.end(), 1,
 | 
				
			||||||
                             std::multiplies<int>());
 | 
					                             std::multiplies<int>());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    std::vector<int> dims;
 | 
					    std::vector<int> dims;
 | 
				
			||||||
 | 
					    // The Tensor has dynamic rather than static shape so the TFLite interpreter
 | 
				
			||||||
 | 
					    // needs to be reallocated. Only relevant for CPU.
 | 
				
			||||||
 | 
					    bool is_dynamic = false;
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  // Quantization parameters corresponding to the zero_point and scale value
 | 
					  // Quantization parameters corresponding to the zero_point and scale value
 | 
				
			||||||
  // made available by TfLite quantized (uint8/int8) tensors.
 | 
					  // made available by TfLite quantized (uint8/int8) tensors.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,6 +2,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <cstring>
 | 
					#include <cstring>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/framework/port/gmock.h"
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
#include "mediapipe/framework/port/gtest.h"
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
| 
						 | 
					@ -34,6 +35,17 @@ TEST(General, TestDataTypes) {
 | 
				
			||||||
  EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
 | 
					  EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(General, TestDynamic) {
 | 
				
			||||||
 | 
					  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true));
 | 
				
			||||||
 | 
					  EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4);
 | 
				
			||||||
 | 
					  EXPECT_TRUE(t1.shape().is_dynamic);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<int> t2_dims = {4, 3, 2, 3};
 | 
				
			||||||
 | 
					  Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true));
 | 
				
			||||||
 | 
					  EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3);
 | 
				
			||||||
 | 
					  EXPECT_TRUE(t2.shape().is_dynamic);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(Cpu, TestMemoryAllocation) {
 | 
					TEST(Cpu, TestMemoryAllocation) {
 | 
				
			||||||
  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
 | 
					  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
 | 
				
			||||||
  auto v1 = t1.GetCpuWriteView();
 | 
					  auto v1 = t1.GetCpuWriteView();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -273,8 +273,8 @@ absl::Status Scheduler::WaitForObservedOutput() {
 | 
				
			||||||
// Idleness requires:
 | 
					// Idleness requires:
 | 
				
			||||||
// 1. either the graph has no source nodes or all source nodes are closed, and
 | 
					// 1. either the graph has no source nodes or all source nodes are closed, and
 | 
				
			||||||
// 2. no packets are added to graph input streams.
 | 
					// 2. no packets are added to graph input streams.
 | 
				
			||||||
// For simplicity, we only allow WaitUntilIdle() to be called on a graph with
 | 
					// For simplicity, we only fully support WaitUntilIdle() to be called on a graph
 | 
				
			||||||
// no source nodes. (This is enforced by CalculatorGraph::WaitUntilIdle().)
 | 
					// with no source nodes.
 | 
				
			||||||
// The application must ensure no other threads are adding packets to graph
 | 
					// The application must ensure no other threads are adding packets to graph
 | 
				
			||||||
// input streams while a WaitUntilIdle() call is in progress.
 | 
					// input streams while a WaitUntilIdle() call is in progress.
 | 
				
			||||||
absl::Status Scheduler::WaitUntilIdle() {
 | 
					absl::Status Scheduler::WaitUntilIdle() {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
 | 
				
			||||||
  return *this + 1;
 | 
					  return *this + 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					bool Timestamp::HasNextAllowedInStream() const {
 | 
				
			||||||
 | 
					  if (*this >= Max() || *this == PreStream()) {
 | 
				
			||||||
 | 
					    return false;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Timestamp Timestamp::PreviousAllowedInStream() const {
 | 
					Timestamp Timestamp::PreviousAllowedInStream() const {
 | 
				
			||||||
  if (*this <= Min() || *this == PostStream()) {
 | 
					  if (*this <= Min() || *this == PostStream()) {
 | 
				
			||||||
    // Indicates that no previous timestamps may occur.
 | 
					    // Indicates that no previous timestamps may occur.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -186,6 +186,10 @@ class Timestamp {
 | 
				
			||||||
  // CHECKs that this->IsAllowedInStream().
 | 
					  // CHECKs that this->IsAllowedInStream().
 | 
				
			||||||
  Timestamp NextAllowedInStream() const;
 | 
					  Timestamp NextAllowedInStream() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Returns true if there's a next timestamp in the range [Min .. Max] after
 | 
				
			||||||
 | 
					  // this one.
 | 
				
			||||||
 | 
					  bool HasNextAllowedInStream() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Returns the previous timestamp in the range [Min .. Max], or
 | 
					  // Returns the previous timestamp in the range [Min .. Max], or
 | 
				
			||||||
  // Unstarted() if no Packets may preceed one with this timestamp.
 | 
					  // Unstarted() if no Packets may preceed one with this timestamp.
 | 
				
			||||||
  Timestamp PreviousAllowedInStream() const;
 | 
					  Timestamp PreviousAllowedInStream() const;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
 | 
				
			||||||
            Timestamp::PostStream().NextAllowedInStream());
 | 
					            Timestamp::PostStream().NextAllowedInStream());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TimestampTest, HasNextAllowedInStream) {
 | 
				
			||||||
 | 
					  EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
 | 
				
			||||||
 | 
					  EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(TimestampTest, SpecialValueDifferences) {
 | 
					TEST(TimestampTest, SpecialValueDifferences) {
 | 
				
			||||||
  {  // Lower range
 | 
					  {  // Lower range
 | 
				
			||||||
    const std::vector<Timestamp> timestamps = {
 | 
					    const std::vector<Timestamp> timestamps = {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -530,6 +530,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/framework/port:ret_check",
 | 
					        "//mediapipe/framework/port:ret_check",
 | 
				
			||||||
        "//mediapipe/framework/port:status",
 | 
					        "//mediapipe/framework/port:status",
 | 
				
			||||||
        "@com_google_absl//absl/base:core_headers",
 | 
					        "@com_google_absl//absl/base:core_headers",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/container:flat_hash_set",
 | 
				
			||||||
        "@com_google_absl//absl/memory",
 | 
					        "@com_google_absl//absl/memory",
 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""MediaPipe Task Library Helper Rules for iOS"""
 | 
					"""MediaPipe Task Library Helper Rules for iOS"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MPP_TASK_MINIMUM_OS_VERSION = "11.0"
 | 
					MPP_TASK_MINIMUM_OS_VERSION = "12.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# When the static framework is built with bazel, the all header files are moved
 | 
					# When the static framework is built with bazel, the all header files are moved
 | 
				
			||||||
# to the "Headers" directory with no header path prefixes. This auxiliary rule
 | 
					# to the "Headers" directory with no header path prefixes. This auxiliary rule
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,6 +50,7 @@ def mediapipe_proto_library_impl(
 | 
				
			||||||
        def_cc_proto = True,
 | 
					        def_cc_proto = True,
 | 
				
			||||||
        def_py_proto = True,
 | 
					        def_py_proto = True,
 | 
				
			||||||
        def_java_lite_proto = True,
 | 
					        def_java_lite_proto = True,
 | 
				
			||||||
 | 
					        def_kt_lite_proto = True,
 | 
				
			||||||
        def_objc_proto = True,
 | 
					        def_objc_proto = True,
 | 
				
			||||||
        def_java_proto = True,
 | 
					        def_java_proto = True,
 | 
				
			||||||
        def_jspb_proto = True,
 | 
					        def_jspb_proto = True,
 | 
				
			||||||
| 
						 | 
					@ -72,6 +73,7 @@ def mediapipe_proto_library_impl(
 | 
				
			||||||
      def_cc_proto: define the cc_proto_library target
 | 
					      def_cc_proto: define the cc_proto_library target
 | 
				
			||||||
      def_py_proto: define the py_proto_library target
 | 
					      def_py_proto: define the py_proto_library target
 | 
				
			||||||
      def_java_lite_proto: define the java_lite_proto_library target
 | 
					      def_java_lite_proto: define the java_lite_proto_library target
 | 
				
			||||||
 | 
					      def_kt_lite_proto: define the kt_lite_proto_library target
 | 
				
			||||||
      def_objc_proto: define the objc_proto_library target
 | 
					      def_objc_proto: define the objc_proto_library target
 | 
				
			||||||
      def_java_proto: define the java_proto_library target
 | 
					      def_java_proto: define the java_proto_library target
 | 
				
			||||||
      def_jspb_proto: define the jspb_proto_library target
 | 
					      def_jspb_proto: define the jspb_proto_library target
 | 
				
			||||||
| 
						 | 
					@ -255,6 +257,7 @@ def mediapipe_proto_library(
 | 
				
			||||||
        def_cc_proto = True,
 | 
					        def_cc_proto = True,
 | 
				
			||||||
        def_py_proto = True,
 | 
					        def_py_proto = True,
 | 
				
			||||||
        def_java_lite_proto = True,
 | 
					        def_java_lite_proto = True,
 | 
				
			||||||
 | 
					        def_kt_lite_proto = True,
 | 
				
			||||||
        def_portable_proto = True,  # @unused
 | 
					        def_portable_proto = True,  # @unused
 | 
				
			||||||
        def_objc_proto = True,
 | 
					        def_objc_proto = True,
 | 
				
			||||||
        def_java_proto = True,
 | 
					        def_java_proto = True,
 | 
				
			||||||
| 
						 | 
					@ -281,6 +284,7 @@ def mediapipe_proto_library(
 | 
				
			||||||
      def_cc_proto: define the cc_proto_library target
 | 
					      def_cc_proto: define the cc_proto_library target
 | 
				
			||||||
      def_py_proto: define the py_proto_library target
 | 
					      def_py_proto: define the py_proto_library target
 | 
				
			||||||
      def_java_lite_proto: define the java_lite_proto_library target
 | 
					      def_java_lite_proto: define the java_lite_proto_library target
 | 
				
			||||||
 | 
					      def_kt_lite_proto: define the kt_lite_proto_library target
 | 
				
			||||||
      def_portable_proto: ignored since portable protos are gone
 | 
					      def_portable_proto: ignored since portable protos are gone
 | 
				
			||||||
      def_objc_proto: define the objc_proto_library target
 | 
					      def_objc_proto: define the objc_proto_library target
 | 
				
			||||||
      def_java_proto: define the java_proto_library target
 | 
					      def_java_proto: define the java_proto_library target
 | 
				
			||||||
| 
						 | 
					@ -304,6 +308,7 @@ def mediapipe_proto_library(
 | 
				
			||||||
        def_cc_proto = def_cc_proto,
 | 
					        def_cc_proto = def_cc_proto,
 | 
				
			||||||
        def_py_proto = def_py_proto,
 | 
					        def_py_proto = def_py_proto,
 | 
				
			||||||
        def_java_lite_proto = def_java_lite_proto,
 | 
					        def_java_lite_proto = def_java_lite_proto,
 | 
				
			||||||
 | 
					        def_kt_lite_proto = def_kt_lite_proto,
 | 
				
			||||||
        def_objc_proto = def_objc_proto,
 | 
					        def_objc_proto = def_objc_proto,
 | 
				
			||||||
        def_java_proto = def_java_proto,
 | 
					        def_java_proto = def_java_proto,
 | 
				
			||||||
        def_jspb_proto = def_jspb_proto,
 | 
					        def_jspb_proto = def_jspb_proto,
 | 
				
			||||||
| 
						 | 
					@ -334,6 +339,7 @@ def mediapipe_proto_library(
 | 
				
			||||||
            def_cc_proto = def_cc_proto,
 | 
					            def_cc_proto = def_cc_proto,
 | 
				
			||||||
            def_py_proto = def_py_proto,
 | 
					            def_py_proto = def_py_proto,
 | 
				
			||||||
            def_java_lite_proto = def_java_lite_proto,
 | 
					            def_java_lite_proto = def_java_lite_proto,
 | 
				
			||||||
 | 
					            def_kt_lite_proto = def_kt_lite_proto,
 | 
				
			||||||
            def_objc_proto = def_objc_proto,
 | 
					            def_objc_proto = def_objc_proto,
 | 
				
			||||||
            def_java_proto = def_java_proto,
 | 
					            def_java_proto = def_java_proto,
 | 
				
			||||||
            def_jspb_proto = def_jspb_proto,
 | 
					            def_jspb_proto = def_jspb_proto,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,6 +20,7 @@
 | 
				
			||||||
#include <utility>
 | 
					#include <utility>
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/container/flat_hash_set.h"
 | 
				
			||||||
#include "absl/memory/memory.h"
 | 
					#include "absl/memory/memory.h"
 | 
				
			||||||
#include "absl/strings/ascii.h"
 | 
					#include "absl/strings/ascii.h"
 | 
				
			||||||
#include "absl/strings/numbers.h"
 | 
					#include "absl/strings/numbers.h"
 | 
				
			||||||
| 
						 | 
					@ -1430,10 +1431,10 @@ std::vector<const FieldDescriptor*> GetFields(const Message* src) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Orders map entries in dst to match src.
 | 
					// Orders map entries in dst to match src.
 | 
				
			||||||
void OrderMapEntries(const Message* src, Message* dst,
 | 
					void OrderMapEntries(const Message* src, Message* dst,
 | 
				
			||||||
                     std::set<const Message*>* seen = nullptr) {
 | 
					                     absl::flat_hash_set<const Message*>* seen = nullptr) {
 | 
				
			||||||
  std::unique_ptr<std::set<const Message*>> seen_owner;
 | 
					  std::unique_ptr<absl::flat_hash_set<const Message*>> seen_owner;
 | 
				
			||||||
  if (!seen) {
 | 
					  if (!seen) {
 | 
				
			||||||
    seen_owner = std::make_unique<std::set<const Message*>>();
 | 
					    seen_owner = std::make_unique<absl::flat_hash_set<const Message*>>();
 | 
				
			||||||
    seen = seen_owner.get();
 | 
					    seen = seen_owner.get();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (seen->count(src) > 0) {
 | 
					  if (seen->count(src) > 0) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,6 +34,7 @@ import java.util.HashMap;
 | 
				
			||||||
import java.util.Map;
 | 
					import java.util.Map;
 | 
				
			||||||
import java.util.concurrent.atomic.AtomicBoolean;
 | 
					import java.util.concurrent.atomic.AtomicBoolean;
 | 
				
			||||||
import java.util.concurrent.atomic.AtomicReference;
 | 
					import java.util.concurrent.atomic.AtomicReference;
 | 
				
			||||||
 | 
					import javax.annotation.Nullable;
 | 
				
			||||||
import javax.microedition.khronos.egl.EGLConfig;
 | 
					import javax.microedition.khronos.egl.EGLConfig;
 | 
				
			||||||
import javax.microedition.khronos.opengles.GL10;
 | 
					import javax.microedition.khronos.opengles.GL10;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -303,7 +304,7 @@ public class GlSurfaceViewRenderer implements GLSurfaceView.Renderer {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Use this when the texture is not a SurfaceTexture.
 | 
					  // Use this when the texture is not a SurfaceTexture.
 | 
				
			||||||
  public void setNextFrame(TextureFrame frame) {
 | 
					  public void setNextFrame(@Nullable TextureFrame frame) {
 | 
				
			||||||
    if (surfaceTexture != null) {
 | 
					    if (surfaceTexture != null) {
 | 
				
			||||||
      Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */);
 | 
					      Matrix.setIdentityM(textureTransformMatrix, 0 /* offset */);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,7 +50,6 @@ android_library(
 | 
				
			||||||
        "MediaPipeRunner.java",
 | 
					        "MediaPipeRunner.java",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    visibility = [
 | 
					    visibility = [
 | 
				
			||||||
        "//java/com/google/android/libraries/camera/effects:__subpackages__",
 | 
					 | 
				
			||||||
        "//mediapipe/java/com/google/mediapipe:__subpackages__",
 | 
					        "//mediapipe/java/com/google/mediapipe:__subpackages__",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    exports = [
 | 
					    exports = [
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -67,6 +67,7 @@ public class ExternalTextureRenderer {
 | 
				
			||||||
  private float[] textureTransformMatrix = new float[16];
 | 
					  private float[] textureTransformMatrix = new float[16];
 | 
				
			||||||
  private boolean flipY;
 | 
					  private boolean flipY;
 | 
				
			||||||
  private int rotation = Surface.ROTATION_0;
 | 
					  private int rotation = Surface.ROTATION_0;
 | 
				
			||||||
 | 
					  private boolean doExplicitCpuSync = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Call this to setup the shader program before rendering. */
 | 
					  /** Call this to setup the shader program before rendering. */
 | 
				
			||||||
  public void setup() {
 | 
					  public void setup() {
 | 
				
			||||||
| 
						 | 
					@ -101,6 +102,14 @@ public class ExternalTextureRenderer {
 | 
				
			||||||
    this.rotation = rotation;
 | 
					    this.rotation = rotation;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Configures whether the renderer should do an explicit CPU synchronization using glFinish upon
 | 
				
			||||||
 | 
					   * each {@link #render} call. Defaults to true.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public void setDoExplicitCpuSync(boolean doExplicitCpuSync) {
 | 
				
			||||||
 | 
					    this.doExplicitCpuSync = doExplicitCpuSync;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Renders the surfaceTexture to the framebuffer with optional vertical flip.
 | 
					   * Renders the surfaceTexture to the framebuffer with optional vertical flip.
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
| 
						 | 
					@ -150,8 +159,11 @@ public class ExternalTextureRenderer {
 | 
				
			||||||
    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
 | 
					    GLES20.glBindTexture(GLES11Ext.GL_TEXTURE_EXTERNAL_OES, 0);
 | 
				
			||||||
    ShaderUtil.checkGlError("glBindTexture");
 | 
					    ShaderUtil.checkGlError("glBindTexture");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: add sync and go back to glFlush()
 | 
					    if (doExplicitCpuSync) {
 | 
				
			||||||
    GLES20.glFinish();
 | 
					
 | 
				
			||||||
 | 
					      // TODO: add sync and go back to glFlush()
 | 
				
			||||||
 | 
					      GLES20.glFinish();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,10 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Placeholder for internal Python strict library and test compatibility macro.
 | 
					# Placeholder for internal Python strict library and test compatibility macro.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
					package(default_visibility = [
 | 
				
			||||||
 | 
					    "//cloud/ml/applications/vision/model_garden/model_oss/mediapipe:__subpackages__",
 | 
				
			||||||
 | 
					    "//mediapipe:__subpackages__",
 | 
				
			||||||
 | 
					])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
licenses(["notice"])
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,7 +43,7 @@ class Classifier(custom_model.CustomModel):
 | 
				
			||||||
    self._model: tf.keras.Model = None
 | 
					    self._model: tf.keras.Model = None
 | 
				
			||||||
    self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
 | 
					    self._optimizer: Union[str, tf.keras.optimizers.Optimizer] = None
 | 
				
			||||||
    self._loss_function: Union[str, tf.keras.losses.Loss] = None
 | 
					    self._loss_function: Union[str, tf.keras.losses.Loss] = None
 | 
				
			||||||
    self._metric_function: Union[str, tf.keras.metrics.Metric] = None
 | 
					    self._metric_functions: Sequence[Union[str, tf.keras.metrics.Metric]] = None
 | 
				
			||||||
    self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
 | 
					    self._callbacks: Sequence[tf.keras.callbacks.Callback] = None
 | 
				
			||||||
    self._hparams: hp.BaseHParams = None
 | 
					    self._hparams: hp.BaseHParams = None
 | 
				
			||||||
    self._history: tf.keras.callbacks.History = None
 | 
					    self._history: tf.keras.callbacks.History = None
 | 
				
			||||||
| 
						 | 
					@ -92,7 +92,8 @@ class Classifier(custom_model.CustomModel):
 | 
				
			||||||
    self._model.compile(
 | 
					    self._model.compile(
 | 
				
			||||||
        optimizer=self._optimizer,
 | 
					        optimizer=self._optimizer,
 | 
				
			||||||
        loss=self._loss_function,
 | 
					        loss=self._loss_function,
 | 
				
			||||||
        metrics=[self._metric_function])
 | 
					        metrics=self._metric_functions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    latest_checkpoint = (
 | 
					    latest_checkpoint = (
 | 
				
			||||||
        tf.train.latest_checkpoint(checkpoint_path)
 | 
					        tf.train.latest_checkpoint(checkpoint_path)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -80,10 +80,30 @@ py_test(
 | 
				
			||||||
    deps = [":loss_functions"],
 | 
					    deps = [":loss_functions"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					######################################################################
 | 
				
			||||||
 | 
					# Public target of the MediaPipe Model Maker Quantization Config.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Quantization Config is used to export a quantized model. Please refer
 | 
				
			||||||
 | 
					# to the specific task documentations such as:
 | 
				
			||||||
 | 
					# https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize
 | 
				
			||||||
 | 
					# for usage information.
 | 
				
			||||||
 | 
					######################################################################
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "metrics",
 | 
				
			||||||
 | 
					    srcs = ["metrics.py"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_test(
 | 
				
			||||||
 | 
					    name = "metrics_test",
 | 
				
			||||||
 | 
					    srcs = ["metrics_test.py"],
 | 
				
			||||||
 | 
					    deps = [":metrics"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "quantization",
 | 
					    name = "quantization",
 | 
				
			||||||
    srcs = ["quantization.py"],
 | 
					    srcs = ["quantization.py"],
 | 
				
			||||||
    srcs_version = "PY3",
 | 
					    srcs_version = "PY3",
 | 
				
			||||||
 | 
					    visibility = ["//visibility:public"],
 | 
				
			||||||
    deps = ["//mediapipe/model_maker/python/core/data:dataset"],
 | 
					    deps = ["//mediapipe/model_maker/python/core/data:dataset"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										104
									
								
								mediapipe/model_maker/python/core/utils/metrics.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								mediapipe/model_maker/python/core/utils/metrics.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,104 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Metrics utility library."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_binary_sparse_metric(metric: tf.metrics.Metric):
 | 
				
			||||||
 | 
					  """Helper method to create a BinarySparse version of a tf.keras.Metric.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  BinarySparse is an implementation where the update_state(y_true, y_pred) takes
 | 
				
			||||||
 | 
					  in shapes y_true=(batch_size, 1) y_pred=(batch_size, 2). Note that this only
 | 
				
			||||||
 | 
					  supports the binary classification case, and that class_id=0 is the negative
 | 
				
			||||||
 | 
					  class and class_id=1 is the positive class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Currently supported tf.metric.Metric classes
 | 
				
			||||||
 | 
					    1. BinarySparseRecallAtPrecision
 | 
				
			||||||
 | 
					    2. BinarySparsePrecisionAtRecall
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    metric: A tf.metric.Metric class for which we want to generate a
 | 
				
			||||||
 | 
					      BinarySparse version of this metric.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    A class for the BinarySparse version of the specified tf.metrics.Metric
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  class BinarySparseMetric(metric):
 | 
				
			||||||
 | 
					    """A BinarySparse wrapper class for a tf.keras.Metric.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This class has the same parameters and functions as the underlying
 | 
				
			||||||
 | 
					    metric class. For example, the parameters for BinarySparseRecallAtPrecision
 | 
				
			||||||
 | 
					    is the same as tf.keras.metrics.RecallAtPrecision. The only new constraint
 | 
				
			||||||
 | 
					    is that class_id must be set to 1 (or not specified) for the Binary metric.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					      if 'class_id' in kwargs and kwargs['class_id'] != 1:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            f'Custom BinarySparseMetric for class:{metric.__name__} is '
 | 
				
			||||||
 | 
					            'only supported for class_id=1, got class_id='
 | 
				
			||||||
 | 
					            f'{kwargs["class_id"]} instead'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      else:
 | 
				
			||||||
 | 
					        kwargs['class_id'] = 1
 | 
				
			||||||
 | 
					      super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_state(self, y_true, y_pred, sample_weight=None):
 | 
				
			||||||
 | 
					      y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32)
 | 
				
			||||||
 | 
					      y_true_one_hot = tf.one_hot(y_true, 2)
 | 
				
			||||||
 | 
					      return super().update_state(
 | 
				
			||||||
 | 
					          y_true_one_hot, y_pred, sample_weight=sample_weight
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return BinarySparseMetric
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_sparse_metric(metric: tf.metrics.Metric):
 | 
				
			||||||
 | 
					  """Helper method to create a Sparse version of a tf.keras.Metric.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Sparse is an implementation where the update_state(y_true, y_pred) takes in
 | 
				
			||||||
 | 
					  shapes y_true=(batch_size, 1) and y_pred=(batch_size, num_classes).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Currently supported tf.metrics.Metric classes:
 | 
				
			||||||
 | 
					    1. tf.metrics.Recall
 | 
				
			||||||
 | 
					    2. tf.metrics.Precision
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Args:
 | 
				
			||||||
 | 
					    metric: A tf.metric.Metric class for which we want to generate a Sparse
 | 
				
			||||||
 | 
					      version of this metric.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Returns:
 | 
				
			||||||
 | 
					    A class for the Sparse version of the specified tf.keras.Metric.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  class SparseMetric(metric):
 | 
				
			||||||
 | 
					    """A Sparse wrapper class for a tf.keras.Metric."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_state(self, y_true, y_pred, sample_weight=None):
 | 
				
			||||||
 | 
					      y_pred = tf.math.argmax(y_pred, axis=-1)
 | 
				
			||||||
 | 
					      return super().update_state(y_true, y_pred, sample_weight=sample_weight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return SparseMetric
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SparseRecall = _get_sparse_metric(tf.metrics.Recall)
 | 
				
			||||||
 | 
					SparsePrecision = _get_sparse_metric(tf.metrics.Precision)
 | 
				
			||||||
 | 
					BinarySparseRecallAtPrecision = _get_binary_sparse_metric(
 | 
				
			||||||
 | 
					    tf.metrics.RecallAtPrecision
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					BinarySparsePrecisionAtRecall = _get_binary_sparse_metric(
 | 
				
			||||||
 | 
					    tf.metrics.PrecisionAtRecall
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										74
									
								
								mediapipe/model_maker/python/core/utils/metrics_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								mediapipe/model_maker/python/core/utils/metrics_test.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,74 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from absl.testing import parameterized
 | 
				
			||||||
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import metrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SparseMetricTest(tf.test.TestCase, parameterized.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def setUp(self):
 | 
				
			||||||
 | 
					    super().setUp()
 | 
				
			||||||
 | 
					    self.y_true = [0, 0, 1, 1, 0, 1]
 | 
				
			||||||
 | 
					    self.y_pred = [
 | 
				
			||||||
 | 
					        [0.9, 0.1],  # 0, 0 y
 | 
				
			||||||
 | 
					        [0.8, 0.2],  # 0, 0 y
 | 
				
			||||||
 | 
					        [0.7, 0.3],  # 0, 1 n
 | 
				
			||||||
 | 
					        [0.6, 0.4],  # 0, 1 n
 | 
				
			||||||
 | 
					        [0.3, 0.7],  # 1, 0 y
 | 
				
			||||||
 | 
					        [0.3, 0.7],  # 1, 1 y
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    self.num_classes = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _assert_metric_equals(self, metric, value):
 | 
				
			||||||
 | 
					    metric.update_state(self.y_true, self.y_pred)
 | 
				
			||||||
 | 
					    self.assertEqual(metric.result(), value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_sparse_recall(self):
 | 
				
			||||||
 | 
					    metric = metrics.SparseRecall()
 | 
				
			||||||
 | 
					    self._assert_metric_equals(metric, 1 / 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_sparse_precision(self):
 | 
				
			||||||
 | 
					    metric = metrics.SparsePrecision()
 | 
				
			||||||
 | 
					    self._assert_metric_equals(metric, 1 / 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_binary_sparse_recall_at_precision(self):
 | 
				
			||||||
 | 
					    metric = metrics.BinarySparseRecallAtPrecision(1.0)
 | 
				
			||||||
 | 
					    self._assert_metric_equals(metric, 0.0)  # impossible to achieve precision=1
 | 
				
			||||||
 | 
					    metric = metrics.BinarySparseRecallAtPrecision(0.4)
 | 
				
			||||||
 | 
					    self._assert_metric_equals(metric, 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_binary_sparse_precision_at_recall(self):
 | 
				
			||||||
 | 
					    metric = metrics.BinarySparsePrecisionAtRecall(1.0)
 | 
				
			||||||
 | 
					    self._assert_metric_equals(metric, 3 / 4)
 | 
				
			||||||
 | 
					    metric = metrics.BinarySparsePrecisionAtRecall(0.7)
 | 
				
			||||||
 | 
					    self._assert_metric_equals(metric, 3 / 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_binary_sparse_precision_at_recall_class_id_error(self):
 | 
				
			||||||
 | 
					    # class_id=1 case should not error
 | 
				
			||||||
 | 
					    _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=1)
 | 
				
			||||||
 | 
					    # class_id=2 case should error
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
 | 
					        ValueError,
 | 
				
			||||||
 | 
					        'Custom BinarySparseMetric for class:PrecisionAtRecall is only'
 | 
				
			||||||
 | 
					        ' supported for class_id=1, got class_id=2 instead',
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					      _ = metrics.BinarySparsePrecisionAtRecall(1.0, class_id=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					  tf.test.main()
 | 
				
			||||||
| 
						 | 
					@ -31,11 +31,11 @@ py_library(
 | 
				
			||||||
    visibility = ["//visibility:public"],
 | 
					    visibility = ["//visibility:public"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":dataset",
 | 
					        ":dataset",
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        ":text_classifier",
 | 
					        ":text_classifier",
 | 
				
			||||||
        ":text_classifier_options",
 | 
					        ":text_classifier_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,12 +45,18 @@ py_library(
 | 
				
			||||||
    deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
 | 
					    deps = ["//mediapipe/model_maker/python/text/core:bert_model_options"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					py_library(
 | 
				
			||||||
 | 
					    name = "hyperparameters",
 | 
				
			||||||
 | 
					    srcs = ["hyperparameters.py"],
 | 
				
			||||||
 | 
					    deps = ["//mediapipe/model_maker/python/core:hyperparameters"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "model_spec",
 | 
					    name = "model_spec",
 | 
				
			||||||
    srcs = ["model_spec.py"],
 | 
					    srcs = ["model_spec.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:file_util",
 | 
					        "//mediapipe/model_maker/python/core/utils:file_util",
 | 
				
			||||||
        "//mediapipe/model_maker/python/text/core:bert_model_spec",
 | 
					        "//mediapipe/model_maker/python/text/core:bert_model_spec",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					@ -61,9 +67,9 @@ py_test(
 | 
				
			||||||
    srcs = ["model_spec_test.py"],
 | 
					    srcs = ["model_spec_test.py"],
 | 
				
			||||||
    tags = ["requires-net:external"],
 | 
					    tags = ["requires-net:external"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -100,9 +106,9 @@ py_library(
 | 
				
			||||||
    name = "text_classifier_options",
 | 
					    name = "text_classifier_options",
 | 
				
			||||||
    srcs = ["text_classifier_options.py"],
 | 
					    srcs = ["text_classifier_options.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,13 +117,14 @@ py_library(
 | 
				
			||||||
    srcs = ["text_classifier.py"],
 | 
					    srcs = ["text_classifier.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":dataset",
 | 
					        ":dataset",
 | 
				
			||||||
 | 
					        ":hyperparameters",
 | 
				
			||||||
        ":model_options",
 | 
					        ":model_options",
 | 
				
			||||||
        ":model_spec",
 | 
					        ":model_spec",
 | 
				
			||||||
        ":preprocessor",
 | 
					        ":preprocessor",
 | 
				
			||||||
        ":text_classifier_options",
 | 
					        ":text_classifier_options",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core:hyperparameters",
 | 
					 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:dataset",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
					        "//mediapipe/model_maker/python/core/tasks:classifier",
 | 
				
			||||||
 | 
					        "//mediapipe/model_maker/python/core/utils:metrics",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
					        "//mediapipe/model_maker/python/core/utils:model_util",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
					        "//mediapipe/model_maker/python/core/utils:quantization",
 | 
				
			||||||
        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
 | 
					        "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,19 +13,23 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""MediaPipe Public Python API for Text Classifier."""
 | 
					"""MediaPipe Public Python API for Text Classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import dataset
 | 
					from mediapipe.model_maker.python.text.text_classifier import dataset
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier
 | 
					from mediapipe.model_maker.python.text.text_classifier import text_classifier
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HParams = hyperparameters.BaseHParams
 | 
					
 | 
				
			||||||
 | 
					AverageWordEmbeddingHParams = hyperparameters.AverageWordEmbeddingHParams
 | 
				
			||||||
 | 
					AverageWordEmbeddingModelOptions = (
 | 
				
			||||||
 | 
					    model_options.AverageWordEmbeddingModelOptions
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					BertOptimizer = hyperparameters.BertOptimizer
 | 
				
			||||||
 | 
					BertHParams = hyperparameters.BertHParams
 | 
				
			||||||
 | 
					BertModelOptions = model_options.BertModelOptions
 | 
				
			||||||
CSVParams = dataset.CSVParameters
 | 
					CSVParams = dataset.CSVParameters
 | 
				
			||||||
Dataset = dataset.Dataset
 | 
					Dataset = dataset.Dataset
 | 
				
			||||||
AverageWordEmbeddingModelOptions = (
 | 
					 | 
				
			||||||
    model_options.AverageWordEmbeddingModelOptions)
 | 
					 | 
				
			||||||
BertModelOptions = model_options.BertModelOptions
 | 
					 | 
				
			||||||
SupportedModels = model_spec.SupportedModels
 | 
					SupportedModels = model_spec.SupportedModels
 | 
				
			||||||
TextClassifier = text_classifier.TextClassifier
 | 
					TextClassifier = text_classifier.TextClassifier
 | 
				
			||||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
 | 
					TextClassifierOptions = text_classifier_options.TextClassifierOptions
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,54 @@
 | 
				
			||||||
 | 
					# Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					"""Hyperparameters for training object detection models."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import dataclasses
 | 
				
			||||||
 | 
					import enum
 | 
				
			||||||
 | 
					from typing import Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class AverageWordEmbeddingHParams(hp.BaseHParams):
 | 
				
			||||||
 | 
					  """The hyperparameters for an AverageWordEmbeddingClassifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@enum.unique
 | 
				
			||||||
 | 
					class BertOptimizer(enum.Enum):
 | 
				
			||||||
 | 
					  """Supported Optimizers for Bert Text Classifier."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ADAMW = "adamw"
 | 
				
			||||||
 | 
					  LAMB = "lamb"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class BertHParams(hp.BaseHParams):
 | 
				
			||||||
 | 
					  """The hyperparameters for a Bert Classifier.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Attributes:
 | 
				
			||||||
 | 
					    learning_rate: Learning rate to use for gradient descent training.
 | 
				
			||||||
 | 
					    batch_size: Batch size for training.
 | 
				
			||||||
 | 
					    epochs: Number of training iterations over the dataset.
 | 
				
			||||||
 | 
					    optimizer: Optimizer to use for training. Only supported values are "adamw"
 | 
				
			||||||
 | 
					      and "lamb".
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  learning_rate: float = 3e-5
 | 
				
			||||||
 | 
					  batch_size: int = 48
 | 
				
			||||||
 | 
					  epochs: int = 2
 | 
				
			||||||
 | 
					  optimizer: BertOptimizer = BertOptimizer.ADAMW
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					HParams = Union[BertHParams, AverageWordEmbeddingHParams]
 | 
				
			||||||
| 
						 | 
					@ -17,13 +17,11 @@ import dataclasses
 | 
				
			||||||
import enum
 | 
					import enum
 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import file_util
 | 
					from mediapipe.model_maker.python.core.utils import file_util
 | 
				
			||||||
from mediapipe.model_maker.python.text.core import bert_model_spec
 | 
					from mediapipe.model_maker.python.text.core import bert_model_spec
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# BERT-based text classifier spec inherited from BertModelSpec
 | 
					 | 
				
			||||||
BertClassifierSpec = bert_model_spec.BertModelSpec
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
					MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    'text_classifier/mobilebert_tiny',
 | 
					    'text_classifier/mobilebert_tiny',
 | 
				
			||||||
| 
						 | 
					@ -31,6 +29,12 @@ MOBILEBERT_TINY_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    is_folder=True,
 | 
					    is_folder=True,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					EXBERT_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					    'text_classifier/exbert',
 | 
				
			||||||
 | 
					    'https://storage.googleapis.com/mediapipe-assets/exbert.tar.gz',
 | 
				
			||||||
 | 
					    is_folder=True,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclasses.dataclass
 | 
					@dataclasses.dataclass
 | 
				
			||||||
class AverageWordEmbeddingClassifierSpec:
 | 
					class AverageWordEmbeddingClassifierSpec:
 | 
				
			||||||
| 
						 | 
					@ -43,27 +47,53 @@ class AverageWordEmbeddingClassifierSpec:
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # `learning_rate` is unused for the average word embedding model
 | 
					  # `learning_rate` is unused for the average word embedding model
 | 
				
			||||||
  hparams: hp.BaseHParams = hp.BaseHParams(
 | 
					  hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
      epochs=10, batch_size=32, learning_rate=0)
 | 
					      epochs=10, batch_size=32, learning_rate=0
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
					  model_options: mo.AverageWordEmbeddingModelOptions = (
 | 
				
			||||||
      mo.AverageWordEmbeddingModelOptions())
 | 
					      mo.AverageWordEmbeddingModelOptions())
 | 
				
			||||||
  name: str = 'AverageWordEmbedding'
 | 
					  name: str = 'AverageWordEmbedding'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
average_word_embedding_classifier_spec = functools.partial(
 | 
					average_word_embedding_classifier_spec = functools.partial(
 | 
				
			||||||
    AverageWordEmbeddingClassifierSpec)
 | 
					    AverageWordEmbeddingClassifierSpec)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclasses.dataclass
 | 
				
			||||||
 | 
					class BertClassifierSpec(bert_model_spec.BertModelSpec):
 | 
				
			||||||
 | 
					  """Specification for a Bert classifier model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Only overrides the hparams attribute since the rest of the attributes are
 | 
				
			||||||
 | 
					  inherited from the BertModelSpec.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  hparams: hp.BertHParams = hp.BertHParams()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mobilebert_classifier_spec = functools.partial(
 | 
					mobilebert_classifier_spec = functools.partial(
 | 
				
			||||||
    BertClassifierSpec,
 | 
					    BertClassifierSpec,
 | 
				
			||||||
    downloaded_files=MOBILEBERT_TINY_FILES,
 | 
					    downloaded_files=MOBILEBERT_TINY_FILES,
 | 
				
			||||||
    hparams=hp.BaseHParams(
 | 
					    hparams=hp.BertHParams(
 | 
				
			||||||
        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
					        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
    name='MobileBert',
 | 
					    name='MobileBert',
 | 
				
			||||||
    tflite_input_name={
 | 
					    tflite_input_name={
 | 
				
			||||||
        'ids': 'serving_default_input_1:0',
 | 
					        'ids': 'serving_default_input_1:0',
 | 
				
			||||||
        'mask': 'serving_default_input_3:0',
 | 
					 | 
				
			||||||
        'segment_ids': 'serving_default_input_2:0',
 | 
					        'segment_ids': 'serving_default_input_2:0',
 | 
				
			||||||
 | 
					        'mask': 'serving_default_input_3:0',
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					exbert_classifier_spec = functools.partial(
 | 
				
			||||||
 | 
					    BertClassifierSpec,
 | 
				
			||||||
 | 
					    downloaded_files=EXBERT_FILES,
 | 
				
			||||||
 | 
					    hparams=hp.BertHParams(
 | 
				
			||||||
 | 
					        epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off'
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    name='ExBert',
 | 
				
			||||||
 | 
					    tflite_input_name={
 | 
				
			||||||
 | 
					        'ids': 'serving_default_input_1:0',
 | 
				
			||||||
 | 
					        'segment_ids': 'serving_default_input_2:0',
 | 
				
			||||||
 | 
					        'mask': 'serving_default_input_3:0',
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,3 +103,4 @@ class SupportedModels(enum.Enum):
 | 
				
			||||||
  """Predefined text classifier model specs supported by Model Maker."""
 | 
					  """Predefined text classifier model specs supported by Model Maker."""
 | 
				
			||||||
  AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
 | 
					  AVERAGE_WORD_EMBEDDING_CLASSIFIER = average_word_embedding_classifier_spec
 | 
				
			||||||
  MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
 | 
					  MOBILEBERT_CLASSIFIER = mobilebert_classifier_spec
 | 
				
			||||||
 | 
					  EXBERT_CLASSIFIER = exbert_classifier_spec
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ from unittest import mock as unittest_mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as classifier_model_options
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -57,11 +57,13 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
            seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
 | 
					            seq_len=128, do_fine_tuning=True, dropout_rate=0.1))
 | 
				
			||||||
    self.assertEqual(
 | 
					    self.assertEqual(
 | 
				
			||||||
        model_spec_obj.hparams,
 | 
					        model_spec_obj.hparams,
 | 
				
			||||||
        hp.BaseHParams(
 | 
					        hp.BertHParams(
 | 
				
			||||||
            epochs=3,
 | 
					            epochs=3,
 | 
				
			||||||
            batch_size=48,
 | 
					            batch_size=48,
 | 
				
			||||||
            learning_rate=3e-5,
 | 
					            learning_rate=3e-5,
 | 
				
			||||||
            distribution_strategy='off'))
 | 
					            distribution_strategy='off',
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_predefined_average_word_embedding_spec(self):
 | 
					  def test_predefined_average_word_embedding_spec(self):
 | 
				
			||||||
    model_spec_obj = (
 | 
					    model_spec_obj = (
 | 
				
			||||||
| 
						 | 
					@ -78,7 +80,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
            dropout_rate=0.2))
 | 
					            dropout_rate=0.2))
 | 
				
			||||||
    self.assertEqual(
 | 
					    self.assertEqual(
 | 
				
			||||||
        model_spec_obj.hparams,
 | 
					        model_spec_obj.hparams,
 | 
				
			||||||
        hp.BaseHParams(
 | 
					        hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
            epochs=10,
 | 
					            epochs=10,
 | 
				
			||||||
            batch_size=32,
 | 
					            batch_size=32,
 | 
				
			||||||
            learning_rate=0,
 | 
					            learning_rate=0,
 | 
				
			||||||
| 
						 | 
					@ -101,7 +103,7 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
                     custom_bert_classifier_options)
 | 
					                     custom_bert_classifier_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_custom_average_word_embedding_spec(self):
 | 
					  def test_custom_average_word_embedding_spec(self):
 | 
				
			||||||
    custom_hparams = hp.BaseHParams(
 | 
					    custom_hparams = hp.AverageWordEmbeddingHParams(
 | 
				
			||||||
        learning_rate=0.4,
 | 
					        learning_rate=0.4,
 | 
				
			||||||
        batch_size=64,
 | 
					        batch_size=64,
 | 
				
			||||||
        epochs=10,
 | 
					        epochs=10,
 | 
				
			||||||
| 
						 | 
					@ -110,7 +112,8 @@ class ModelSpecTest(tf.test.TestCase):
 | 
				
			||||||
        export_dir='foo/bar',
 | 
					        export_dir='foo/bar',
 | 
				
			||||||
        distribution_strategy='mirrored',
 | 
					        distribution_strategy='mirrored',
 | 
				
			||||||
        num_gpus=3,
 | 
					        num_gpus=3,
 | 
				
			||||||
        tpu='tpu/address')
 | 
					        tpu='tpu/address',
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    custom_average_word_embedding_model_options = (
 | 
					    custom_average_word_embedding_model_options = (
 | 
				
			||||||
        classifier_model_options.AverageWordEmbeddingModelOptions(
 | 
					        classifier_model_options.AverageWordEmbeddingModelOptions(
 | 
				
			||||||
            seq_len=512,
 | 
					            seq_len=512,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,14 +19,16 @@ import tempfile
 | 
				
			||||||
from typing import Any, Optional, Sequence, Tuple
 | 
					from typing import Any, Optional, Sequence, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					from tensorflow_addons import optimizers as tfa_optimizers
 | 
				
			||||||
import tensorflow_hub as hub
 | 
					import tensorflow_hub as hub
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					 | 
				
			||||||
from mediapipe.model_maker.python.core.data import dataset as ds
 | 
					from mediapipe.model_maker.python.core.data import dataset as ds
 | 
				
			||||||
from mediapipe.model_maker.python.core.tasks import classifier
 | 
					from mediapipe.model_maker.python.core.tasks import classifier
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.core.utils import metrics
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import model_util
 | 
					from mediapipe.model_maker.python.core.utils import model_util
 | 
				
			||||||
from mediapipe.model_maker.python.core.utils import quantization
 | 
					from mediapipe.model_maker.python.core.utils import quantization
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
					from mediapipe.model_maker.python.text.text_classifier import dataset as text_ds
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
					from mediapipe.model_maker.python.text.text_classifier import preprocessor
 | 
				
			||||||
| 
						 | 
					@ -54,22 +56,26 @@ def _validate(options: text_classifier_options.TextClassifierOptions):
 | 
				
			||||||
       ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
 | 
					       ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER)):
 | 
				
			||||||
    raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
 | 
					    raise ValueError("Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER,"
 | 
				
			||||||
                     f" got {options.supported_model}")
 | 
					                     f" got {options.supported_model}")
 | 
				
			||||||
  if (isinstance(options.model_options, mo.BertModelOptions) and
 | 
					  if isinstance(options.model_options, mo.BertModelOptions) and (
 | 
				
			||||||
      (options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER)):
 | 
					      options.supported_model != ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
				
			||||||
 | 
					      and options.supported_model != ms.SupportedModels.EXBERT_CLASSIFIER
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
    raise ValueError(
 | 
					    raise ValueError(
 | 
				
			||||||
        f"Expected MOBILEBERT_CLASSIFIER, got {options.supported_model}")
 | 
					        "Expected a Bert Classifier(MobileBERT or EXBERT), got "
 | 
				
			||||||
 | 
					        f"{options.supported_model}"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TextClassifier(classifier.Classifier):
 | 
					class TextClassifier(classifier.Classifier):
 | 
				
			||||||
  """API for creating and training a text classification model."""
 | 
					  """API for creating and training a text classification model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: Any, hparams: hp.BaseHParams,
 | 
					  def __init__(
 | 
				
			||||||
               label_names: Sequence[str]):
 | 
					      self, model_spec: Any, label_names: Sequence[str], shuffle: bool
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
    super().__init__(
 | 
					    super().__init__(
 | 
				
			||||||
        model_spec=model_spec, label_names=label_names, shuffle=hparams.shuffle)
 | 
					        model_spec=model_spec, label_names=label_names, shuffle=shuffle
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    self._model_spec = model_spec
 | 
					    self._model_spec = model_spec
 | 
				
			||||||
    self._hparams = hparams
 | 
					 | 
				
			||||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
					 | 
				
			||||||
    self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
 | 
					    self._text_preprocessor: preprocessor.TextClassifierPreprocessor = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
| 
						 | 
					@ -106,7 +112,10 @@ class TextClassifier(classifier.Classifier):
 | 
				
			||||||
    if options.hparams is None:
 | 
					    if options.hparams is None:
 | 
				
			||||||
      options.hparams = options.supported_model.value().hparams
 | 
					      options.hparams = options.supported_model.value().hparams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER:
 | 
					    if (
 | 
				
			||||||
 | 
					        options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER
 | 
				
			||||||
 | 
					        or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
      text_classifier = (
 | 
					      text_classifier = (
 | 
				
			||||||
          _BertClassifier.create_bert_classifier(train_data, validation_data,
 | 
					          _BertClassifier.create_bert_classifier(train_data, validation_data,
 | 
				
			||||||
                                                 options,
 | 
					                                                 options,
 | 
				
			||||||
| 
						 | 
					@ -123,12 +132,24 @@ class TextClassifier(classifier.Classifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return text_classifier
 | 
					    return text_classifier
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def evaluate(self, data: ds.Dataset, batch_size: int = 32) -> Any:
 | 
					  def evaluate(
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      data: ds.Dataset,
 | 
				
			||||||
 | 
					      batch_size: int = 32,
 | 
				
			||||||
 | 
					      desired_precisions: Optional[Sequence[float]] = None,
 | 
				
			||||||
 | 
					      desired_recalls: Optional[Sequence[float]] = None,
 | 
				
			||||||
 | 
					  ) -> Any:
 | 
				
			||||||
    """Overrides Classifier.evaluate().
 | 
					    """Overrides Classifier.evaluate().
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
      data: Evaluation dataset. Must be a TextClassifier Dataset.
 | 
					      data: Evaluation dataset. Must be a TextClassifier Dataset.
 | 
				
			||||||
      batch_size: Number of samples per evaluation step.
 | 
					      batch_size: Number of samples per evaluation step.
 | 
				
			||||||
 | 
					      desired_precisions: If specified, adds a RecallAtPrecision metric per
 | 
				
			||||||
 | 
					        desired_precisions[i] entry which tracks the recall given the constraint
 | 
				
			||||||
 | 
					        on precision. Only supported for binary classification.
 | 
				
			||||||
 | 
					      desired_recalls: If specified, adds a PrecisionAtRecall metric per
 | 
				
			||||||
 | 
					        desired_recalls[i] entry which tracks the precision given the constraint
 | 
				
			||||||
 | 
					        on recall. Only supported for binary classification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
      The loss value and accuracy.
 | 
					      The loss value and accuracy.
 | 
				
			||||||
| 
						 | 
					@ -144,6 +165,28 @@ class TextClassifier(classifier.Classifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    processed_data = self._text_preprocessor.preprocess(data)
 | 
					    processed_data = self._text_preprocessor.preprocess(data)
 | 
				
			||||||
    dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
 | 
					    dataset = processed_data.gen_tf_dataset(batch_size, is_training=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    additional_metrics = []
 | 
				
			||||||
 | 
					    if desired_precisions and len(data.label_names) == 2:
 | 
				
			||||||
 | 
					      for precision in desired_precisions:
 | 
				
			||||||
 | 
					        additional_metrics.append(
 | 
				
			||||||
 | 
					            metrics.BinarySparseRecallAtPrecision(
 | 
				
			||||||
 | 
					                precision, name=f"recall_at_precision_{precision}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if desired_recalls and len(data.label_names) == 2:
 | 
				
			||||||
 | 
					      for recall in desired_recalls:
 | 
				
			||||||
 | 
					        additional_metrics.append(
 | 
				
			||||||
 | 
					            metrics.BinarySparsePrecisionAtRecall(
 | 
				
			||||||
 | 
					                recall, name=f"precision_at_recall_{recall}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    metric_functions = self._metric_functions + additional_metrics
 | 
				
			||||||
 | 
					    self._model.compile(
 | 
				
			||||||
 | 
					        optimizer=self._optimizer,
 | 
				
			||||||
 | 
					        loss=self._loss_function,
 | 
				
			||||||
 | 
					        metrics=metric_functions,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    return self._model.evaluate(dataset)
 | 
					    return self._model.evaluate(dataset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def export_model(
 | 
					  def export_model(
 | 
				
			||||||
| 
						 | 
					@ -161,9 +204,8 @@ class TextClassifier(classifier.Classifier):
 | 
				
			||||||
        path is {self._hparams.export_dir}/{model_name}.
 | 
					        path is {self._hparams.export_dir}/{model_name}.
 | 
				
			||||||
      quantization_config: The configuration for model quantization.
 | 
					      quantization_config: The configuration for model quantization.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if not tf.io.gfile.exists(self._hparams.export_dir):
 | 
					 | 
				
			||||||
      tf.io.gfile.makedirs(self._hparams.export_dir)
 | 
					 | 
				
			||||||
    tflite_file = os.path.join(self._hparams.export_dir, model_name)
 | 
					    tflite_file = os.path.join(self._hparams.export_dir, model_name)
 | 
				
			||||||
 | 
					    tf.io.gfile.makedirs(os.path.dirname(tflite_file))
 | 
				
			||||||
    metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
 | 
					    metadata_file = os.path.join(self._hparams.export_dir, "metadata.json")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tflite_model = model_util.convert_to_tflite(
 | 
					    tflite_model = model_util.convert_to_tflite(
 | 
				
			||||||
| 
						 | 
					@ -174,7 +216,7 @@ class TextClassifier(classifier.Classifier):
 | 
				
			||||||
    writer = self._get_metadata_writer(tflite_model, vocab_filepath)
 | 
					    writer = self._get_metadata_writer(tflite_model, vocab_filepath)
 | 
				
			||||||
    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
					    tflite_model_with_metadata, metadata_json = writer.populate()
 | 
				
			||||||
    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
					    model_util.save_tflite(tflite_model_with_metadata, tflite_file)
 | 
				
			||||||
    with open(metadata_file, "w") as f:
 | 
					    with tf.io.gfile.GFile(metadata_file, "w") as f:
 | 
				
			||||||
      f.write(metadata_json)
 | 
					      f.write(metadata_json)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @abc.abstractmethod
 | 
					  @abc.abstractmethod
 | 
				
			||||||
| 
						 | 
					@ -191,13 +233,23 @@ class _AverageWordEmbeddingClassifier(TextClassifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  _DELIM_REGEX_PATTERN = r"[^\w\']+"
 | 
					  _DELIM_REGEX_PATTERN = r"[^\w\']+"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
					  def __init__(
 | 
				
			||||||
               model_options: mo.AverageWordEmbeddingModelOptions,
 | 
					      self,
 | 
				
			||||||
               hparams: hp.BaseHParams, label_names: Sequence[str]):
 | 
					      model_spec: ms.AverageWordEmbeddingClassifierSpec,
 | 
				
			||||||
    super().__init__(model_spec, hparams, label_names)
 | 
					      model_options: mo.AverageWordEmbeddingModelOptions,
 | 
				
			||||||
 | 
					      hparams: hp.AverageWordEmbeddingHParams,
 | 
				
			||||||
 | 
					      label_names: Sequence[str],
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
 | 
					    self._hparams = hparams
 | 
				
			||||||
 | 
					    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
				
			||||||
    self._loss_function = "sparse_categorical_crossentropy"
 | 
					    self._loss_function = "sparse_categorical_crossentropy"
 | 
				
			||||||
    self._metric_function = "accuracy"
 | 
					    self._metric_functions = [
 | 
				
			||||||
 | 
					        "accuracy",
 | 
				
			||||||
 | 
					        metrics.SparsePrecision(name="precision", dtype=tf.float32),
 | 
				
			||||||
 | 
					        metrics.SparseRecall(name="recall", dtype=tf.float32),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
    self._text_preprocessor: (
 | 
					    self._text_preprocessor: (
 | 
				
			||||||
        preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
 | 
					        preprocessor.AverageWordEmbeddingClassifierPreprocessor) = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -306,16 +358,26 @@ class _BertClassifier(TextClassifier):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  _INITIALIZER_RANGE = 0.02
 | 
					  _INITIALIZER_RANGE = 0.02
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __init__(self, model_spec: ms.BertClassifierSpec,
 | 
					  def __init__(
 | 
				
			||||||
               model_options: mo.BertModelOptions, hparams: hp.BaseHParams,
 | 
					      self,
 | 
				
			||||||
               label_names: Sequence[str]):
 | 
					      model_spec: ms.BertClassifierSpec,
 | 
				
			||||||
    super().__init__(model_spec, hparams, label_names)
 | 
					      model_options: mo.BertModelOptions,
 | 
				
			||||||
 | 
					      hparams: hp.BertHParams,
 | 
				
			||||||
 | 
					      label_names: Sequence[str],
 | 
				
			||||||
 | 
					  ):
 | 
				
			||||||
 | 
					    super().__init__(model_spec, label_names, hparams.shuffle)
 | 
				
			||||||
 | 
					    self._hparams = hparams
 | 
				
			||||||
 | 
					    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
    with self._hparams.get_strategy().scope():
 | 
					    with self._hparams.get_strategy().scope():
 | 
				
			||||||
      self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 | 
					      self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
 | 
				
			||||||
      self._metric_function = tf.keras.metrics.SparseCategoricalAccuracy(
 | 
					      self._metric_functions = [
 | 
				
			||||||
          "test_accuracy", dtype=tf.float32
 | 
					          tf.keras.metrics.SparseCategoricalAccuracy(
 | 
				
			||||||
      )
 | 
					              "test_accuracy", dtype=tf.float32
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					          metrics.SparsePrecision(name="precision", dtype=tf.float32),
 | 
				
			||||||
 | 
					          metrics.SparseRecall(name="recall", dtype=tf.float32),
 | 
				
			||||||
 | 
					      ]
 | 
				
			||||||
    self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
 | 
					    self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
| 
						 | 
					@ -438,11 +500,26 @@ class _BertClassifier(TextClassifier):
 | 
				
			||||||
          initial_learning_rate=initial_lr,
 | 
					          initial_learning_rate=initial_lr,
 | 
				
			||||||
          decay_schedule_fn=lr_schedule,
 | 
					          decay_schedule_fn=lr_schedule,
 | 
				
			||||||
          warmup_steps=warmup_steps)
 | 
					          warmup_steps=warmup_steps)
 | 
				
			||||||
 | 
					    if self._hparams.optimizer == hp.BertOptimizer.ADAMW:
 | 
				
			||||||
    self._optimizer = tf.keras.optimizers.experimental.AdamW(
 | 
					      self._optimizer = tf.keras.optimizers.experimental.AdamW(
 | 
				
			||||||
        lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
 | 
					          lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0
 | 
				
			||||||
    self._optimizer.exclude_from_weight_decay(
 | 
					      )
 | 
				
			||||||
        var_names=["LayerNorm", "layer_norm", "bias"])
 | 
					      self._optimizer.exclude_from_weight_decay(
 | 
				
			||||||
 | 
					          var_names=["LayerNorm", "layer_norm", "bias"]
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    elif self._hparams.optimizer == hp.BertOptimizer.LAMB:
 | 
				
			||||||
 | 
					      self._optimizer = tfa_optimizers.LAMB(
 | 
				
			||||||
 | 
					          lr_schedule,
 | 
				
			||||||
 | 
					          weight_decay_rate=0.01,
 | 
				
			||||||
 | 
					          epsilon=1e-6,
 | 
				
			||||||
 | 
					          exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
 | 
				
			||||||
 | 
					          global_clipnorm=1.0,
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "BertHParams.optimizer must be set to ADAM or "
 | 
				
			||||||
 | 
					          f"LAMB. Got {self._hparams.optimizer}."
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _save_vocab(self, vocab_filepath: str):
 | 
					  def _save_vocab(self, vocab_filepath: str):
 | 
				
			||||||
    tf.io.gfile.copy(
 | 
					    tf.io.gfile.copy(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -66,14 +66,16 @@ def run(data_dir,
 | 
				
			||||||
  quantization_config = None
 | 
					  quantization_config = None
 | 
				
			||||||
  if (supported_model ==
 | 
					  if (supported_model ==
 | 
				
			||||||
      text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
 | 
					      text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER):
 | 
				
			||||||
    hparams = text_classifier.HParams(
 | 
					    hparams = text_classifier.AverageWordEmbeddingHParams(
 | 
				
			||||||
        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir)
 | 
					        epochs=10, batch_size=32, learning_rate=0, export_dir=export_dir
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
  # Warning: This takes extremely long to run on CPU
 | 
					  # Warning: This takes extremely long to run on CPU
 | 
				
			||||||
  elif (
 | 
					  elif (
 | 
				
			||||||
      supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
 | 
					      supported_model == text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER):
 | 
				
			||||||
    quantization_config = quantization.QuantizationConfig.for_dynamic()
 | 
					    quantization_config = quantization.QuantizationConfig.for_dynamic()
 | 
				
			||||||
    hparams = text_classifier.HParams(
 | 
					    hparams = text_classifier.BertHParams(
 | 
				
			||||||
        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir)
 | 
					        epochs=3, batch_size=48, learning_rate=3e-5, export_dir=export_dir
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Fine-tunes the model.
 | 
					  # Fine-tunes the model.
 | 
				
			||||||
  options = text_classifier.TextClassifierOptions(
 | 
					  options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,7 +16,7 @@
 | 
				
			||||||
import dataclasses
 | 
					import dataclasses
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core import hyperparameters as hp
 | 
					from mediapipe.model_maker.python.text.text_classifier import hyperparameters as hp
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_options as mo
 | 
				
			||||||
from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
					from mediapipe.model_maker.python.text.text_classifier import model_spec as ms
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,5 +34,5 @@ class TextClassifierOptions:
 | 
				
			||||||
      architecture of the `supported_model`.
 | 
					      architecture of the `supported_model`.
 | 
				
			||||||
  """
 | 
					  """
 | 
				
			||||||
  supported_model: ms.SupportedModels
 | 
					  supported_model: ms.SupportedModels
 | 
				
			||||||
  hparams: Optional[hp.BaseHParams] = None
 | 
					  hparams: Optional[hp.HParams] = None
 | 
				
			||||||
  model_options: Optional[mo.TextClassifierModelOptions] = None
 | 
					  model_options: Optional[mo.TextClassifierModelOptions] = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -66,12 +66,14 @@ class TextClassifierTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_and_train_average_word_embedding_model(self):
 | 
					  def test_create_and_train_average_word_embedding_model(self):
 | 
				
			||||||
    train_data, validation_data = self._get_data()
 | 
					    train_data, validation_data = self._get_data()
 | 
				
			||||||
    options = (
 | 
					    options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
        text_classifier.TextClassifierOptions(
 | 
					        supported_model=(
 | 
				
			||||||
            supported_model=(text_classifier.SupportedModels
 | 
					            text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER
 | 
				
			||||||
                             .AVERAGE_WORD_EMBEDDING_CLASSIFIER),
 | 
					        ),
 | 
				
			||||||
            hparams=text_classifier.HParams(
 | 
					        hparams=text_classifier.AverageWordEmbeddingHParams(
 | 
				
			||||||
                epochs=1, batch_size=1, learning_rate=0)))
 | 
					            epochs=1, batch_size=1, learning_rate=0
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    average_word_embedding_classifier = (
 | 
					    average_word_embedding_classifier = (
 | 
				
			||||||
        text_classifier.TextClassifier.create(train_data, validation_data,
 | 
					        text_classifier.TextClassifier.create(train_data, validation_data,
 | 
				
			||||||
                                              options))
 | 
					                                              options))
 | 
				
			||||||
| 
						 | 
					@ -103,12 +105,15 @@ class TextClassifierTest(tf.test.TestCase):
 | 
				
			||||||
    options = text_classifier.TextClassifierOptions(
 | 
					    options = text_classifier.TextClassifierOptions(
 | 
				
			||||||
        supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
 | 
					        supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER,
 | 
				
			||||||
        model_options=text_classifier.BertModelOptions(
 | 
					        model_options=text_classifier.BertModelOptions(
 | 
				
			||||||
            do_fine_tuning=False, seq_len=2),
 | 
					            do_fine_tuning=False, seq_len=2
 | 
				
			||||||
        hparams=text_classifier.HParams(
 | 
					        ),
 | 
				
			||||||
 | 
					        hparams=text_classifier.BertHParams(
 | 
				
			||||||
            epochs=1,
 | 
					            epochs=1,
 | 
				
			||||||
            batch_size=1,
 | 
					            batch_size=1,
 | 
				
			||||||
            learning_rate=3e-5,
 | 
					            learning_rate=3e-5,
 | 
				
			||||||
            distribution_strategy='off'))
 | 
					            distribution_strategy='off',
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    bert_classifier = text_classifier.TextClassifier.create(
 | 
					    bert_classifier = text_classifier.TextClassifier.create(
 | 
				
			||||||
        train_data, validation_data, options)
 | 
					        train_data, validation_data, options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,13 +20,6 @@ licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
					package(default_visibility = ["//mediapipe:__subpackages__"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
filegroup(
 | 
					 | 
				
			||||||
    name = "testdata",
 | 
					 | 
				
			||||||
    srcs = glob([
 | 
					 | 
				
			||||||
        "testdata/**",
 | 
					 | 
				
			||||||
    ]),
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
py_library(
 | 
					py_library(
 | 
				
			||||||
    name = "constants",
 | 
					    name = "constants",
 | 
				
			||||||
    srcs = ["constants.py"],
 | 
					    srcs = ["constants.py"],
 | 
				
			||||||
| 
						 | 
					@ -72,18 +65,11 @@ py_library(
 | 
				
			||||||
    name = "dataset",
 | 
					    name = "dataset",
 | 
				
			||||||
    srcs = ["dataset.py"],
 | 
					    srcs = ["dataset.py"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":constants",
 | 
				
			||||||
        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
					        "//mediapipe/model_maker/python/core/data:classification_dataset",
 | 
				
			||||||
        "//mediapipe/model_maker/python/vision/core:image_utils",
 | 
					        "//mediapipe/python:_framework_bindings",
 | 
				
			||||||
    ],
 | 
					        "//mediapipe/tasks/python/core:base_options",
 | 
				
			||||||
)
 | 
					        "//mediapipe/tasks/python/vision:face_aligner",
 | 
				
			||||||
 | 
					 | 
				
			||||||
py_test(
 | 
					 | 
				
			||||||
    name = "dataset_test",
 | 
					 | 
				
			||||||
    srcs = ["dataset_test.py"],
 | 
					 | 
				
			||||||
    data = [":testdata"],
 | 
					 | 
				
			||||||
    deps = [
 | 
					 | 
				
			||||||
        ":dataset",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/python/test:test_utils",
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,5 +41,11 @@ FACE_STYLIZER_W_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
    'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
 | 
					    'https://storage.googleapis.com/mediapipe-assets/face_stylizer_w_avg.npy',
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					FACE_ALIGNER_TASK_FILES = file_util.DownloadedFiles(
 | 
				
			||||||
 | 
					    'face_stylizer/face_landmarker_v2.task',
 | 
				
			||||||
 | 
					    'https://storage.googleapis.com/mediapipe-assets/face_landmarker_v2.task',
 | 
				
			||||||
 | 
					    is_folder=False,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Dimension of the input style vector to the decoder
 | 
					# Dimension of the input style vector to the decoder
 | 
				
			||||||
STYLE_DIM = 512
 | 
					STYLE_DIM = 512
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,13 +13,37 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
"""Face stylizer dataset library."""
 | 
					"""Face stylizer dataset library."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Sequence
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from mediapipe.model_maker.python.core.data import classification_dataset
 | 
					from mediapipe.model_maker.python.core.data import classification_dataset
 | 
				
			||||||
from mediapipe.model_maker.python.vision.core import image_utils
 | 
					from mediapipe.model_maker.python.vision.face_stylizer import constants
 | 
				
			||||||
 | 
					from mediapipe.python._framework_bindings import image as image_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.core import base_options as base_options_module
 | 
				
			||||||
 | 
					from mediapipe.tasks.python.vision import face_aligner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _preprocess_face_dataset(
 | 
				
			||||||
 | 
					    all_image_paths: Sequence[str],
 | 
				
			||||||
 | 
					) -> Sequence[tf.Tensor]:
 | 
				
			||||||
 | 
					  """Preprocess face image dataset by aligning the face."""
 | 
				
			||||||
 | 
					  path = constants.FACE_ALIGNER_TASK_FILES.get_path()
 | 
				
			||||||
 | 
					  base_options = base_options_module.BaseOptions(model_asset_path=path)
 | 
				
			||||||
 | 
					  options = face_aligner.FaceAlignerOptions(base_options=base_options)
 | 
				
			||||||
 | 
					  aligner = face_aligner.FaceAligner.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  preprocessed_images = []
 | 
				
			||||||
 | 
					  for path in all_image_paths:
 | 
				
			||||||
 | 
					    tf.compat.v1.logging.info('Preprocess image %s', path)
 | 
				
			||||||
 | 
					    image = image_module.Image.create_from_file(path)
 | 
				
			||||||
 | 
					    aligned_image = aligner.align(image)
 | 
				
			||||||
 | 
					    aligned_image_tensor = tf.convert_to_tensor(aligned_image.numpy_view())
 | 
				
			||||||
 | 
					    preprocessed_images.append(aligned_image_tensor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return preprocessed_images
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: Change to a unlabeled dataset if it makes sense.
 | 
					# TODO: Change to a unlabeled dataset if it makes sense.
 | 
				
			||||||
| 
						 | 
					@ -58,6 +82,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
      raise ValueError('No images found under given directory')
 | 
					      raise ValueError('No images found under given directory')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    image_data = _preprocess_face_dataset(all_image_paths)
 | 
				
			||||||
    label_names = sorted(
 | 
					    label_names = sorted(
 | 
				
			||||||
        name
 | 
					        name
 | 
				
			||||||
        for name in os.listdir(data_root)
 | 
					        for name in os.listdir(data_root)
 | 
				
			||||||
| 
						 | 
					@ -73,11 +98,7 @@ class Dataset(classification_dataset.ClassificationDataset):
 | 
				
			||||||
        for path in all_image_paths
 | 
					        for path in all_image_paths
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
 | 
					    image_ds = tf.data.Dataset.from_tensor_slices(image_data)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    image_ds = path_ds.map(
 | 
					 | 
				
			||||||
        image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load label
 | 
					    # Load label
 | 
				
			||||||
    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
					    label_ds = tf.data.Dataset.from_tensor_slices(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,8 +12,10 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from mediapipe.model_maker.python.vision.core import image_utils
 | 
				
			||||||
from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
					from mediapipe.model_maker.python.vision.face_stylizer import dataset
 | 
				
			||||||
from mediapipe.tasks.python.test import test_utils
 | 
					from mediapipe.tasks.python.test import test_utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,10 +24,10 @@ class DatasetTest(tf.test.TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def setUp(self):
 | 
					  def setUp(self):
 | 
				
			||||||
    super().setUp()
 | 
					    super().setUp()
 | 
				
			||||||
    self._test_data_dirname = 'input/style'
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_from_folder(self):
 | 
					  def test_from_folder(self):
 | 
				
			||||||
    input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
 | 
					    test_data_dirname = 'input/style'
 | 
				
			||||||
 | 
					    input_data_dir = test_utils.get_test_data_path(test_data_dirname)
 | 
				
			||||||
    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
					    data = dataset.Dataset.from_folder(dirname=input_data_dir)
 | 
				
			||||||
    self.assertEqual(data.num_classes, 2)
 | 
					    self.assertEqual(data.num_classes, 2)
 | 
				
			||||||
    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
					    self.assertEqual(data.label_names, ['cartoon', 'sketch'])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,7 +14,7 @@
 | 
				
			||||||
"""APIs to train face stylization model."""
 | 
					"""APIs to train face stylization model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from typing import Callable, Optional
 | 
					from typing import Any, Callable, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import tensorflow as tf
 | 
					import tensorflow as tf
 | 
				
			||||||
| 
						 | 
					@ -54,7 +54,6 @@ class FaceStylizer(object):
 | 
				
			||||||
    self._model_spec = model_spec
 | 
					    self._model_spec = model_spec
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
    self._hparams = hparams
 | 
					    self._hparams = hparams
 | 
				
			||||||
    # TODO: Support face alignment in image preprocessor.
 | 
					 | 
				
			||||||
    self._preprocessor = image_preprocessing.Preprocessor(
 | 
					    self._preprocessor = image_preprocessing.Preprocessor(
 | 
				
			||||||
        input_shape=self._model_spec.input_image_shape,
 | 
					        input_shape=self._model_spec.input_image_shape,
 | 
				
			||||||
        num_classes=1,
 | 
					        num_classes=1,
 | 
				
			||||||
| 
						 | 
					@ -128,7 +127,7 @@ class FaceStylizer(object):
 | 
				
			||||||
  def _train_model(
 | 
					  def _train_model(
 | 
				
			||||||
      self,
 | 
					      self,
 | 
				
			||||||
      train_data: classification_ds.ClassificationDataset,
 | 
					      train_data: classification_ds.ClassificationDataset,
 | 
				
			||||||
      preprocessor: Optional[Callable[..., bool]] = None,
 | 
					      preprocessor: Optional[Callable[..., Any]] = None,
 | 
				
			||||||
  ):
 | 
					  ):
 | 
				
			||||||
    """Trains the face stylizer model.
 | 
					    """Trains the face stylizer model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,7 +54,7 @@ class GestureRecognizer(classifier.Classifier):
 | 
				
			||||||
    self._model_options = model_options
 | 
					    self._model_options = model_options
 | 
				
			||||||
    self._hparams = hparams
 | 
					    self._hparams = hparams
 | 
				
			||||||
    self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
 | 
					    self._loss_function = loss_functions.FocalLoss(gamma=self._hparams.gamma)
 | 
				
			||||||
    self._metric_function = 'categorical_accuracy'
 | 
					    self._metric_functions = ['categorical_accuracy']
 | 
				
			||||||
    self._optimizer = 'adam'
 | 
					    self._optimizer = 'adam'
 | 
				
			||||||
    self._callbacks = self._get_callbacks()
 | 
					    self._callbacks = self._get_callbacks()
 | 
				
			||||||
    self._history = None
 | 
					    self._history = None
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,7 +59,7 @@ class ImageClassifier(classifier.Classifier):
 | 
				
			||||||
    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
					    self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir)
 | 
				
			||||||
    self._loss_function = tf.keras.losses.CategoricalCrossentropy(
 | 
					    self._loss_function = tf.keras.losses.CategoricalCrossentropy(
 | 
				
			||||||
        label_smoothing=self._hparams.label_smoothing)
 | 
					        label_smoothing=self._hparams.label_smoothing)
 | 
				
			||||||
    self._metric_function = 'accuracy'
 | 
					    self._metric_functions = ['accuracy']
 | 
				
			||||||
    self._history = None  # Training history returned from `keras_model.fit`.
 | 
					    self._history = None  # Training history returned from `keras_model.fit`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -101,14 +101,17 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    return model_config
 | 
					    return model_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _build_model(self) -> tf.keras.Model:
 | 
					  def _build_model(self, omit_l2=False) -> tf.keras.Model:
 | 
				
			||||||
    """Builds a RetinaNet object detector model."""
 | 
					    """Builds a RetinaNet object detector model."""
 | 
				
			||||||
    input_specs = tf.keras.layers.InputSpec(
 | 
					    input_specs = tf.keras.layers.InputSpec(
 | 
				
			||||||
        shape=[None] + self._model_spec.input_image_shape
 | 
					        shape=[None] + self._model_spec.input_image_shape
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    l2_regularizer = tf.keras.regularizers.l2(
 | 
					    if omit_l2:
 | 
				
			||||||
        self._model_options.l2_weight_decay / 2.0
 | 
					      l2_regularizer = None
 | 
				
			||||||
    )
 | 
					    else:
 | 
				
			||||||
 | 
					      l2_regularizer = tf.keras.regularizers.l2(
 | 
				
			||||||
 | 
					          self._model_options.l2_weight_decay / 2.0
 | 
				
			||||||
 | 
					      )
 | 
				
			||||||
    model_config = self._get_model_config()
 | 
					    model_config = self._get_model_config()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return factory.build_retinanet(input_specs, model_config, l2_regularizer)
 | 
					    return factory.build_retinanet(input_specs, model_config, l2_regularizer)
 | 
				
			||||||
| 
						 | 
					@ -167,7 +170,7 @@ class ObjectDetectorModel(tf.keras.Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def convert_to_qat(self) -> None:
 | 
					  def convert_to_qat(self) -> None:
 | 
				
			||||||
    """Converts the model to a QAT RetinaNet model."""
 | 
					    """Converts the model to a QAT RetinaNet model."""
 | 
				
			||||||
    model = self._build_model()
 | 
					    model = self._build_model(omit_l2=True)
 | 
				
			||||||
    dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
 | 
					    dummy_input = tf.zeros([1] + self._model_spec.input_image_shape)
 | 
				
			||||||
    model(dummy_input, training=True)
 | 
					    model(dummy_input, training=True)
 | 
				
			||||||
    model.set_weights(self._model.get_weights())
 | 
					    model.set_weights(self._model.get_weights())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,6 +43,7 @@ cc_library(
 | 
				
			||||||
        ":base_audio_task_api",
 | 
					        ":base_audio_task_api",
 | 
				
			||||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
					        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
				
			||||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
					        "//mediapipe/framework:calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,7 @@ limitations under the License.
 | 
				
			||||||
#include "absl/strings/str_cat.h"
 | 
					#include "absl/strings/str_cat.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator.pb.h"
 | 
					#include "mediapipe/framework/calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
 | 
					#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
				
			||||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
					#include "tensorflow/lite/core/api/op_resolver.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
| 
						 | 
					@ -60,13 +61,8 @@ class AudioTaskApiFactory {
 | 
				
			||||||
            "Task graph config should only contain one task subgraph node.",
 | 
					            "Task graph config should only contain one task subgraph node.",
 | 
				
			||||||
            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
					            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        if (!node.options().HasExtension(Options::ext)) {
 | 
					        MP_RETURN_IF_ERROR(
 | 
				
			||||||
          return CreateStatusWithPayload(
 | 
					            tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
 | 
				
			||||||
              absl::StatusCode::kInvalidArgument,
 | 
					 | 
				
			||||||
              absl::StrCat(node.calculator(),
 | 
					 | 
				
			||||||
                           " is missing the required task options field."),
 | 
					 | 
				
			||||||
              MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        found_task_subgraph = true;
 | 
					        found_task_subgraph = true;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,6 +29,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
 | 
					        "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
 | 
					        "//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/log",
 | 
				
			||||||
        "@com_google_absl//absl/memory",
 | 
					        "@com_google_absl//absl/memory",
 | 
				
			||||||
        "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
 | 
					        "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
 | 
				
			||||||
        "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
 | 
					        "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,15 +17,56 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <variant>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/log/log.h"
 | 
				
			||||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
 | 
					#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
 | 
					#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
 | 
					#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
namespace tasks {
 | 
					namespace tasks {
 | 
				
			||||||
namespace core {
 | 
					namespace core {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
 | 
				
			||||||
 | 
					    const BaseOptions::CpuOptions& options) {
 | 
				
			||||||
 | 
					  proto::Acceleration acceleration_proto = proto::Acceleration();
 | 
				
			||||||
 | 
					  acceleration_proto.mutable_tflite();
 | 
				
			||||||
 | 
					  return acceleration_proto;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					proto::Acceleration ConvertDelegateOptionsToAccelerationProto(
 | 
				
			||||||
 | 
					    const BaseOptions::GpuOptions& options) {
 | 
				
			||||||
 | 
					  proto::Acceleration acceleration_proto = proto::Acceleration();
 | 
				
			||||||
 | 
					  auto* gpu = acceleration_proto.mutable_gpu();
 | 
				
			||||||
 | 
					  gpu->set_use_advanced_gpu_api(true);
 | 
				
			||||||
 | 
					  gpu->set_cached_kernel_path(options.cached_kernel_path);
 | 
				
			||||||
 | 
					  gpu->set_serialized_model_dir(options.serialized_model_dir);
 | 
				
			||||||
 | 
					  gpu->set_model_token(options.model_token);
 | 
				
			||||||
 | 
					  return acceleration_proto;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename T>
 | 
				
			||||||
 | 
					void SetDelegateOptionsOrDie(const BaseOptions* base_options,
 | 
				
			||||||
 | 
					                             proto::BaseOptions& base_options_proto) {
 | 
				
			||||||
 | 
					  if (base_options->delegate_options.has_value()) {
 | 
				
			||||||
 | 
					    if (!std::holds_alternative<T>(*base_options->delegate_options)) {
 | 
				
			||||||
 | 
					      LOG(FATAL) << "Specified Delegate type does not match the provided "
 | 
				
			||||||
 | 
					                    "delegate options.";
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      std::visit(
 | 
				
			||||||
 | 
					          [&base_options_proto](const auto& delegate_options) {
 | 
				
			||||||
 | 
					            proto::Acceleration acceleration_proto =
 | 
				
			||||||
 | 
					                ConvertDelegateOptionsToAccelerationProto(delegate_options);
 | 
				
			||||||
 | 
					            base_options_proto.mutable_acceleration()->Swap(
 | 
				
			||||||
 | 
					                &acceleration_proto);
 | 
				
			||||||
 | 
					          },
 | 
				
			||||||
 | 
					          *base_options->delegate_options);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
					proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
				
			||||||
  proto::BaseOptions base_options_proto;
 | 
					  proto::BaseOptions base_options_proto;
 | 
				
			||||||
  if (!base_options->model_asset_path.empty()) {
 | 
					  if (!base_options->model_asset_path.empty()) {
 | 
				
			||||||
| 
						 | 
					@ -53,11 +94,15 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
				
			||||||
  switch (base_options->delegate) {
 | 
					  switch (base_options->delegate) {
 | 
				
			||||||
    case BaseOptions::Delegate::CPU:
 | 
					    case BaseOptions::Delegate::CPU:
 | 
				
			||||||
      base_options_proto.mutable_acceleration()->mutable_tflite();
 | 
					      base_options_proto.mutable_acceleration()->mutable_tflite();
 | 
				
			||||||
 | 
					      SetDelegateOptionsOrDie<BaseOptions::CpuOptions>(base_options,
 | 
				
			||||||
 | 
					                                                       base_options_proto);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    case BaseOptions::Delegate::GPU:
 | 
					    case BaseOptions::Delegate::GPU:
 | 
				
			||||||
      base_options_proto.mutable_acceleration()
 | 
					      base_options_proto.mutable_acceleration()
 | 
				
			||||||
          ->mutable_gpu()
 | 
					          ->mutable_gpu()
 | 
				
			||||||
          ->set_use_advanced_gpu_api(true);
 | 
					          ->set_use_advanced_gpu_api(true);
 | 
				
			||||||
 | 
					      SetDelegateOptionsOrDie<BaseOptions::GpuOptions>(base_options,
 | 
				
			||||||
 | 
					                                                       base_options_proto);
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
    case BaseOptions::Delegate::EDGETPU_NNAPI:
 | 
					    case BaseOptions::Delegate::EDGETPU_NNAPI:
 | 
				
			||||||
      base_options_proto.mutable_acceleration()
 | 
					      base_options_proto.mutable_acceleration()
 | 
				
			||||||
| 
						 | 
					@ -65,7 +110,6 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) {
 | 
				
			||||||
          ->set_accelerator_name("google-edgetpu");
 | 
					          ->set_accelerator_name("google-edgetpu");
 | 
				
			||||||
      break;
 | 
					      break;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					 | 
				
			||||||
  return base_options_proto;
 | 
					  return base_options_proto;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}  // namespace core
 | 
					}  // namespace core
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,7 +17,9 @@ limitations under the License.
 | 
				
			||||||
#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
 | 
					#define MEDIAPIPE_TASKS_CC_CORE_BASE_OPTIONS_H_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <optional>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <variant>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/memory/memory.h"
 | 
					#include "absl/memory/memory.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
 | 
					#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
 | 
				
			||||||
| 
						 | 
					@ -38,7 +40,8 @@ struct BaseOptions {
 | 
				
			||||||
  std::string model_asset_path = "";
 | 
					  std::string model_asset_path = "";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The delegate to run MediaPipe. If the delegate is not set, the default
 | 
					  // The delegate to run MediaPipe. If the delegate is not set, the default
 | 
				
			||||||
  // delegate CPU is used.
 | 
					  // delegate CPU is used. Use `delegate_options` to configure advanced
 | 
				
			||||||
 | 
					  // features of the selected delegate."
 | 
				
			||||||
  enum Delegate {
 | 
					  enum Delegate {
 | 
				
			||||||
    CPU = 0,
 | 
					    CPU = 0,
 | 
				
			||||||
    GPU = 1,
 | 
					    GPU = 1,
 | 
				
			||||||
| 
						 | 
					@ -48,6 +51,30 @@ struct BaseOptions {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Delegate delegate = CPU;
 | 
					  Delegate delegate = CPU;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Options for CPU.
 | 
				
			||||||
 | 
					  struct CpuOptions {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Options for GPU.
 | 
				
			||||||
 | 
					  struct GpuOptions {
 | 
				
			||||||
 | 
					    // Load pre-compiled serialized binary cache to accelerate init process.
 | 
				
			||||||
 | 
					    // Only available on Android. Kernel caching will only be enabled if this
 | 
				
			||||||
 | 
					    // path is set. NOTE: binary cache usage may be skipped if valid serialized
 | 
				
			||||||
 | 
					    // model, specified by "serialized_model_dir", exists.
 | 
				
			||||||
 | 
					    std::string cached_kernel_path;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // A dir to load from and save to a pre-compiled serialized model used to
 | 
				
			||||||
 | 
					    // accelerate init process.
 | 
				
			||||||
 | 
					    // NOTE: serialized model takes precedence over binary cache
 | 
				
			||||||
 | 
					    // specified by "cached_kernel_path", which still can be used if
 | 
				
			||||||
 | 
					    // serialized model is invalid or missing.
 | 
				
			||||||
 | 
					    std::string serialized_model_dir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Unique token identifying the model. Used in conjunction with
 | 
				
			||||||
 | 
					    // "serialized_model_dir". It is the caller's responsibility to ensure
 | 
				
			||||||
 | 
					    // there is no clash of the tokens.
 | 
				
			||||||
 | 
					    std::string model_token;
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The file descriptor to a file opened with open(2), with optional additional
 | 
					  // The file descriptor to a file opened with open(2), with optional additional
 | 
				
			||||||
  // offset and length information.
 | 
					  // offset and length information.
 | 
				
			||||||
  struct FileDescriptorMeta {
 | 
					  struct FileDescriptorMeta {
 | 
				
			||||||
| 
						 | 
					@ -67,6 +94,10 @@ struct BaseOptions {
 | 
				
			||||||
  // built-in Ops.
 | 
					  // built-in Ops.
 | 
				
			||||||
  std::unique_ptr<tflite::OpResolver> op_resolver =
 | 
					  std::unique_ptr<tflite::OpResolver> op_resolver =
 | 
				
			||||||
      absl::make_unique<MediaPipeBuiltinOpResolver>();
 | 
					      absl::make_unique<MediaPipeBuiltinOpResolver>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Options for the chosen delegate. If not set, the default delegate options
 | 
				
			||||||
 | 
					  // is used.
 | 
				
			||||||
 | 
					  std::optional<std::variant<CpuOptions, GpuOptions>> delegate_options;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Converts a BaseOptions to a BaseOptionsProto.
 | 
					// Converts a BaseOptions to a BaseOptionsProto.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,9 @@
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
					#include "mediapipe/tasks/cc/core/base_options.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <optional>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <variant>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
 | 
					#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/port/gmock.h"
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
| 
						 | 
					@ -11,6 +14,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr char kTestModelBundlePath[] =
 | 
					constexpr char kTestModelBundlePath[] =
 | 
				
			||||||
    "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
 | 
					    "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task";
 | 
				
			||||||
 | 
					constexpr char kCachedModelDir[] = "/data/local/tmp";
 | 
				
			||||||
 | 
					constexpr char kModelToken[] = "dummy_model_token";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mediapipe {
 | 
					namespace mediapipe {
 | 
				
			||||||
namespace tasks {
 | 
					namespace tasks {
 | 
				
			||||||
| 
						 | 
					@ -40,6 +45,44 @@ TEST(BaseOptionsTest, ConvertBaseOptionsToProtoWithAcceleration) {
 | 
				
			||||||
  EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
 | 
					  EXPECT_EQ(proto.acceleration().nnapi().accelerator_name(), "google-edgetpu");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DelegateOptionsTest, SucceedCpuOptions) {
 | 
				
			||||||
 | 
					  BaseOptions base_options;
 | 
				
			||||||
 | 
					  base_options.delegate = BaseOptions::Delegate::CPU;
 | 
				
			||||||
 | 
					  BaseOptions::CpuOptions cpu_options;
 | 
				
			||||||
 | 
					  base_options.delegate_options = cpu_options;
 | 
				
			||||||
 | 
					  proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
 | 
				
			||||||
 | 
					  EXPECT_TRUE(proto.acceleration().has_tflite());
 | 
				
			||||||
 | 
					  ASSERT_FALSE(proto.acceleration().has_gpu());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DelegateOptionsTest, SucceedGpuOptions) {
 | 
				
			||||||
 | 
					  BaseOptions base_options;
 | 
				
			||||||
 | 
					  base_options.delegate = BaseOptions::Delegate::GPU;
 | 
				
			||||||
 | 
					  BaseOptions::GpuOptions gpu_options;
 | 
				
			||||||
 | 
					  gpu_options.cached_kernel_path = kCachedModelDir;
 | 
				
			||||||
 | 
					  gpu_options.model_token = kModelToken;
 | 
				
			||||||
 | 
					  base_options.delegate_options = gpu_options;
 | 
				
			||||||
 | 
					  proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options);
 | 
				
			||||||
 | 
					  ASSERT_TRUE(proto.acceleration().has_gpu());
 | 
				
			||||||
 | 
					  ASSERT_FALSE(proto.acceleration().has_tflite());
 | 
				
			||||||
 | 
					  EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api());
 | 
				
			||||||
 | 
					  EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir);
 | 
				
			||||||
 | 
					  EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(DelegateOptionsDeathTest, FailWrongDelegateOptionsType) {
 | 
				
			||||||
 | 
					  BaseOptions base_options;
 | 
				
			||||||
 | 
					  base_options.delegate = BaseOptions::Delegate::CPU;
 | 
				
			||||||
 | 
					  BaseOptions::GpuOptions gpu_options;
 | 
				
			||||||
 | 
					  gpu_options.cached_kernel_path = kCachedModelDir;
 | 
				
			||||||
 | 
					  gpu_options.model_token = kModelToken;
 | 
				
			||||||
 | 
					  base_options.delegate_options = gpu_options;
 | 
				
			||||||
 | 
					  ASSERT_DEATH(
 | 
				
			||||||
 | 
					      { proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); },
 | 
				
			||||||
 | 
					      "Specified Delegate type does not match the provided "
 | 
				
			||||||
 | 
					      "delegate options.");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
}  // namespace core
 | 
					}  // namespace core
 | 
				
			||||||
}  // namespace tasks
 | 
					}  // namespace tasks
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,7 +81,6 @@ class TaskApiFactory {
 | 
				
			||||||
    return std::make_unique<T>(std::move(runner));
 | 
					    return std::make_unique<T>(std::move(runner));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					 | 
				
			||||||
  template <typename Options>
 | 
					  template <typename Options>
 | 
				
			||||||
  static absl::Status CheckHasValidOptions(
 | 
					  static absl::Status CheckHasValidOptions(
 | 
				
			||||||
      const CalculatorGraphConfig::Node& node) {
 | 
					      const CalculatorGraphConfig::Node& node) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,10 +86,9 @@ cc_test(
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers:classification_result",
 | 
					        "//mediapipe/tasks/cc/components/containers:classification_result",
 | 
				
			||||||
        "@com_google_absl//absl/flags:flag",
 | 
					        "@com_google_absl//absl/flags:flag",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
        "@com_google_absl//absl/strings:cord",
 | 
					        "@com_google_absl//absl/strings:cord",
 | 
				
			||||||
        "@com_google_sentencepiece//src:sentencepiece_processor",
 | 
					        "@com_google_sentencepiece//src:sentencepiece_processor",  # fixdeps: keep
 | 
				
			||||||
        "@org_tensorflow//tensorflow/lite:test_util",
 | 
					        "@org_tensorflow//tensorflow/lite:test_util",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,8 +15,6 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
 | 
					#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <cmath>
 | 
					 | 
				
			||||||
#include <cstdlib>
 | 
					 | 
				
			||||||
#include <memory>
 | 
					#include <memory>
 | 
				
			||||||
#include <sstream>
 | 
					#include <sstream>
 | 
				
			||||||
#include <string>
 | 
					#include <string>
 | 
				
			||||||
| 
						 | 
					@ -24,7 +22,6 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/flags/flag.h"
 | 
					#include "absl/flags/flag.h"
 | 
				
			||||||
#include "absl/status/status.h"
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					 | 
				
			||||||
#include "absl/strings/cord.h"
 | 
					#include "absl/strings/cord.h"
 | 
				
			||||||
#include "absl/strings/str_cat.h"
 | 
					#include "absl/strings/str_cat.h"
 | 
				
			||||||
#include "absl/strings/string_view.h"
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,7 +45,7 @@ constexpr char kUniversalSentenceEncoderModel[] =
 | 
				
			||||||
// Tolerance for embedding vector coordinate values.
 | 
					// Tolerance for embedding vector coordinate values.
 | 
				
			||||||
constexpr float kEpsilon = 1e-4;
 | 
					constexpr float kEpsilon = 1e-4;
 | 
				
			||||||
// Tolerancy for cosine similarity evaluation.
 | 
					// Tolerancy for cosine similarity evaluation.
 | 
				
			||||||
constexpr double kSimilarityTolerancy = 1e-6;
 | 
					constexpr double kSimilarityTolerancy = 2e-2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using ::mediapipe::file::JoinPath;
 | 
					using ::mediapipe::file::JoinPath;
 | 
				
			||||||
using ::testing::HasSubstr;
 | 
					using ::testing::HasSubstr;
 | 
				
			||||||
| 
						 | 
					@ -79,6 +79,8 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
 | 
				
			||||||
  ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
 | 
					  ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 512);
 | 
				
			||||||
#ifdef _WIN32
 | 
					#ifdef _WIN32
 | 
				
			||||||
  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
 | 
					  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.2148f, kEpsilon);
 | 
				
			||||||
 | 
					#elif defined(__FMA__)
 | 
				
			||||||
 | 
					  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 21.3605f, kEpsilon);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
 | 
					  ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 19.9016f, kEpsilon);
 | 
				
			||||||
#endif  // _WIN32
 | 
					#endif  // _WIN32
 | 
				
			||||||
| 
						 | 
					@ -87,7 +89,11 @@ TEST_F(EmbedderTest, SucceedsWithMobileBert) {
 | 
				
			||||||
      auto result1, text_embedder->Embed("what a great and fantastic trip"));
 | 
					      auto result1, text_embedder->Embed("what a great and fantastic trip"));
 | 
				
			||||||
  ASSERT_EQ(result1.embeddings.size(), 1);
 | 
					  ASSERT_EQ(result1.embeddings.size(), 1);
 | 
				
			||||||
  ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
 | 
					  ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 512);
 | 
				
			||||||
 | 
					#ifdef __FMA__
 | 
				
			||||||
 | 
					  ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 21.254150f, kEpsilon);
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
  ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
 | 
					  ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 22.626251f, kEpsilon);
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Check cosine similarity.
 | 
					  // Check cosine similarity.
 | 
				
			||||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
					  MP_ASSERT_OK_AND_ASSIGN(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,6 +43,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
					        "//mediapipe/framework/formats:rect_cc_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers:rect",
 | 
					        "//mediapipe/tasks/cc/components/containers:rect",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:base_task_api",
 | 
					        "//mediapipe/tasks/cc/core:base_task_api",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core:task_runner",
 | 
					        "//mediapipe/tasks/cc/core:task_runner",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
					        "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
| 
						 | 
					@ -58,6 +59,7 @@ cc_library(
 | 
				
			||||||
        ":base_vision_task_api",
 | 
					        ":base_vision_task_api",
 | 
				
			||||||
        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
					        "//mediapipe/calculators/core:flow_limiter_calculator",
 | 
				
			||||||
        "//mediapipe/framework:calculator_cc_proto",
 | 
					        "//mediapipe/framework:calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core:task_api_factory",
 | 
				
			||||||
        "@com_google_absl//absl/status",
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,6 +26,7 @@ limitations under the License.
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
#include "absl/strings/str_cat.h"
 | 
					#include "absl/strings/str_cat.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator.pb.h"
 | 
					#include "mediapipe/framework/calculator.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
 | 
					#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
 | 
				
			||||||
#include "tensorflow/lite/core/api/op_resolver.h"
 | 
					#include "tensorflow/lite/core/api/op_resolver.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,13 +61,8 @@ class VisionTaskApiFactory {
 | 
				
			||||||
            "Task graph config should only contain one task subgraph node.",
 | 
					            "Task graph config should only contain one task subgraph node.",
 | 
				
			||||||
            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
					            MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        if (!node.options().HasExtension(Options::ext)) {
 | 
					        MP_RETURN_IF_ERROR(
 | 
				
			||||||
          return CreateStatusWithPayload(
 | 
					            tasks::core::TaskApiFactory::CheckHasValidOptions<Options>(node));
 | 
				
			||||||
              absl::StatusCode::kInvalidArgument,
 | 
					 | 
				
			||||||
              absl::StrCat(node.calculator(),
 | 
					 | 
				
			||||||
                           " is missing the required task options field."),
 | 
					 | 
				
			||||||
              MediaPipeTasksStatus::kInvalidTaskGraphConfigError);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        found_task_subgraph = true;
 | 
					        found_task_subgraph = true;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -153,6 +153,8 @@ cc_library(
 | 
				
			||||||
    alwayslink = 1,
 | 
					    alwayslink = 1,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: open source hand joints graph
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "hand_landmarker_result",
 | 
					    name = "hand_landmarker_result",
 | 
				
			||||||
    srcs = ["hand_landmarker_result.cc"],
 | 
					    srcs = ["hand_landmarker_result.cc"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,3 +41,5 @@ mediapipe_proto_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
 | 
					        "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: open source hand joints graph
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,6 +52,7 @@ cc_library(
 | 
				
			||||||
    name = "interactive_segmenter_graph",
 | 
					    name = "interactive_segmenter_graph",
 | 
				
			||||||
    srcs = ["interactive_segmenter_graph.cc"],
 | 
					    srcs = ["interactive_segmenter_graph.cc"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/image:image_transformation_calculator",
 | 
				
			||||||
        "//mediapipe/calculators/image:set_alpha_calculator",
 | 
					        "//mediapipe/calculators/image:set_alpha_calculator",
 | 
				
			||||||
        "//mediapipe/calculators/util:annotation_overlay_calculator",
 | 
					        "//mediapipe/calculators/util:annotation_overlay_calculator",
 | 
				
			||||||
        "//mediapipe/calculators/util:flat_color_image_calculator",
 | 
					        "//mediapipe/calculators/util:flat_color_image_calculator",
 | 
				
			||||||
| 
						 | 
					@ -60,6 +61,7 @@ cc_library(
 | 
				
			||||||
        "//mediapipe/calculators/util:to_image_calculator",
 | 
					        "//mediapipe/calculators/util:to_image_calculator",
 | 
				
			||||||
        "//mediapipe/framework:calculator_framework",
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
        "//mediapipe/framework/api2:builder",
 | 
					        "//mediapipe/framework/api2:builder",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:node",
 | 
				
			||||||
        "//mediapipe/framework/api2:port",
 | 
					        "//mediapipe/framework/api2:port",
 | 
				
			||||||
        "//mediapipe/framework/formats:image",
 | 
					        "//mediapipe/framework/formats:image",
 | 
				
			||||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
					        "//mediapipe/framework/formats:rect_cc_proto",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
 | 
				
			||||||
limitations under the License.
 | 
					limitations under the License.
 | 
				
			||||||
==============================================================================*/
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
#include <vector>
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "absl/status/statusor.h"
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
#include "absl/strings/string_view.h"
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
 | 
					#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
 | 
				
			||||||
#include "mediapipe/framework/api2/builder.h"
 | 
					#include "mediapipe/framework/api2/builder.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/node.h"
 | 
				
			||||||
#include "mediapipe/framework/api2/port.h"
 | 
					#include "mediapipe/framework/api2/port.h"
 | 
				
			||||||
#include "mediapipe/framework/calculator_framework.h"
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
#include "mediapipe/framework/formats/image.h"
 | 
					#include "mediapipe/framework/formats/image.h"
 | 
				
			||||||
| 
						 | 
					@ -35,6 +37,51 @@ namespace mediapipe {
 | 
				
			||||||
namespace tasks {
 | 
					namespace tasks {
 | 
				
			||||||
namespace vision {
 | 
					namespace vision {
 | 
				
			||||||
namespace interactive_segmenter {
 | 
					namespace interactive_segmenter {
 | 
				
			||||||
 | 
					namespace internal {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// A calculator to add thickness to the render data according to the image size,
 | 
				
			||||||
 | 
					// so that the render data is scale invariant to the image size. If the render
 | 
				
			||||||
 | 
					// data already has thickness, it will be kept as is.
 | 
				
			||||||
 | 
					class AddThicknessToRenderDataCalculator : public api2::Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr api2::Input<Image> kImageIn{"IMAGE"};
 | 
				
			||||||
 | 
					  static constexpr api2::Input<mediapipe::RenderData> kRenderDataIn{
 | 
				
			||||||
 | 
					      "RENDER_DATA"};
 | 
				
			||||||
 | 
					  static constexpr api2::Output<mediapipe::RenderData> kRenderDataOut{
 | 
				
			||||||
 | 
					      "RENDER_DATA"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static constexpr int kModelInputTensorWidth = 512;
 | 
				
			||||||
 | 
					  static constexpr int kModelInputTensorHeight = 512;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kImageIn, kRenderDataIn, kRenderDataOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) final {
 | 
				
			||||||
 | 
					    mediapipe::RenderData render_data = kRenderDataIn(cc).Get();
 | 
				
			||||||
 | 
					    Image image = kImageIn(cc).Get();
 | 
				
			||||||
 | 
					    double thickness = std::max(
 | 
				
			||||||
 | 
					        std::max(image.width() / static_cast<double>(kModelInputTensorWidth),
 | 
				
			||||||
 | 
					                 image.height() / static_cast<double>(kModelInputTensorHeight)),
 | 
				
			||||||
 | 
					        1.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (auto& annotation : *render_data.mutable_render_annotations()) {
 | 
				
			||||||
 | 
					      if (!annotation.has_thickness()) {
 | 
				
			||||||
 | 
					        annotation.set_thickness(thickness);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    kRenderDataOut(cc).Send(render_data);
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NOLINTBEGIN: Node registration doesn't work when part of calculator name is
 | 
				
			||||||
 | 
					// moved to next line.
 | 
				
			||||||
 | 
					// clang-format off
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(
 | 
				
			||||||
 | 
					    ::mediapipe::tasks::vision::interactive_segmenter::internal::AddThicknessToRenderDataCalculator);
 | 
				
			||||||
 | 
					// clang-format on
 | 
				
			||||||
 | 
					// NOLINTEND
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace internal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,6 +106,7 @@ constexpr absl::string_view kAlphaGpuTag{"ALPHA_GPU"};
 | 
				
			||||||
constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
					constexpr absl::string_view kNormRectTag{"NORM_RECT"};
 | 
				
			||||||
constexpr absl::string_view kRoiTag{"ROI"};
 | 
					constexpr absl::string_view kRoiTag{"ROI"};
 | 
				
			||||||
constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
					constexpr absl::string_view kQualityScoresTag{"QUALITY_SCORES"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kRenderDataTag{"RENDER_DATA"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Updates the graph to return `roi` stream which has same dimension as
 | 
					// Updates the graph to return `roi` stream which has same dimension as
 | 
				
			||||||
// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
 | 
					// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is
 | 
				
			||||||
| 
						 | 
					@ -69,14 +117,23 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
 | 
				
			||||||
  const absl::string_view image_tag_with_suffix =
 | 
					  const absl::string_view image_tag_with_suffix =
 | 
				
			||||||
      use_gpu ? kImageGpuTag : kImageCpuTag;
 | 
					      use_gpu ? kImageGpuTag : kImageCpuTag;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Adds thickness to the render data so that the render data is scale
 | 
				
			||||||
 | 
					  // invariant to the input image size.
 | 
				
			||||||
 | 
					  auto& add_thickness = graph.AddNode(
 | 
				
			||||||
 | 
					      "mediapipe::tasks::vision::interactive_segmenter::internal::"
 | 
				
			||||||
 | 
					      "AddThicknessToRenderDataCalculator");
 | 
				
			||||||
 | 
					  image >> add_thickness.In(kImageTag);
 | 
				
			||||||
 | 
					  roi >> add_thickness.In(kRenderDataTag);
 | 
				
			||||||
 | 
					  auto roi_with_thickness = add_thickness.Out(kRenderDataTag);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Generates a blank canvas with same size as input image.
 | 
					  // Generates a blank canvas with same size as input image.
 | 
				
			||||||
  auto& flat_color = graph.AddNode("FlatColorImageCalculator");
 | 
					  auto& flat_color = graph.AddNode("FlatColorImageCalculator");
 | 
				
			||||||
  auto& flat_color_options =
 | 
					  auto& flat_color_options =
 | 
				
			||||||
      flat_color.GetOptions<FlatColorImageCalculatorOptions>();
 | 
					      flat_color.GetOptions<FlatColorImageCalculatorOptions>();
 | 
				
			||||||
  // SetAlphaCalculator only takes 1st channel.
 | 
					  // SetAlphaCalculator only takes 1st channel.
 | 
				
			||||||
  flat_color_options.mutable_color()->set_r(0);
 | 
					  flat_color_options.mutable_color()->set_r(0);
 | 
				
			||||||
  image >> flat_color.In(kImageTag)[0];
 | 
					  image >> flat_color.In(kImageTag);
 | 
				
			||||||
  auto blank_canvas = flat_color.Out(kImageTag)[0];
 | 
					  auto blank_canvas = flat_color.Out(kImageTag);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto& from_mp_image = graph.AddNode("FromImageCalculator");
 | 
					  auto& from_mp_image = graph.AddNode("FromImageCalculator");
 | 
				
			||||||
  blank_canvas >> from_mp_image.In(kImageTag);
 | 
					  blank_canvas >> from_mp_image.In(kImageTag);
 | 
				
			||||||
| 
						 | 
					@ -85,7 +142,7 @@ Source<> RoiToAlpha(Source<Image> image, Source<RenderData> roi, bool use_gpu,
 | 
				
			||||||
  auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
 | 
					  auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator");
 | 
				
			||||||
  blank_canvas_in_cpu_or_gpu >>
 | 
					  blank_canvas_in_cpu_or_gpu >>
 | 
				
			||||||
      roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
 | 
					      roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag);
 | 
				
			||||||
  roi >> roi_to_alpha.In(0);
 | 
					  roi_with_thickness >> roi_to_alpha.In(0);
 | 
				
			||||||
  auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
 | 
					  auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return alpha;
 | 
					  return alpha;
 | 
				
			||||||
| 
						 | 
					@ -163,6 +220,7 @@ class InteractiveSegmenterGraph : public core::ModelTaskGraph {
 | 
				
			||||||
    image >> from_mp_image.In(kImageTag);
 | 
					    image >> from_mp_image.In(kImageTag);
 | 
				
			||||||
    auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
 | 
					    auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Creates an RGBA image with model input tensor size.
 | 
				
			||||||
    auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
 | 
					    auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto& set_alpha = graph.AddNode("SetAlphaCalculator");
 | 
					    auto& set_alpha = graph.AddNode("SetAlphaCalculator");
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -34,6 +34,7 @@ limitations under the License.
 | 
				
			||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
 | 
					#include "mediapipe/framework/port/opencv_core_inc.h"
 | 
				
			||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
 | 
					#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
 | 
				
			||||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
					#include "mediapipe/framework/port/status_matchers.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/tool/test_util.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/components/containers/keypoint.h"
 | 
					#include "mediapipe/tasks/cc/components/containers/keypoint.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
					#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
				
			||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
 | 
					#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
 | 
				
			||||||
| 
						 | 
					@ -70,6 +71,10 @@ constexpr absl::string_view kCatsAndDogsJpg{"cats_and_dogs.jpg"};
 | 
				
			||||||
// Golden mask for the dogs in cats_and_dogs.jpg.
 | 
					// Golden mask for the dogs in cats_and_dogs.jpg.
 | 
				
			||||||
constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
 | 
					constexpr absl::string_view kCatsAndDogsMaskDog1{"cats_and_dogs_mask_dog1.png"};
 | 
				
			||||||
constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
 | 
					constexpr absl::string_view kCatsAndDogsMaskDog2{"cats_and_dogs_mask_dog2.png"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kPenguinsLarge{"penguins_large.jpg"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kPenguinsSmall{"penguins_small.jpg"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kPenguinsSmallMask{"penguins_small_mask.png"};
 | 
				
			||||||
 | 
					constexpr absl::string_view kPenguinsLargeMask{"penguins_large_mask.png"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
constexpr float kGoldenMaskSimilarity = 0.97;
 | 
					constexpr float kGoldenMaskSimilarity = 0.97;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -183,6 +188,7 @@ struct InteractiveSegmenterTestParams {
 | 
				
			||||||
  std::string test_name;
 | 
					  std::string test_name;
 | 
				
			||||||
  RegionOfInterest::Format format;
 | 
					  RegionOfInterest::Format format;
 | 
				
			||||||
  std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
 | 
					  std::variant<NormalizedKeypoint, std::vector<NormalizedKeypoint>> roi;
 | 
				
			||||||
 | 
					  absl::string_view input_image_file;
 | 
				
			||||||
  absl::string_view golden_mask_file;
 | 
					  absl::string_view golden_mask_file;
 | 
				
			||||||
  float similarity_threshold;
 | 
					  float similarity_threshold;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
| 
						 | 
					@ -220,8 +226,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
 | 
				
			||||||
  const InteractiveSegmenterTestParams& params = GetParam();
 | 
					  const InteractiveSegmenterTestParams& params = GetParam();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
					  MP_ASSERT_OK_AND_ASSIGN(
 | 
				
			||||||
      Image image,
 | 
					      Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
 | 
				
			||||||
      DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
 | 
					                                                params.input_image_file)));
 | 
				
			||||||
  auto options = std::make_unique<InteractiveSegmenterOptions>();
 | 
					  auto options = std::make_unique<InteractiveSegmenterOptions>();
 | 
				
			||||||
  options->base_options.model_asset_path =
 | 
					  options->base_options.model_asset_path =
 | 
				
			||||||
      JoinPath("./", kTestDataDirectory, kPtmModel);
 | 
					      JoinPath("./", kTestDataDirectory, kPtmModel);
 | 
				
			||||||
| 
						 | 
					@ -244,6 +250,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) {
 | 
				
			||||||
  EXPECT_THAT(actual_mask,
 | 
					  EXPECT_THAT(actual_mask,
 | 
				
			||||||
              SimilarToUint8Mask(expected_mask, params.similarity_threshold,
 | 
					              SimilarToUint8Mask(expected_mask, params.similarity_threshold,
 | 
				
			||||||
                                 kGoldenMaskMagnificationFactor));
 | 
					                                 kGoldenMaskMagnificationFactor));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  cv::Mat visualized_mask;
 | 
				
			||||||
 | 
					  actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
 | 
				
			||||||
 | 
					  ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
 | 
				
			||||||
 | 
					                              visualized_mask.cols, visualized_mask.rows,
 | 
				
			||||||
 | 
					                              visualized_mask.step, visualized_mask.data,
 | 
				
			||||||
 | 
					                              [visualized_mask](uint8_t[]) {});
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(SavePngTestOutput(
 | 
				
			||||||
 | 
					      visualized_image, absl::StrFormat("%s_category_mask", params.test_name)));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
					TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
				
			||||||
| 
						 | 
					@ -252,8 +267,8 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
				
			||||||
  const InteractiveSegmenterTestParams& params = GetParam();
 | 
					  const InteractiveSegmenterTestParams& params = GetParam();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
					  MP_ASSERT_OK_AND_ASSIGN(
 | 
				
			||||||
      Image image,
 | 
					      Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
 | 
				
			||||||
      DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg)));
 | 
					                                                params.input_image_file)));
 | 
				
			||||||
  auto options = std::make_unique<InteractiveSegmenterOptions>();
 | 
					  auto options = std::make_unique<InteractiveSegmenterOptions>();
 | 
				
			||||||
  options->base_options.model_asset_path =
 | 
					  options->base_options.model_asset_path =
 | 
				
			||||||
      JoinPath("./", kTestDataDirectory, kPtmModel);
 | 
					      JoinPath("./", kTestDataDirectory, kPtmModel);
 | 
				
			||||||
| 
						 | 
					@ -275,6 +290,15 @@ TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) {
 | 
				
			||||||
      result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
 | 
					      result.confidence_masks->at(1).GetImageFrameSharedPtr().get());
 | 
				
			||||||
  EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
 | 
					  EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float,
 | 
				
			||||||
                                              params.similarity_threshold));
 | 
					                                              params.similarity_threshold));
 | 
				
			||||||
 | 
					  cv::Mat visualized_mask;
 | 
				
			||||||
 | 
					  actual_mask.convertTo(visualized_mask, CV_8UC1, /*alpha=*/255);
 | 
				
			||||||
 | 
					  ImageFrame visualized_image(mediapipe::ImageFormat::GRAY8,
 | 
				
			||||||
 | 
					                              visualized_mask.cols, visualized_mask.rows,
 | 
				
			||||||
 | 
					                              visualized_mask.step, visualized_mask.data,
 | 
				
			||||||
 | 
					                              [visualized_mask](uint8_t[]) {});
 | 
				
			||||||
 | 
					  MP_EXPECT_OK(SavePngTestOutput(
 | 
				
			||||||
 | 
					      visualized_image,
 | 
				
			||||||
 | 
					      absl::StrFormat("%s_confidence_mask", params.test_name)));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
INSTANTIATE_TEST_SUITE_P(
 | 
					INSTANTIATE_TEST_SUITE_P(
 | 
				
			||||||
| 
						 | 
					@ -282,21 +306,28 @@ INSTANTIATE_TEST_SUITE_P(
 | 
				
			||||||
    ::testing::ValuesIn<InteractiveSegmenterTestParams>(
 | 
					    ::testing::ValuesIn<InteractiveSegmenterTestParams>(
 | 
				
			||||||
        {// Keypoint input.
 | 
					        {// Keypoint input.
 | 
				
			||||||
         {"PointToDog1", RegionOfInterest::Format::kKeyPoint,
 | 
					         {"PointToDog1", RegionOfInterest::Format::kKeyPoint,
 | 
				
			||||||
          NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f},
 | 
					          NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsJpg, kCatsAndDogsMaskDog1,
 | 
				
			||||||
 | 
					          0.84f},
 | 
				
			||||||
         {"PointToDog2", RegionOfInterest::Format::kKeyPoint,
 | 
					         {"PointToDog2", RegionOfInterest::Format::kKeyPoint,
 | 
				
			||||||
          NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2,
 | 
					          NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsJpg, kCatsAndDogsMaskDog2,
 | 
				
			||||||
          kGoldenMaskSimilarity},
 | 
					          kGoldenMaskSimilarity},
 | 
				
			||||||
 | 
					         {"PenguinsSmall", RegionOfInterest::Format::kKeyPoint,
 | 
				
			||||||
 | 
					          NormalizedKeypoint{0.329, 0.545}, kPenguinsSmall, kPenguinsSmallMask,
 | 
				
			||||||
 | 
					          0.9f},
 | 
				
			||||||
 | 
					         {"PenguinsLarge", RegionOfInterest::Format::kKeyPoint,
 | 
				
			||||||
 | 
					          NormalizedKeypoint{0.329, 0.545}, kPenguinsLarge, kPenguinsLargeMask,
 | 
				
			||||||
 | 
					          0.9f},
 | 
				
			||||||
         // Scribble input.
 | 
					         // Scribble input.
 | 
				
			||||||
         {"ScribbleToDog1", RegionOfInterest::Format::kScribble,
 | 
					         {"ScribbleToDog1", RegionOfInterest::Format::kScribble,
 | 
				
			||||||
          std::vector{NormalizedKeypoint{0.44, 0.70},
 | 
					          std::vector{NormalizedKeypoint{0.44, 0.70},
 | 
				
			||||||
                      NormalizedKeypoint{0.44, 0.71},
 | 
					                      NormalizedKeypoint{0.44, 0.71},
 | 
				
			||||||
                      NormalizedKeypoint{0.44, 0.72}},
 | 
					                      NormalizedKeypoint{0.44, 0.72}},
 | 
				
			||||||
          kCatsAndDogsMaskDog1, 0.84f},
 | 
					          kCatsAndDogsJpg, kCatsAndDogsMaskDog1, 0.84f},
 | 
				
			||||||
         {"ScribbleToDog2", RegionOfInterest::Format::kScribble,
 | 
					         {"ScribbleToDog2", RegionOfInterest::Format::kScribble,
 | 
				
			||||||
          std::vector{NormalizedKeypoint{0.66, 0.66},
 | 
					          std::vector{NormalizedKeypoint{0.66, 0.66},
 | 
				
			||||||
                      NormalizedKeypoint{0.66, 0.67},
 | 
					                      NormalizedKeypoint{0.66, 0.67},
 | 
				
			||||||
                      NormalizedKeypoint{0.66, 0.68}},
 | 
					                      NormalizedKeypoint{0.66, 0.68}},
 | 
				
			||||||
          kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
 | 
					          kCatsAndDogsJpg, kCatsAndDogsMaskDog2, kGoldenMaskSimilarity}}),
 | 
				
			||||||
    [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
 | 
					    [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>&
 | 
				
			||||||
           info) { return info.param.test_name; });
 | 
					           info) { return info.param.test_name; });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,3 +60,9 @@ objc_library(
 | 
				
			||||||
    srcs = ["sources/MPPLandmark.m"],
 | 
					    srcs = ["sources/MPPLandmark.m"],
 | 
				
			||||||
    hdrs = ["sources/MPPLandmark.h"],
 | 
					    hdrs = ["sources/MPPLandmark.h"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPConnection",
 | 
				
			||||||
 | 
					    srcs = ["sources/MPPConnection.m"],
 | 
				
			||||||
 | 
					    hdrs = ["sources/MPPConnection.h"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,44 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <Foundation/Foundation.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_BEGIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** The value class representing a landmark connection. */
 | 
				
			||||||
 | 
					NS_SWIFT_NAME(Connection)
 | 
				
			||||||
 | 
					@interface MPPConnection : NSObject
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@property(nonatomic, readonly) NSUInteger start;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@property(nonatomic, readonly) NSUInteger end;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Initializes a new `MPPConnection` with the start and end landmarks integer constants.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @param start The integer representing the starting landmark of the connection.
 | 
				
			||||||
 | 
					 * @param end The integer representing the ending landmark of the connection.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An instance of `MPPConnection` initialized with the given start and end landmarks integer
 | 
				
			||||||
 | 
					 * constants.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end NS_DESIGNATED_INITIALIZER;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (instancetype)init NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (instancetype)new NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_END
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,28 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPConnection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (instancetype)initWithStart:(NSUInteger)start end:(NSUInteger)end {
 | 
				
			||||||
 | 
					  self = [super init];
 | 
				
			||||||
 | 
					  if (self) {
 | 
				
			||||||
 | 
					    _start = start;
 | 
				
			||||||
 | 
					    _end = end;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return self;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
| 
						 | 
					@ -54,3 +54,20 @@ ios_unit_test(
 | 
				
			||||||
        ":MPPImageObjcTestLibrary",
 | 
					        ":MPPImageObjcTestLibrary",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPMaskObjcTestLibrary",
 | 
				
			||||||
 | 
					    testonly = 1,
 | 
				
			||||||
 | 
					    srcs = ["MPPMaskTests.m"],
 | 
				
			||||||
 | 
					    deps = ["//mediapipe/tasks/ios/vision/core:MPPMask"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ios_unit_test(
 | 
				
			||||||
 | 
					    name = "MPPMaskObjcTest",
 | 
				
			||||||
 | 
					    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
				
			||||||
 | 
					    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":MPPMaskObjcTestLibrary",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										127
									
								
								mediapipe/tasks/ios/test/vision/core/MPPMaskTests.m
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								mediapipe/tasks/ios/test/vision/core/MPPMaskTests.m
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,127 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <XCTest/XCTest.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Unit tests for `MPPMask`. */
 | 
				
			||||||
 | 
					@interface MPPMaskTests : XCTestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPMaskTests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark - Tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testInitWithUInt8ArrayNoCopySucceeds {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSInteger width = 2;
 | 
				
			||||||
 | 
					  NSInteger height = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
 | 
				
			||||||
 | 
					  float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:NO];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.width, width);
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.height, height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if UInt8 mask is not copied.
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.uint8Data, (const UInt8*)uint8Data);
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(mask.float32Data, NULL);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0 ; i < width * height ; i ++) {
 | 
				
			||||||
 | 
					    XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f, @"index i = %d", i);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if repeated Float32 mask accesses return the same array in memory.
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.float32Data, mask.float32Data);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testInitWithUInt8ArrayCopySucceeds {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSInteger width = 2;
 | 
				
			||||||
 | 
					  NSInteger height = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  UInt8 uint8Data[] = {128, 128, 128, 128, 128, 128};
 | 
				
			||||||
 | 
					  float float32Data[] = {0.501f, 0.501f, 0.501f, 0.501f, 0.501f, 0.501f};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPMask *mask = [[MPPMask alloc] initWithUInt8Data:uint8Data width:width height:height shouldCopy:YES];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.width, width);
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.height, height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if UInt8 mask is copied.
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(mask.uint8Data, (const UInt8*)uint8Data);
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(mask.float32Data, NULL);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0 ; i < width * height ; i ++) {
 | 
				
			||||||
 | 
					    XCTAssertEqualWithAccuracy(mask.float32Data[i], float32Data[i], 1e-3f);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if repeated Float32 mask accesses return the same array in memory.
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.float32Data, mask.float32Data);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testInitWithFloat32ArrayNoCopySucceeds {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSInteger width = 2;
 | 
				
			||||||
 | 
					  NSInteger height = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
 | 
				
			||||||
 | 
					  float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
 | 
				
			||||||
 | 
					  MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:NO];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.width, width);
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.height, height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if Float32 mask is not copied.
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.float32Data, (const float*)float32Data);
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(mask.uint8Data, NULL);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0 ; i < width * height ; i ++) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if repeated UInt8 mask accesses return the same array in memory.
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.uint8Data, mask.uint8Data);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testInitWithFloat32ArrayCopySucceeds {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSInteger width = 2;
 | 
				
			||||||
 | 
					  NSInteger height = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  UInt8 uint8Data[] = {132, 132, 132, 132, 132, 132};
 | 
				
			||||||
 | 
					  float float32Data[] = {0.52f, 0.52f, 0.52f, 0.52f, 0.52f, 0.52f};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPMask *mask = [[MPPMask alloc] initWithFloat32Data:float32Data width:width height:height shouldCopy:YES];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.width, width);
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.height, height);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if Float32 mask is copied.
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(mask.float32Data, (const float*)float32Data);
 | 
				
			||||||
 | 
					  XCTAssertNotEqual(mask.uint8Data, NULL);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0 ; i < width * height ; i ++) {
 | 
				
			||||||
 | 
					    XCTAssertEqual(mask.uint8Data[i], uint8Data[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Test if repeated UInt8 mask accesses return the same array in memory.
 | 
				
			||||||
 | 
					  XCTAssertEqual(mask.uint8Data, mask.uint8Data);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
| 
						 | 
					@ -155,12 +155,12 @@ static const float kKeypointErrorThreshold = 1e-2;
 | 
				
			||||||
  NSInteger iterationCount = 100;
 | 
					  NSInteger iterationCount = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
					  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
				
			||||||
  // normal expectation will fail if expectation.fullfill() is not called
 | 
					  // normal expectation will fail if expectation.fulfill() is not called
 | 
				
			||||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
					  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
				
			||||||
  // only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
					  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
				
			||||||
  // Since it is not possible to predict how many times the expectation is supposed to be
 | 
					  // Since it is not possible to predict how many times the expectation is supposed to be
 | 
				
			||||||
  // fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
					  // fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
				
			||||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
 | 
					  // `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
 | 
				
			||||||
  // `iterationCount` times.
 | 
					  // `iterationCount` times.
 | 
				
			||||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
					  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
				
			||||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
					      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
				
			||||||
| 
						 | 
					@ -385,13 +385,13 @@ static const float kKeypointErrorThreshold = 1e-2;
 | 
				
			||||||
  NSInteger iterationCount = 100;
 | 
					  NSInteger iterationCount = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
					  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
				
			||||||
  // normal expectation will fail if expectation.fullfill() is not called times. An normal
 | 
					  // normal expectation will fail if expectation.fulfill() is not called times. An normal
 | 
				
			||||||
  // expectation will fail if expectation.fullfill() is not called
 | 
					  // expectation will fail if expectation.fulfill() is not called
 | 
				
			||||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
					  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
				
			||||||
  // only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
					  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
				
			||||||
  // Since it it not possible to determine how many times the expectation is supposed to be
 | 
					  // Since it it not possible to determine how many times the expectation is supposed to be
 | 
				
			||||||
  // fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
					  // fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
				
			||||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
 | 
					  // `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
 | 
				
			||||||
  // `iterationCount` times.
 | 
					  // `iterationCount` times.
 | 
				
			||||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
					  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
				
			||||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
					      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -174,12 +174,12 @@ constexpr float kFacialTransformationMatrixErrorThreshold = 0.2f;
 | 
				
			||||||
  NSInteger iterationCount = 100;
 | 
					  NSInteger iterationCount = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
					  // Because of flow limiting, the callback might be invoked fewer than `iterationCount` times. An
 | 
				
			||||||
  // normal expectation will fail if expectation.fullfill() is not called
 | 
					  // normal expectation will fail if expectation.fulfill() is not called
 | 
				
			||||||
  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
					  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
				
			||||||
  // only succeed if expectation is not fullfilled for the specified `expectedFulfillmentCount`.
 | 
					  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
				
			||||||
  // Since it is not possible to predict how many times the expectation is supposed to be
 | 
					  // Since it is not possible to predict how many times the expectation is supposed to be
 | 
				
			||||||
  // fullfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
					  // fulfilled, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
				
			||||||
  // `expectation.isInverted = true` ensures that test succeeds if expectation is fullfilled <=
 | 
					  // `expectation.isInverted = true` ensures that test succeeds if expectation is fulfilled <=
 | 
				
			||||||
  // `iterationCount` times.
 | 
					  // `iterationCount` times.
 | 
				
			||||||
  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
					  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
				
			||||||
      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
					      initWithDescription:@"detectWithOutOfOrderTimestampsAndLiveStream"];
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										62
									
								
								mediapipe/tasks/ios/test/vision/gesture_recognizer/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								mediapipe/tasks/ios/test/vision/gesture_recognizer/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,62 @@
 | 
				
			||||||
 | 
					load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "//mediapipe/framework/tool:ios.bzl",
 | 
				
			||||||
 | 
					    "MPP_TASK_MINIMUM_OS_VERSION",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "@org_tensorflow//tensorflow/lite:special_rules.bzl",
 | 
				
			||||||
 | 
					    "tflite_ios_lab_runner",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
 | 
				
			||||||
 | 
					TFL_DEFAULT_TAGS = [
 | 
				
			||||||
 | 
					    "apple",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Following sanitizer tests are not supported by iOS test targets.
 | 
				
			||||||
 | 
					TFL_DISABLED_SANITIZER_TAGS = [
 | 
				
			||||||
 | 
					    "noasan",
 | 
				
			||||||
 | 
					    "nomsan",
 | 
				
			||||||
 | 
					    "notsan",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPGestureRecognizerObjcTestLibrary",
 | 
				
			||||||
 | 
					    testonly = 1,
 | 
				
			||||||
 | 
					    srcs = ["MPPGestureRecognizerTests.m"],
 | 
				
			||||||
 | 
					    copts = [
 | 
				
			||||||
 | 
					        "-ObjC++",
 | 
				
			||||||
 | 
					        "-std=c++17",
 | 
				
			||||||
 | 
					        "-x objective-c++",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/vision:gesture_recognizer.task",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/vision:test_images",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/vision:test_protos",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/test/vision/gesture_recognizer/utils:MPPGestureRecognizerResultProtobufHelpers",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizer",
 | 
				
			||||||
 | 
					    ] + select({
 | 
				
			||||||
 | 
					        "//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
				
			||||||
 | 
					        "//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
				
			||||||
 | 
					        "//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
				
			||||||
 | 
					        "//conditions:default": ["@ios_opencv//:OpencvFramework"],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ios_unit_test(
 | 
				
			||||||
 | 
					    name = "MPPGestureRecognizerObjcTest",
 | 
				
			||||||
 | 
					    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
				
			||||||
 | 
					    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":MPPGestureRecognizerObjcTestLibrary",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,706 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <XCTest/XCTest.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizer.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kPbFileExtension = @"pbtxt";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kGestureRecognizerBundleAssetFile =
 | 
				
			||||||
 | 
					    @{@"name" : @"gesture_recognizer", @"type" : @"task"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kFistImage = @{@"name" : @"fist", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kPointingUpRotatedImage =
 | 
				
			||||||
 | 
					    @{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kExpectedFistLandmarksFile =
 | 
				
			||||||
 | 
					    @{@"name" : @"fist_landmarks", @"type" : kPbFileExtension};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
 | 
				
			||||||
 | 
					    @{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kFistLabel = @"Closed_Fist";
 | 
				
			||||||
 | 
					static NSString *const kExpectedThumbUpLabel = @"Thumb_Up";
 | 
				
			||||||
 | 
					static NSString *const kExpectedPointingUpLabel = @"Pointing_Up";
 | 
				
			||||||
 | 
					static NSString *const kRockLabel = @"Rock";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static const NSInteger kGestureExpectedIndex = -1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
				
			||||||
 | 
					static const float kLandmarksErrorTolerance = 0.03f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kLiveStreamTestsDictGestureRecognizerKey = @"gesture_recognizer";
 | 
				
			||||||
 | 
					static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertEqualErrors(error, expectedError)              \
 | 
				
			||||||
 | 
					  XCTAssertNotNil(error);                                    \
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(error.domain, expectedError.domain); \
 | 
				
			||||||
 | 
					  XCTAssertEqual(error.code, expectedError.code);            \
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertEqualGestures(gesture, expectedGesture, handIndex, gestureIndex)                  \
 | 
				
			||||||
 | 
					  XCTAssertEqual(gesture.index, kGestureExpectedIndex, @"hand index = %d gesture index j = %d", \
 | 
				
			||||||
 | 
					                 handIndex, gestureIndex);                                                      \
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(gesture.categoryName, expectedGesture.categoryName,                     \
 | 
				
			||||||
 | 
					                        @"hand index = %d gesture index j = %d", handIndex, gestureIndex);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex)   \
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance,            \
 | 
				
			||||||
 | 
					                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance,            \
 | 
				
			||||||
 | 
					                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult) \
 | 
				
			||||||
 | 
					  XCTAssertTrue(gestureRecognizerResult.gestures.count == 0);         \
 | 
				
			||||||
 | 
					  XCTAssertTrue(gestureRecognizerResult.handedness.count == 0);       \
 | 
				
			||||||
 | 
					  XCTAssertTrue(gestureRecognizerResult.landmarks.count == 0);        \
 | 
				
			||||||
 | 
					  XCTAssertTrue(gestureRecognizerResult.worldLandmarks.count == 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@interface MPPGestureRecognizerTests : XCTestCase <MPPGestureRecognizerLiveStreamDelegate> {
 | 
				
			||||||
 | 
					  NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
 | 
				
			||||||
 | 
					  NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPGestureRecognizerTests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Expected Results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPGestureRecognizerResult *)emptyGestureRecognizerResult {
 | 
				
			||||||
 | 
					  return [[MPPGestureRecognizerResult alloc] initWithGestures:@[]
 | 
				
			||||||
 | 
					                                                   handedness:@[]
 | 
				
			||||||
 | 
					                                                    landmarks:@[]
 | 
				
			||||||
 | 
					                                               worldLandmarks:@[]
 | 
				
			||||||
 | 
					                                      timestampInMilliseconds:0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPGestureRecognizerResult *)thumbUpGestureRecognizerResult {
 | 
				
			||||||
 | 
					  NSString *filePath =
 | 
				
			||||||
 | 
					      [MPPGestureRecognizerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [MPPGestureRecognizerResult
 | 
				
			||||||
 | 
					      gestureRecognizerResultsFromProtobufFileWithName:filePath
 | 
				
			||||||
 | 
					                                          gestureLabel:kExpectedThumbUpLabel
 | 
				
			||||||
 | 
					                                 shouldRemoveZPosition:YES];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPGestureRecognizerResult *)fistGestureRecognizerResultWithLabel:(NSString *)gestureLabel {
 | 
				
			||||||
 | 
					  NSString *filePath = [MPPGestureRecognizerTests filePathWithFileInfo:kExpectedFistLandmarksFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [MPPGestureRecognizerResult gestureRecognizerResultsFromProtobufFileWithName:filePath
 | 
				
			||||||
 | 
					                                                                         gestureLabel:gestureLabel
 | 
				
			||||||
 | 
					                                                                shouldRemoveZPosition:YES];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Assert Gesture Recognizer Results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
 | 
				
			||||||
 | 
					    areApproximatelyEqualToExpectedMultiHandLandmarks:
 | 
				
			||||||
 | 
					        (NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
 | 
				
			||||||
 | 
					  XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
 | 
				
			||||||
 | 
					  if (multiHandLandmarks.count == 0) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
 | 
				
			||||||
 | 
					  NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
 | 
				
			||||||
 | 
					  for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
 | 
				
			||||||
 | 
					    MPPNormalizedLandmark *landmark = topHandLandmarks[i];
 | 
				
			||||||
 | 
					    XCTAssertNotNil(landmark);
 | 
				
			||||||
 | 
					    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
 | 
				
			||||||
 | 
					    areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
 | 
				
			||||||
 | 
					        (NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
 | 
				
			||||||
 | 
					  XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
 | 
				
			||||||
 | 
					  if (expectedMultiHandWorldLandmarks.count == 0) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
 | 
				
			||||||
 | 
					  NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
 | 
				
			||||||
 | 
					  for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
 | 
				
			||||||
 | 
					    MPPLandmark *landmark = topHandWorldLandmarks[i];
 | 
				
			||||||
 | 
					    XCTAssertNotNil(landmark);
 | 
				
			||||||
 | 
					    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertMultiHandGestures:(NSArray<NSArray<MPPCategory *> *> *)multiHandGestures
 | 
				
			||||||
 | 
					    areApproximatelyEqualToExpectedMultiHandGestures:
 | 
				
			||||||
 | 
					        (NSArray<NSArray<MPPCategory *> *> *)expectedMultiHandGestures {
 | 
				
			||||||
 | 
					  XCTAssertEqual(multiHandGestures.count, expectedMultiHandGestures.count);
 | 
				
			||||||
 | 
					  if (multiHandGestures.count == 0) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSArray<MPPCategory *> *topHandGestures = multiHandGestures[0];
 | 
				
			||||||
 | 
					  NSArray<MPPCategory *> *expectedTopHandGestures = expectedMultiHandGestures[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(topHandGestures.count, expectedTopHandGestures.count);
 | 
				
			||||||
 | 
					  for (int i = 0; i < expectedTopHandGestures.count; i++) {
 | 
				
			||||||
 | 
					    MPPCategory *gesture = topHandGestures[i];
 | 
				
			||||||
 | 
					    XCTAssertNotNil(gesture);
 | 
				
			||||||
 | 
					    AssertEqualGestures(gesture, expectedTopHandGestures[i], 0, i);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertGestureRecognizerResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
 | 
				
			||||||
 | 
					    isApproximatelyEqualToExpectedResult:
 | 
				
			||||||
 | 
					        (MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
 | 
				
			||||||
 | 
					  [self assertMultiHandLandmarks:gestureRecognizerResult.landmarks
 | 
				
			||||||
 | 
					      areApproximatelyEqualToExpectedMultiHandLandmarks:expectedGestureRecognizerResult.landmarks];
 | 
				
			||||||
 | 
					  [self assertMultiHandWorldLandmarks:gestureRecognizerResult.worldLandmarks
 | 
				
			||||||
 | 
					      areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedGestureRecognizerResult
 | 
				
			||||||
 | 
					                                                                 .worldLandmarks];
 | 
				
			||||||
 | 
					  [self assertMultiHandGestures:gestureRecognizerResult.gestures
 | 
				
			||||||
 | 
					      areApproximatelyEqualToExpectedMultiHandGestures:expectedGestureRecognizerResult.gestures];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertResultsOfRecognizeImageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
				
			||||||
 | 
					                           usingGestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
 | 
				
			||||||
 | 
					       approximatelyEqualsGestureRecognizerResult:
 | 
				
			||||||
 | 
					           (MPPGestureRecognizerResult *)expectedGestureRecognizerResult {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
				
			||||||
 | 
					      [self recognizeImageWithFileInfo:fileInfo usingGestureRecognizer:gestureRecognizer];
 | 
				
			||||||
 | 
					  [self assertGestureRecognizerResult:gestureRecognizerResult
 | 
				
			||||||
 | 
					      isApproximatelyEqualToExpectedResult:expectedGestureRecognizerResult];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark File
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
				
			||||||
 | 
					  NSString *filePath = [MPPGestureRecognizerTests filePathWithName:fileInfo[@"name"]
 | 
				
			||||||
 | 
					                                                         extension:fileInfo[@"type"]];
 | 
				
			||||||
 | 
					  return filePath;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
 | 
				
			||||||
 | 
					  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
 | 
				
			||||||
 | 
					                                                                      ofType:extension];
 | 
				
			||||||
 | 
					  return filePath;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Gesture Recognizer Initializers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPGestureRecognizerOptions *)gestureRecognizerOptionsWithModelFileInfo:
 | 
				
			||||||
 | 
					    (ResourceFileInfo *)modelFileInfo {
 | 
				
			||||||
 | 
					  NSString *modelPath = [MPPGestureRecognizerTests filePathWithFileInfo:modelFileInfo];
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [[MPPGestureRecognizerOptions alloc] init];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.baseOptions.modelAssetPath = modelPath;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return gestureRecognizerOptions;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPGestureRecognizer *)createGestureRecognizerWithOptionsSucceeds:
 | 
				
			||||||
 | 
					    (MPPGestureRecognizerOptions *)gestureRecognizerOptions {
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:nil];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(gestureRecognizer);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return gestureRecognizer;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertCreateGestureRecognizerWithOptions:
 | 
				
			||||||
 | 
					            (MPPGestureRecognizerOptions *)gestureRecognizerOptions
 | 
				
			||||||
 | 
					                          failsWithExpectedError:(NSError *)expectedError {
 | 
				
			||||||
 | 
					  NSError *error = nil;
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [[MPPGestureRecognizer alloc] initWithOptions:gestureRecognizerOptions error:&error];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertNil(gestureRecognizer);
 | 
				
			||||||
 | 
					  AssertEqualErrors(error, expectedError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Recognize Helpers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
				
			||||||
 | 
					  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
 | 
				
			||||||
 | 
					                                              fileName:fileInfo[@"name"]
 | 
				
			||||||
 | 
					                                                ofType:fileInfo[@"type"]];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(image);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return image;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
				
			||||||
 | 
					                    orientation:(UIImageOrientation)orientation {
 | 
				
			||||||
 | 
					  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPGestureRecognizerTests class]
 | 
				
			||||||
 | 
					                                              fileName:fileInfo[@"name"]
 | 
				
			||||||
 | 
					                                                ofType:fileInfo[@"type"]
 | 
				
			||||||
 | 
					                                           orientation:orientation];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(image);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return image;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPGestureRecognizerResult *)recognizeImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
 | 
				
			||||||
 | 
					                                    usingGestureRecognizer:
 | 
				
			||||||
 | 
					                                        (MPPGestureRecognizer *)gestureRecognizer {
 | 
				
			||||||
 | 
					  MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
 | 
				
			||||||
 | 
					                                                                                    error:nil];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(gestureRecognizerResult);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return gestureRecognizerResult;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark General Tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithModelPathSucceeds {
 | 
				
			||||||
 | 
					  NSString *modelPath =
 | 
				
			||||||
 | 
					      [MPPGestureRecognizerTests filePathWithFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [[MPPGestureRecognizer alloc] initWithModelPath:modelPath error:nil];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(gestureRecognizer);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertResultsOfRecognizeImageWithFileInfo:kThumbUpImage
 | 
				
			||||||
 | 
					                           usingGestureRecognizer:gestureRecognizer
 | 
				
			||||||
 | 
					       approximatelyEqualsGestureRecognizerResult:[MPPGestureRecognizerTests
 | 
				
			||||||
 | 
					                                                      thumbUpGestureRecognizerResult]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithEmptyResultsSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
				
			||||||
 | 
					      [self recognizeImageWithFileInfo:kNoHandsImage usingGestureRecognizer:gestureRecognizer];
 | 
				
			||||||
 | 
					  AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithScoreThresholdSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
				
			||||||
 | 
					      [self recognizeImageWithFileInfo:kThumbUpImage usingGestureRecognizer:gestureRecognizer];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *expectedGestureRecognizerResult =
 | 
				
			||||||
 | 
					      [MPPGestureRecognizerTests thumbUpGestureRecognizerResult];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertTrue(gestureRecognizerResult.gestures.count == 1);
 | 
				
			||||||
 | 
					  AssertEqualGestures(gestureRecognizerResult.gestures[0][0],
 | 
				
			||||||
 | 
					                      expectedGestureRecognizerResult.gestures[0][0], 0, 0);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithNumHandsSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const NSInteger numHands = 2;
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.numHands = numHands;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
				
			||||||
 | 
					      [self recognizeImageWithFileInfo:kTwoHandsImage usingGestureRecognizer:gestureRecognizer];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertTrue(gestureRecognizerResult.handedness.count == numHands);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithRotationSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.numHands = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					  MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
 | 
				
			||||||
 | 
					                                   orientation:UIImageOrientationRight];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage
 | 
				
			||||||
 | 
					                                                                                    error:nil];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertNotNil(gestureRecognizerResult);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(gestureRecognizerResult.gestures.count, 1);
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(gestureRecognizerResult.gestures[0][0].categoryName,
 | 
				
			||||||
 | 
					                        kExpectedPointingUpLabel);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithCannedGestureFistSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.numHands = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertResultsOfRecognizeImageWithFileInfo:kFistImage
 | 
				
			||||||
 | 
					                           usingGestureRecognizer:gestureRecognizer
 | 
				
			||||||
 | 
					       approximatelyEqualsGestureRecognizerResult:
 | 
				
			||||||
 | 
					           [MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithAllowGestureFistSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.numHands = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertResultsOfRecognizeImageWithFileInfo:kFistImage
 | 
				
			||||||
 | 
					                           usingGestureRecognizer:gestureRecognizer
 | 
				
			||||||
 | 
					       approximatelyEqualsGestureRecognizerResult:
 | 
				
			||||||
 | 
					           [MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithDenyGestureFistSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.numHands = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					  MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
				
			||||||
 | 
					      [self recognizeImageWithFileInfo:kFistImage usingGestureRecognizer:gestureRecognizer];
 | 
				
			||||||
 | 
					  AssertGestureRecognizerResultIsEmpty(gestureRecognizerResult);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithPreferAllowlistOverDenylistSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *gestureRecognizerOptions =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions = [[MPPClassifierOptions alloc] init];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.scoreThreshold = 0.5f;
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryAllowlist = @[ kFistLabel ];
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.cannedGesturesClassifierOptions.categoryDenylist = @[ kFistLabel ];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  gestureRecognizerOptions.numHands = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertResultsOfRecognizeImageWithFileInfo:kFistImage
 | 
				
			||||||
 | 
					                           usingGestureRecognizer:gestureRecognizer
 | 
				
			||||||
 | 
					       approximatelyEqualsGestureRecognizerResult:
 | 
				
			||||||
 | 
					           [MPPGestureRecognizerTests fistGestureRecognizerResultWithLabel:kFistLabel]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Running Mode Tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testCreateGestureRecognizerFailsWithDelegateInNonLiveStreamMode {
 | 
				
			||||||
 | 
					  MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
 | 
				
			||||||
 | 
					  for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
 | 
				
			||||||
 | 
					    MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					        [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    options.runningMode = runningModesToTest[i];
 | 
				
			||||||
 | 
					    options.gestureRecognizerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [self assertCreateGestureRecognizerWithOptions:options
 | 
				
			||||||
 | 
					                            failsWithExpectedError:
 | 
				
			||||||
 | 
					                                [NSError
 | 
				
			||||||
 | 
					                                    errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                                           userInfo:@{
 | 
				
			||||||
 | 
					                                             NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                                                 @"The vision task is in image or video mode. The "
 | 
				
			||||||
 | 
					                                                 @"delegate must not be set in the task's options."
 | 
				
			||||||
 | 
					                                           }]];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testCreateGestureRecognizerFailsWithMissingDelegateInLiveStreamMode {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self
 | 
				
			||||||
 | 
					      assertCreateGestureRecognizerWithOptions:options
 | 
				
			||||||
 | 
					                        failsWithExpectedError:
 | 
				
			||||||
 | 
					                            [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                                                code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                                            userInfo:@{
 | 
				
			||||||
 | 
					                                              NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                                                  @"The vision task is in live stream mode. An "
 | 
				
			||||||
 | 
					                                                  @"object must be set as the delegate of the task "
 | 
				
			||||||
 | 
					                                                  @"in its options to ensure asynchronous delivery "
 | 
				
			||||||
 | 
					                                                  @"of results."
 | 
				
			||||||
 | 
					                                            }]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeFailsWithCallingWrongApiInImageMode {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kFistImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *liveStreamApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
 | 
				
			||||||
 | 
					                                timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                                  error:&liveStreamApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedLiveStreamApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
				
			||||||
 | 
					                                                    @"stream mode. Current Running Mode: Image"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *videoApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
 | 
				
			||||||
 | 
					                                timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                                  error:&videoApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedVideoApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"video mode. Current Running Mode: Image"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeFailsWithCallingWrongApiInVideoMode {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeVideo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kFistImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *liveStreamApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
 | 
				
			||||||
 | 
					                                timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                                  error:&liveStreamApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedLiveStreamApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
				
			||||||
 | 
					                                                    @"stream mode. Current Running Mode: Video"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *imageApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedImageApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"image mode. Current Running Mode: Video"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeFailsWithCallingWrongApiInLiveStreamMode {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					  options.gestureRecognizerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kFistImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *imageApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeImage:image error:&imageApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedImageApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"image mode. Current Running Mode: Live Stream"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *videoApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeVideoFrame:image
 | 
				
			||||||
 | 
					                                timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                                  error:&videoApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedVideoApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"video mode. Current Running Mode: Live Stream"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithVideoModeSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeVideo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < 3; i++) {
 | 
				
			||||||
 | 
					    MPPGestureRecognizerResult *gestureRecognizerResult =
 | 
				
			||||||
 | 
					        [gestureRecognizer recognizeVideoFrame:image timestampInMilliseconds:i error:nil];
 | 
				
			||||||
 | 
					    [self assertGestureRecognizerResult:gestureRecognizerResult
 | 
				
			||||||
 | 
					        isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
 | 
				
			||||||
 | 
					                                                 thumbUpGestureRecognizerResult]];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithOutOfOrderTimestampsAndLiveStreamModeFails {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					  options.gestureRecognizerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
				
			||||||
 | 
					      initWithDescription:@"recognizeWithOutOfOrderTimestampsAndLiveStream"];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  expectation.expectedFulfillmentCount = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _outOfOrderTimestampTestDict = @{
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictExpectationKey : expectation
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image timestampInMilliseconds:1 error:nil]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *error;
 | 
				
			||||||
 | 
					  XCTAssertFalse([gestureRecognizer recognizeAsyncImage:image
 | 
				
			||||||
 | 
					                                timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                                  error:&error]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                            @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(error, expectedError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSTimeInterval timeout = 0.5f;
 | 
				
			||||||
 | 
					  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testRecognizeWithLiveStreamModeSucceeds {
 | 
				
			||||||
 | 
					  MPPGestureRecognizerOptions *options =
 | 
				
			||||||
 | 
					      [self gestureRecognizerOptionsWithModelFileInfo:kGestureRecognizerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					  options.gestureRecognizerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSInteger iterationCount = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
 | 
				
			||||||
 | 
					  // times. An normal expectation will fail if expectation.fulfill() is not called
 | 
				
			||||||
 | 
					  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
				
			||||||
 | 
					  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
				
			||||||
 | 
					  // Since in our case we cannot predict how many times the expectation is supposed to be fulfilled
 | 
				
			||||||
 | 
					  // setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
				
			||||||
 | 
					  // `expectation.isInverted = true` ensures that test succeeds ifexpectation is fulfilled <=
 | 
				
			||||||
 | 
					  // `iterationCount` times.
 | 
				
			||||||
 | 
					  XCTestExpectation *expectation =
 | 
				
			||||||
 | 
					      [[XCTestExpectation alloc] initWithDescription:@"recognizeWithLiveStream"];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  expectation.expectedFulfillmentCount = iterationCount + 1;
 | 
				
			||||||
 | 
					  expectation.inverted = YES;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPGestureRecognizer *gestureRecognizer =
 | 
				
			||||||
 | 
					      [self createGestureRecognizerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _liveStreamSucceedsTestDict = @{
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictGestureRecognizerKey : gestureRecognizer,
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictExpectationKey : expectation
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
 | 
				
			||||||
 | 
					  // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
 | 
				
			||||||
 | 
					  // `CMSampleBuffer`.
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < iterationCount; i++) {
 | 
				
			||||||
 | 
					    XCTAssertTrue([gestureRecognizer recognizeAsyncImage:image
 | 
				
			||||||
 | 
					                                 timestampInMilliseconds:i
 | 
				
			||||||
 | 
					                                                   error:nil]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSTimeInterval timeout = 0.5f;
 | 
				
			||||||
 | 
					  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)gestureRecognizer:(MPPGestureRecognizer *)gestureRecognizer
 | 
				
			||||||
 | 
					    didFinishRecognitionWithResult:(MPPGestureRecognizerResult *)gestureRecognizerResult
 | 
				
			||||||
 | 
					           timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
				
			||||||
 | 
					                             error:(NSError *)error {
 | 
				
			||||||
 | 
					  [self assertGestureRecognizerResult:gestureRecognizerResult
 | 
				
			||||||
 | 
					      isApproximatelyEqualToExpectedResult:[MPPGestureRecognizerTests
 | 
				
			||||||
 | 
					                                               thumbUpGestureRecognizerResult]];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (gestureRecognizer == _outOfOrderTimestampTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
 | 
				
			||||||
 | 
					    [_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
				
			||||||
 | 
					  } else if (gestureRecognizer ==
 | 
				
			||||||
 | 
					             _liveStreamSucceedsTestDict[kLiveStreamTestsDictGestureRecognizerKey]) {
 | 
				
			||||||
 | 
					    [_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,22 @@
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPGestureRecognizerResultProtobufHelpers",
 | 
				
			||||||
 | 
					    srcs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.mm"],
 | 
				
			||||||
 | 
					    hdrs = ["sources/MPPGestureRecognizerResult+ProtobufHelpers.h"],
 | 
				
			||||||
 | 
					    copts = [
 | 
				
			||||||
 | 
					        "-ObjC++",
 | 
				
			||||||
 | 
					        "-std=c++17",
 | 
				
			||||||
 | 
					        "-x objective-c++",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/vision/gesture_recognizer:MPPGestureRecognizerResult",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/vision/gesture_recognizer/utils:MPPGestureRecognizerResultHelpers",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,28 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <Foundation/Foundation.h>
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_BEGIN
 | 
				
			||||||
 | 
					@interface MPPGestureRecognizerResult (ProtobufHelpers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPGestureRecognizerResult *)
 | 
				
			||||||
 | 
					    gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
 | 
				
			||||||
 | 
					                                        gestureLabel:(NSString *)gestureLabel
 | 
				
			||||||
 | 
					                               shouldRemoveZPosition:(BOOL)removeZPosition;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_END
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,65 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/test/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+ProtobufHelpers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/gesture_recognizer/utils/sources/MPPGestureRecognizerResult+Helpers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					using ClassificationListProto = ::mediapipe::ClassificationList;
 | 
				
			||||||
 | 
					using ClassificationProto = ::mediapipe::Classification;
 | 
				
			||||||
 | 
					using LandmarksDetectionResultProto =
 | 
				
			||||||
 | 
					    ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
 | 
				
			||||||
 | 
					using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
 | 
				
			||||||
 | 
					}  // anonymous namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPGestureRecognizerResult (ProtobufHelpers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPGestureRecognizerResult *)
 | 
				
			||||||
 | 
					    gestureRecognizerResultsFromProtobufFileWithName:(NSString *)fileName
 | 
				
			||||||
 | 
					                                        gestureLabel:(NSString *)gestureLabel
 | 
				
			||||||
 | 
					                               shouldRemoveZPosition:(BOOL)removeZPosition {
 | 
				
			||||||
 | 
					  LandmarksDetectionResultProto landmarkDetectionResultProto;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
 | 
				
			||||||
 | 
					    return nil;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (removeZPosition) {
 | 
				
			||||||
 | 
					    // Remove z position of landmarks, because they are not used in correctness testing. For video
 | 
				
			||||||
 | 
					    // or live stream mode, the z positions varies a lot during tracking from frame to frame.
 | 
				
			||||||
 | 
					    for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
 | 
				
			||||||
 | 
					      auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
 | 
				
			||||||
 | 
					      landmark.clear_z();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ClassificationListProto gesturesProto;
 | 
				
			||||||
 | 
					  ClassificationProto *classificationProto = gesturesProto.add_classification();
 | 
				
			||||||
 | 
					  classificationProto->set_label([gestureLabel UTF8String]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [MPPGestureRecognizerResult
 | 
				
			||||||
 | 
					      gestureRecognizerResultWithHandGesturesProto:{gesturesProto}
 | 
				
			||||||
 | 
					                                   handednessProto:{landmarkDetectionResultProto.classifications()}
 | 
				
			||||||
 | 
					                                handLandmarksProto:{landmarkDetectionResultProto.landmarks()}
 | 
				
			||||||
 | 
					                               worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
 | 
				
			||||||
 | 
					                           timestampInMilliseconds:0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
							
								
								
									
										62
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,62 @@
 | 
				
			||||||
 | 
					load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "//mediapipe/framework/tool:ios.bzl",
 | 
				
			||||||
 | 
					    "MPP_TASK_MINIMUM_OS_VERSION",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					load(
 | 
				
			||||||
 | 
					    "@org_tensorflow//tensorflow/lite:special_rules.bzl",
 | 
				
			||||||
 | 
					    "tflite_ios_lab_runner",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Default tags for filtering iOS targets. Targets are restricted to Apple platforms.
 | 
				
			||||||
 | 
					TFL_DEFAULT_TAGS = [
 | 
				
			||||||
 | 
					    "apple",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Following sanitizer tests are not supported by iOS test targets.
 | 
				
			||||||
 | 
					TFL_DISABLED_SANITIZER_TAGS = [
 | 
				
			||||||
 | 
					    "noasan",
 | 
				
			||||||
 | 
					    "nomsan",
 | 
				
			||||||
 | 
					    "notsan",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPHandLandmarkerObjcTestLibrary",
 | 
				
			||||||
 | 
					    testonly = 1,
 | 
				
			||||||
 | 
					    srcs = ["MPPHandLandmarkerTests.m"],
 | 
				
			||||||
 | 
					    copts = [
 | 
				
			||||||
 | 
					        "-ObjC++",
 | 
				
			||||||
 | 
					        "-std=c++17",
 | 
				
			||||||
 | 
					        "-x objective-c++",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    data = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/vision:hand_landmarker.task",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/vision:test_images",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/testdata/vision:test_protos",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/test/vision/hand_landmarker/utils:MPPHandLandmarkerResultProtobufHelpers",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/test/vision/utils:MPPImageTestUtils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarker",
 | 
				
			||||||
 | 
					    ] + select({
 | 
				
			||||||
 | 
					        "//third_party:opencv_ios_sim_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
				
			||||||
 | 
					        "//third_party:opencv_ios_arm64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
				
			||||||
 | 
					        "//third_party:opencv_ios_x86_64_source_build": ["@ios_opencv_source//:opencv_xcframework"],
 | 
				
			||||||
 | 
					        "//conditions:default": ["@ios_opencv//:OpencvFramework"],
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ios_unit_test(
 | 
				
			||||||
 | 
					    name = "MPPHandLandmarkerObjcTest",
 | 
				
			||||||
 | 
					    minimum_os_version = MPP_TASK_MINIMUM_OS_VERSION,
 | 
				
			||||||
 | 
					    runner = tflite_ios_lab_runner("IOS_LATEST"),
 | 
				
			||||||
 | 
					    tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":MPPHandLandmarkerObjcTestLibrary",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,557 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <XCTest/XCTest.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/test/vision/utils/sources/MPPImage+TestUtils.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarker.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kPbFileExtension = @"pbtxt";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					typedef NSDictionary<NSString *, NSString *> ResourceFileInfo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kHandLandmarkerBundleAssetFile =
 | 
				
			||||||
 | 
					    @{@"name" : @"hand_landmarker", @"type" : @"task"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kTwoHandsImage = @{@"name" : @"right_hands", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kNoHandsImage = @{@"name" : @"cats_and_dogs", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kThumbUpImage = @{@"name" : @"thumb_up", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kPointingUpRotatedImage =
 | 
				
			||||||
 | 
					    @{@"name" : @"pointing_up_rotated", @"type" : @"jpg"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
 | 
				
			||||||
 | 
					    @{@"name" : @"thumb_up_landmarks", @"type" : kPbFileExtension};
 | 
				
			||||||
 | 
					static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
 | 
				
			||||||
 | 
					    @{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
				
			||||||
 | 
					static const float kLandmarksErrorTolerance = 0.03f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static NSString *const kLiveStreamTestsDictHandLandmarkerKey = @"gesture_recognizer";
 | 
				
			||||||
 | 
					static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertEqualErrors(error, expectedError)              \
 | 
				
			||||||
 | 
					  XCTAssertNotNil(error);                                    \
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(error.domain, expectedError.domain); \
 | 
				
			||||||
 | 
					  XCTAssertEqual(error.code, expectedError.code);            \
 | 
				
			||||||
 | 
					  XCTAssertEqualObjects(error.localizedDescription, expectedError.localizedDescription)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertApproximatelyEqualLandmarks(landmark, expectedLandmark, handIndex, landmarkIndex)   \
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(landmark.x, expectedLandmark.x, kLandmarksErrorTolerance,            \
 | 
				
			||||||
 | 
					                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex); \
 | 
				
			||||||
 | 
					  XCTAssertEqualWithAccuracy(landmark.y, expectedLandmark.y, kLandmarksErrorTolerance,            \
 | 
				
			||||||
 | 
					                             @"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define AssertHandLandmarkerResultIsEmpty(handLandmarkerResult) \
 | 
				
			||||||
 | 
					  XCTAssertTrue(handLandmarkerResult.handedness.count == 0);    \
 | 
				
			||||||
 | 
					  XCTAssertTrue(handLandmarkerResult.landmarks.count == 0);     \
 | 
				
			||||||
 | 
					  XCTAssertTrue(handLandmarkerResult.worldLandmarks.count == 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@interface MPPHandLandmarkerTests : XCTestCase <MPPHandLandmarkerLiveStreamDelegate> {
 | 
				
			||||||
 | 
					  NSDictionary<NSString *, id> *_liveStreamSucceedsTestDict;
 | 
				
			||||||
 | 
					  NSDictionary<NSString *, id> *_outOfOrderTimestampTestDict;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPHandLandmarkerTests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPHandLandmarkerResult *)emptyHandLandmarkerResult {
 | 
				
			||||||
 | 
					  return [[MPPHandLandmarkerResult alloc] initWithLandmarks:@[]
 | 
				
			||||||
 | 
					                                             worldLandmarks:@[]
 | 
				
			||||||
 | 
					                                                 handedness:@[]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                    timestampInMilliseconds:0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
 | 
				
			||||||
 | 
					  NSString *filePath = [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
 | 
				
			||||||
 | 
					                                                         shouldRemoveZPosition:YES];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPHandLandmarkerResult *)pointingUpRotatedHandLandmarkerResult {
 | 
				
			||||||
 | 
					  NSString *filePath =
 | 
				
			||||||
 | 
					      [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
 | 
				
			||||||
 | 
					                                                         shouldRemoveZPosition:YES];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
 | 
				
			||||||
 | 
					    areApproximatelyEqualToExpectedMultiHandLandmarks:
 | 
				
			||||||
 | 
					        (NSArray<NSArray<MPPNormalizedLandmark *> *> *)expectedMultiHandLandmarks {
 | 
				
			||||||
 | 
					  XCTAssertEqual(multiHandLandmarks.count, expectedMultiHandLandmarks.count);
 | 
				
			||||||
 | 
					  if (multiHandLandmarks.count == 0) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSArray<MPPNormalizedLandmark *> *topHandLandmarks = multiHandLandmarks[0];
 | 
				
			||||||
 | 
					  NSArray<MPPNormalizedLandmark *> *expectedTopHandLandmarks = expectedMultiHandLandmarks[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(topHandLandmarks.count, expectedTopHandLandmarks.count);
 | 
				
			||||||
 | 
					  for (int i = 0; i < expectedTopHandLandmarks.count; i++) {
 | 
				
			||||||
 | 
					    MPPNormalizedLandmark *landmark = topHandLandmarks[i];
 | 
				
			||||||
 | 
					    XCTAssertNotNil(landmark);
 | 
				
			||||||
 | 
					    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandLandmarks[i], 0, i);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertMultiHandWorldLandmarks:(NSArray<NSArray<MPPLandmark *> *> *)multiHandWorldLandmarks
 | 
				
			||||||
 | 
					    areApproximatelyEqualToExpectedMultiHandWorldLandmarks:
 | 
				
			||||||
 | 
					        (NSArray<NSArray<MPPLandmark *> *> *)expectedMultiHandWorldLandmarks {
 | 
				
			||||||
 | 
					  XCTAssertEqual(multiHandWorldLandmarks.count, expectedMultiHandWorldLandmarks.count);
 | 
				
			||||||
 | 
					  if (expectedMultiHandWorldLandmarks.count == 0) {
 | 
				
			||||||
 | 
					    return;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSArray<MPPLandmark *> *topHandWorldLandmarks = multiHandWorldLandmarks[0];
 | 
				
			||||||
 | 
					  NSArray<MPPLandmark *> *expectedTopHandWorldLandmarks = expectedMultiHandWorldLandmarks[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertEqual(topHandWorldLandmarks.count, expectedTopHandWorldLandmarks.count);
 | 
				
			||||||
 | 
					  for (int i = 0; i < expectedTopHandWorldLandmarks.count; i++) {
 | 
				
			||||||
 | 
					    MPPLandmark *landmark = topHandWorldLandmarks[i];
 | 
				
			||||||
 | 
					    XCTAssertNotNil(landmark);
 | 
				
			||||||
 | 
					    AssertApproximatelyEqualLandmarks(landmark, expectedTopHandWorldLandmarks[i], 0, i);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
 | 
				
			||||||
 | 
					    isApproximatelyEqualToExpectedResult:(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
				
			||||||
 | 
					  [self assertMultiHandLandmarks:handLandmarkerResult.landmarks
 | 
				
			||||||
 | 
					      areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
 | 
				
			||||||
 | 
					  [self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
 | 
				
			||||||
 | 
					      areApproximatelyEqualToExpectedMultiHandWorldLandmarks:expectedHandLandmarkerResult
 | 
				
			||||||
 | 
					                                                                 .worldLandmarks];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark File
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
				
			||||||
 | 
					  NSString *filePath = [MPPHandLandmarkerTests filePathWithName:fileInfo[@"name"]
 | 
				
			||||||
 | 
					                                                      extension:fileInfo[@"type"]];
 | 
				
			||||||
 | 
					  return filePath;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (NSString *)filePathWithName:(NSString *)fileName extension:(NSString *)extension {
 | 
				
			||||||
 | 
					  NSString *filePath = [[NSBundle bundleForClass:self.class] pathForResource:fileName
 | 
				
			||||||
 | 
					                                                                      ofType:extension];
 | 
				
			||||||
 | 
					  return filePath;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Hand Landmarker Initializers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
 | 
				
			||||||
 | 
					    (ResourceFileInfo *)modelFileInfo {
 | 
				
			||||||
 | 
					  NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *handLandmarkerOptions = [[MPPHandLandmarkerOptions alloc] init];
 | 
				
			||||||
 | 
					  handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return handLandmarkerOptions;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
 | 
				
			||||||
 | 
					    (MPPHandLandmarkerOptions *)handLandmarkerOptions {
 | 
				
			||||||
 | 
					  NSError *error;
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker =
 | 
				
			||||||
 | 
					      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(handLandmarker);
 | 
				
			||||||
 | 
					  XCTAssertNil(error);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return handLandmarker;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertCreateHandLandmarkerWithOptions:(MPPHandLandmarkerOptions *)handLandmarkerOptions
 | 
				
			||||||
 | 
					                       failsWithExpectedError:(NSError *)expectedError {
 | 
				
			||||||
 | 
					  NSError *error = nil;
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker =
 | 
				
			||||||
 | 
					      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertNil(handLandmarker);
 | 
				
			||||||
 | 
					  AssertEqualErrors(error, expectedError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Assert Hand Landmarker Results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo {
 | 
				
			||||||
 | 
					  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
 | 
				
			||||||
 | 
					                                              fileName:fileInfo[@"name"]
 | 
				
			||||||
 | 
					                                                ofType:fileInfo[@"type"]];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(image);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return image;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPImage *)imageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
				
			||||||
 | 
					                    orientation:(UIImageOrientation)orientation {
 | 
				
			||||||
 | 
					  MPPImage *image = [MPPImage imageFromBundleWithClass:[MPPHandLandmarkerTests class]
 | 
				
			||||||
 | 
					                                              fileName:fileInfo[@"name"]
 | 
				
			||||||
 | 
					                                                ofType:fileInfo[@"type"]
 | 
				
			||||||
 | 
					                                           orientation:orientation];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(image);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return image;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
 | 
				
			||||||
 | 
					                                   usingHandLandmarker:(MPPHandLandmarker *)handLandmarker {
 | 
				
			||||||
 | 
					  MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
 | 
				
			||||||
 | 
					  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(handLandmarkerResult);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return handLandmarkerResult;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)assertResultsOfDetectInImageWithFileInfo:(ResourceFileInfo *)fileInfo
 | 
				
			||||||
 | 
					                             usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
 | 
				
			||||||
 | 
					         approximatelyEqualsHandLandmarkerResult:
 | 
				
			||||||
 | 
					             (MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:fileInfo
 | 
				
			||||||
 | 
					                                                              usingHandLandmarker:handLandmarker];
 | 
				
			||||||
 | 
					  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
				
			||||||
 | 
					      isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark General Tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithModelPathSucceeds {
 | 
				
			||||||
 | 
					  NSString *modelPath =
 | 
				
			||||||
 | 
					      [MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [[MPPHandLandmarker alloc] initWithModelPath:modelPath
 | 
				
			||||||
 | 
					                                                                             error:nil];
 | 
				
			||||||
 | 
					  XCTAssertNotNil(handLandmarker);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
 | 
				
			||||||
 | 
					                             usingHandLandmarker:handLandmarker
 | 
				
			||||||
 | 
					         approximatelyEqualsHandLandmarkerResult:[MPPHandLandmarkerTests
 | 
				
			||||||
 | 
					                                                     thumbUpHandLandmarkerResult]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithEmptyResultsSucceeds {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker =
 | 
				
			||||||
 | 
					      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kNoHandsImage
 | 
				
			||||||
 | 
					                                                              usingHandLandmarker:handLandmarker];
 | 
				
			||||||
 | 
					  AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithNumHandsSucceeds {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const NSInteger numHands = 2;
 | 
				
			||||||
 | 
					  handLandmarkerOptions.numHands = numHands;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker =
 | 
				
			||||||
 | 
					      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kTwoHandsImage
 | 
				
			||||||
 | 
					                                                              usingHandLandmarker:handLandmarker];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithRotationSucceeds {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker =
 | 
				
			||||||
 | 
					      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
 | 
				
			||||||
 | 
					                                   orientation:UIImageOrientationRight];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
				
			||||||
 | 
					      isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests
 | 
				
			||||||
 | 
					                                               pointingUpRotatedHandLandmarkerResult]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma mark Running Mode Tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testCreateHandLandmarkerFailsWithDelegateInNonLiveStreamMode {
 | 
				
			||||||
 | 
					  MPPRunningMode runningModesToTest[] = {MPPRunningModeImage, MPPRunningModeVideo};
 | 
				
			||||||
 | 
					  for (int i = 0; i < sizeof(runningModesToTest) / sizeof(runningModesToTest[0]); i++) {
 | 
				
			||||||
 | 
					    MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					        [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    options.runningMode = runningModesToTest[i];
 | 
				
			||||||
 | 
					    options.handLandmarkerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    [self
 | 
				
			||||||
 | 
					        assertCreateHandLandmarkerWithOptions:options
 | 
				
			||||||
 | 
					                       failsWithExpectedError:
 | 
				
			||||||
 | 
					                           [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                                           userInfo:@{
 | 
				
			||||||
 | 
					                                             NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                                                 @"The vision task is in image or video mode. The "
 | 
				
			||||||
 | 
					                                                 @"delegate must not be set in the task's options."
 | 
				
			||||||
 | 
					                                           }]];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testCreateHandLandmarkerFailsWithMissingDelegateInLiveStreamMode {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  [self assertCreateHandLandmarkerWithOptions:options
 | 
				
			||||||
 | 
					                       failsWithExpectedError:
 | 
				
			||||||
 | 
					                           [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                                               code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                                           userInfo:@{
 | 
				
			||||||
 | 
					                                             NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                                                 @"The vision task is in live stream mode. An "
 | 
				
			||||||
 | 
					                                                 @"object must be set as the delegate of the task "
 | 
				
			||||||
 | 
					                                                 @"in its options to ensure asynchronous delivery "
 | 
				
			||||||
 | 
					                                                 @"of results."
 | 
				
			||||||
 | 
					                                           }]];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectFailsWithCallingWrongApiInImageMode {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *liveStreamApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectAsyncInImage:image
 | 
				
			||||||
 | 
					                            timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                              error:&liveStreamApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedLiveStreamApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
				
			||||||
 | 
					                                                    @"stream mode. Current Running Mode: Image"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *videoApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectInVideoFrame:image
 | 
				
			||||||
 | 
					                            timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                              error:&videoApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedVideoApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"video mode. Current Running Mode: Image"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectFailsWithCallingWrongApiInVideoMode {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeVideo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *liveStreamApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectAsyncInImage:image
 | 
				
			||||||
 | 
					                            timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                              error:&liveStreamApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedLiveStreamApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with live "
 | 
				
			||||||
 | 
					                                                    @"stream mode. Current Running Mode: Video"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *imageApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedImageApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"image mode. Current Running Mode: Video"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectFailsWithCallingWrongApiInLiveStreamMode {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					  options.handLandmarkerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *imageApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectInImage:image error:&imageApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedImageApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"image mode. Current Running Mode: Live Stream"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *videoApiCallError;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectInVideoFrame:image
 | 
				
			||||||
 | 
					                            timestampInMilliseconds:0
 | 
				
			||||||
 | 
					                                              error:&videoApiCallError]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedVideoApiCallError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey : @"The vision task is not initialized with "
 | 
				
			||||||
 | 
					                                                    @"video mode. Current Running Mode: Live Stream"
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(videoApiCallError, expectedVideoApiCallError);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithVideoModeSucceeds {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeVideo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < 3; i++) {
 | 
				
			||||||
 | 
					    MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInVideoFrame:image
 | 
				
			||||||
 | 
					                                                               timestampInMilliseconds:i
 | 
				
			||||||
 | 
					                                                                                 error:nil];
 | 
				
			||||||
 | 
					    [self assertHandLandmarkerResult:handLandmarkerResult
 | 
				
			||||||
 | 
					        isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithOutOfOrderTimestampsAndLiveStreamModeFails {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					  options.handLandmarkerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTestExpectation *expectation = [[XCTestExpectation alloc]
 | 
				
			||||||
 | 
					      initWithDescription:@"detectWiththOutOfOrderTimestampsAndLiveStream"];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  expectation.expectedFulfillmentCount = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _outOfOrderTimestampTestDict = @{
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictExpectationKey : expectation
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:1 error:nil]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *error;
 | 
				
			||||||
 | 
					  XCTAssertFalse([handLandmarker detectAsyncInImage:image timestampInMilliseconds:0 error:&error]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSError *expectedError =
 | 
				
			||||||
 | 
					      [NSError errorWithDomain:kExpectedErrorDomain
 | 
				
			||||||
 | 
					                          code:MPPTasksErrorCodeInvalidArgumentError
 | 
				
			||||||
 | 
					                      userInfo:@{
 | 
				
			||||||
 | 
					                        NSLocalizedDescriptionKey :
 | 
				
			||||||
 | 
					                            @"INVALID_ARGUMENT: Input timestamp must be monotonically increasing."
 | 
				
			||||||
 | 
					                      }];
 | 
				
			||||||
 | 
					  AssertEqualErrors(error, expectedError);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSTimeInterval timeout = 0.5f;
 | 
				
			||||||
 | 
					  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)testDetectWithLiveStreamModeSucceeds {
 | 
				
			||||||
 | 
					  MPPHandLandmarkerOptions *options =
 | 
				
			||||||
 | 
					      [self handLandmarkerOptionsWithModelFileInfo:kHandLandmarkerBundleAssetFile];
 | 
				
			||||||
 | 
					  options.runningMode = MPPRunningModeLiveStream;
 | 
				
			||||||
 | 
					  options.handLandmarkerLiveStreamDelegate = self;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSInteger iterationCount = 100;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Because of flow limiting, we cannot ensure that the callback will be invoked `iterationCount`
 | 
				
			||||||
 | 
					  // times. An normal expectation will fail if expectation.fulfill() is not called
 | 
				
			||||||
 | 
					  // `expectation.expectedFulfillmentCount` times. If `expectation.isInverted = true`, the test will
 | 
				
			||||||
 | 
					  // only succeed if expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
				
			||||||
 | 
					  // Since in our case we cannot predict how many times the expectation is supposed to be fullfilled
 | 
				
			||||||
 | 
					  // setting, `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
				
			||||||
 | 
					  // `expectation.isInverted = true` ensures that test succeeds ifexpectation is fullfilled <=
 | 
				
			||||||
 | 
					  // `iterationCount` times.
 | 
				
			||||||
 | 
					  XCTestExpectation *expectation =
 | 
				
			||||||
 | 
					      [[XCTestExpectation alloc] initWithDescription:@"detectWithLiveStream"];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  expectation.expectedFulfillmentCount = iterationCount + 1;
 | 
				
			||||||
 | 
					  expectation.inverted = YES;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MPPHandLandmarker *handLandmarker = [self createHandLandmarkerWithOptionsSucceeds:options];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _liveStreamSucceedsTestDict = @{
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictHandLandmarkerKey : handLandmarker,
 | 
				
			||||||
 | 
					    kLiveStreamTestsDictExpectationKey : expectation
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO: Mimic initialization from CMSampleBuffer as live stream mode is most likely to be used
 | 
				
			||||||
 | 
					  // with the iOS camera. AVCaptureVideoDataOutput sample buffer delegates provide frames of type
 | 
				
			||||||
 | 
					  // `CMSampleBuffer`.
 | 
				
			||||||
 | 
					  MPPImage *image = [self imageWithFileInfo:kThumbUpImage];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (int i = 0; i < iterationCount; i++) {
 | 
				
			||||||
 | 
					    XCTAssertTrue([handLandmarker detectAsyncInImage:image timestampInMilliseconds:i error:nil]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  NSTimeInterval timeout = 0.5f;
 | 
				
			||||||
 | 
					  [self waitForExpectations:@[ expectation ] timeout:timeout];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (void)handLandmarker:(MPPHandLandmarker *)handLandmarker
 | 
				
			||||||
 | 
					    didFinishDetectionWithResult:(MPPHandLandmarkerResult *)handLandmarkerResult
 | 
				
			||||||
 | 
					         timestampInMilliseconds:(NSInteger)timestampInMilliseconds
 | 
				
			||||||
 | 
					                           error:(NSError *)error {
 | 
				
			||||||
 | 
					  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
				
			||||||
 | 
					      isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests thumbUpHandLandmarkerResult]];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (handLandmarker == _outOfOrderTimestampTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
 | 
				
			||||||
 | 
					    [_outOfOrderTimestampTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
				
			||||||
 | 
					  } else if (handLandmarker == _liveStreamSucceedsTestDict[kLiveStreamTestsDictHandLandmarkerKey]) {
 | 
				
			||||||
 | 
					    [_liveStreamSucceedsTestDict[kLiveStreamTestsDictExpectationKey] fulfill];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
							
								
								
									
										22
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/utils/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								mediapipe/tasks/ios/test/vision/hand_landmarker/utils/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,22 @@
 | 
				
			||||||
 | 
					package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					licenses(["notice"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPHandLandmarkerResultProtobufHelpers",
 | 
				
			||||||
 | 
					    srcs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.mm"],
 | 
				
			||||||
 | 
					    hdrs = ["sources/MPPHandLandmarkerResult+ProtobufHelpers.h"],
 | 
				
			||||||
 | 
					    copts = [
 | 
				
			||||||
 | 
					        "-ObjC++",
 | 
				
			||||||
 | 
					        "-std=c++17",
 | 
				
			||||||
 | 
					        "-x objective-c++",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarkerResult",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,26 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <Foundation/Foundation.h>
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/hand_landmarker/sources/MPPHandLandmarkerResult.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_BEGIN
 | 
				
			||||||
 | 
					@interface MPPHandLandmarkerResult (ProtobufHelpers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
 | 
				
			||||||
 | 
					                                                    shouldRemoveZPosition:(BOOL)removeZPosition;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_END
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,58 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/test/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+ProtobufHelpers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/hand_landmarker/utils/sources/MPPHandLandmarkerResult+Helpers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/ios/test/vision/utils/sources/parse_proto_utils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					using ClassificationListProto = ::mediapipe::ClassificationList;
 | 
				
			||||||
 | 
					using ClassificationProto = ::mediapipe::Classification;
 | 
				
			||||||
 | 
					using LandmarksDetectionResultProto =
 | 
				
			||||||
 | 
					    ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
 | 
				
			||||||
 | 
					using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
 | 
				
			||||||
 | 
					}  // anonymous namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPHandLandmarkerResult (ProtobufHelpers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
 | 
				
			||||||
 | 
					                                                    shouldRemoveZPosition:(BOOL)removeZPosition {
 | 
				
			||||||
 | 
					  LandmarksDetectionResultProto landmarkDetectionResultProto;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
 | 
				
			||||||
 | 
					    return nil;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (removeZPosition) {
 | 
				
			||||||
 | 
					    // Remove z position of landmarks, because they are not used in correctness testing. For video
 | 
				
			||||||
 | 
					    // or live stream mode, the z positions varies a lot during tracking from frame to frame.
 | 
				
			||||||
 | 
					    for (int i = 0; i < landmarkDetectionResultProto.landmarks().landmark().size(); i++) {
 | 
				
			||||||
 | 
					      auto &landmark = *landmarkDetectionResultProto.mutable_landmarks()->mutable_landmark(i);
 | 
				
			||||||
 | 
					      landmark.clear_z();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return [MPPHandLandmarkerResult
 | 
				
			||||||
 | 
					      handLandmarkerResultWithLandmarksProto:{landmarkDetectionResultProto.landmarks()}
 | 
				
			||||||
 | 
					                         worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
 | 
				
			||||||
 | 
					                             handednessProto:{landmarkDetectionResultProto.classifications()}
 | 
				
			||||||
 | 
					                     timestampInMilliSeconds:0];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
| 
						 | 
					@ -673,10 +673,10 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation";
 | 
				
			||||||
  // If `expectation.isInverted = true`, the test will only succeed if
 | 
					  // If `expectation.isInverted = true`, the test will only succeed if
 | 
				
			||||||
  // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
					  // expectation is not fulfilled for the specified `expectedFulfillmentCount`.
 | 
				
			||||||
  // Since in our case we cannot predict how many times the expectation is
 | 
					  // Since in our case we cannot predict how many times the expectation is
 | 
				
			||||||
  // supposed to be fullfilled setting,
 | 
					  // supposed to be fulfilled setting,
 | 
				
			||||||
  // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
					  // `expectation.expectedFulfillmentCount` = `iterationCount` + 1 and
 | 
				
			||||||
  // `expectation.isInverted = true` ensures that test succeeds if
 | 
					  // `expectation.isInverted = true` ensures that test succeeds if
 | 
				
			||||||
  // expectation is fullfilled <= `iterationCount` times.
 | 
					  // expectation is fulfilled <= `iterationCount` times.
 | 
				
			||||||
  XCTestExpectation *expectation =
 | 
					  XCTestExpectation *expectation =
 | 
				
			||||||
      [[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
 | 
					      [[XCTestExpectation alloc] initWithDescription:@"classifyWithLiveStream"];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -64,3 +64,17 @@ objc_library(
 | 
				
			||||||
        "@com_google_absl//absl/status:statusor",
 | 
					        "@com_google_absl//absl/status:statusor",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPMask",
 | 
				
			||||||
 | 
					    srcs = ["sources/MPPMask.mm"],
 | 
				
			||||||
 | 
					    hdrs = ["sources/MPPMask.h"],
 | 
				
			||||||
 | 
					    copts = [
 | 
				
			||||||
 | 
					        "-ObjC++",
 | 
				
			||||||
 | 
					        "-std=c++17",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common:MPPCommon",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										118
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,118 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import <Foundation/Foundation.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_BEGIN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** The underlying type of the segmentation mask. */
 | 
				
			||||||
 | 
					typedef NS_ENUM(NSUInteger, MPPMaskDataType) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Represents the native `UInt8 *` type. */
 | 
				
			||||||
 | 
					  MPPMaskDataTypeUInt8,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Represents the native `float *` type. */
 | 
				
			||||||
 | 
					  MPPMaskDataTypeFloat32,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} NS_SWIFT_NAME(MaskDataType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * The wrapper class for MediaPipe segmentation masks.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Masks are stored as `UInt8 *` or `float *` objects.
 | 
				
			||||||
 | 
					 * Every mask has an underlying type which can be accessed using `dataType`. You can access the
 | 
				
			||||||
 | 
					 * mask as any other type using the appropriate properties. For example, if the underlying type is
 | 
				
			||||||
 | 
					 * `MPPMaskDataTypeUInt8`, in addition to accessing the mask using `uint8Data`, you can access
 | 
				
			||||||
 | 
					 * `float32Data` to get the 32 bit float data (with values ranging from 0.0 to 1.0). The first
 | 
				
			||||||
 | 
					 * time you access the data as a type different from the underlying type, an expensive type
 | 
				
			||||||
 | 
					 * conversion is performed. Subsequent accesses return a pointer to the memory location fo the same
 | 
				
			||||||
 | 
					 * type converted array. As type conversions can be expensive, it is recommended to limit the
 | 
				
			||||||
 | 
					 * accesses to data of types different from the underlying type.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Masks that are returned from a MediaPipe Tasks are owned by by the underlying C++ Task. If you
 | 
				
			||||||
 | 
					 * need to extend the lifetime of these objects, you can invoke the `[MPPMask copy:]` method.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					NS_SWIFT_NAME(Mask)
 | 
				
			||||||
 | 
					@interface MPPMask : NSObject <NSCopying>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** The width of the mask. */
 | 
				
			||||||
 | 
					@property(nonatomic, readonly) NSInteger width;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** The height of the mask. */
 | 
				
			||||||
 | 
					@property(nonatomic, readonly) NSInteger height;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** The data type of the mask. */
 | 
				
			||||||
 | 
					@property(nonatomic, readonly) MPPMaskDataType dataType;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * The pointer to the memory location where the underlying mask as a single channel `UInt8` array is
 | 
				
			||||||
 | 
					 * stored. Uint8 values use the full value range and range from 0 to 255.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					@property(nonatomic, readonly, assign) const UInt8 *uint8Data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * The pointer to the memory location where the underlying mask as a single channel float 32 array
 | 
				
			||||||
 | 
					 * is stored. Float values range from 0.0 to 1.0.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					@property(nonatomic, readonly, assign) const float *float32Data;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Initializes an `MPPMask` object of type `MPPMaskDataTypeUInt8` with the given `UInt8*` data,
 | 
				
			||||||
 | 
					 * width and height.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
 | 
				
			||||||
 | 
					 * `uint8Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
 | 
				
			||||||
 | 
					 * the `MPPMask` must outlive the passed in `uint8Data`.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @param uint8Data A pointer to the memory location of the `UInt8` data array.
 | 
				
			||||||
 | 
					 * @param width The width of the mask.
 | 
				
			||||||
 | 
					 * @param height The height of the mask.
 | 
				
			||||||
 | 
					 * @param shouldCopy The height of the mask.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return A new `MPPMask` instance with the given `UInt8*` data, width and height.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
 | 
				
			||||||
 | 
					                                     width:(NSInteger)width
 | 
				
			||||||
 | 
					                                    height:(NSInteger)height
 | 
				
			||||||
 | 
					                                shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Initializes an `MPPMask` object of type `MPPMaskDataTypeFloat32` with the given `float*` data,
 | 
				
			||||||
 | 
					 * width and height.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * If `shouldCopy` is set to `YES`, the newly created `MPPMask` stores a reference to a deep copied
 | 
				
			||||||
 | 
					 * `float32Data`. Since deep copies are expensive, it is recommended to not set `shouldCopy` unless
 | 
				
			||||||
 | 
					 * the `MPPMask` must outlive the passed in `float32Data`.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @param float32Data A pointer to the memory location of the `float` data array.
 | 
				
			||||||
 | 
					 * @param width The width of the mask.
 | 
				
			||||||
 | 
					 * @param height The height of the mask.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return A new `MPPMask` instance with the given `float*` data, width and height.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
 | 
				
			||||||
 | 
					                                       width:(NSInteger)width
 | 
				
			||||||
 | 
					                                      height:(NSInteger)height
 | 
				
			||||||
 | 
					                                  shouldCopy:(BOOL)shouldCopy NS_DESIGNATED_INITIALIZER;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: Add methods for CVPixelBuffer conversion.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Unavailable. */
 | 
				
			||||||
 | 
					- (instancetype)init NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					+ (instancetype)new NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NS_ASSUME_NONNULL_END
 | 
				
			||||||
							
								
								
									
										135
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.mm
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								mediapipe/tasks/ios/vision/core/sources/MPPMask.mm
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,135 @@
 | 
				
			||||||
 | 
					// Copyright 2023 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/vision/core/sources/MPPMask.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/sources/MPPCommon.h"
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@interface MPPMask () {
 | 
				
			||||||
 | 
					  const UInt8 *_uint8Data;
 | 
				
			||||||
 | 
					  const float *_float32Data;
 | 
				
			||||||
 | 
					  std::unique_ptr<UInt8[]> _uint8DataPtr;
 | 
				
			||||||
 | 
					  std::unique_ptr<float[]> _float32DataPtr;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@implementation MPPMask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (nullable instancetype)initWithUInt8Data:(const UInt8 *)uint8Data
 | 
				
			||||||
 | 
					                                     width:(NSInteger)width
 | 
				
			||||||
 | 
					                                    height:(NSInteger)height
 | 
				
			||||||
 | 
					                                shouldCopy:(BOOL)shouldCopy {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  self = [super init];
 | 
				
			||||||
 | 
					  if (self) {
 | 
				
			||||||
 | 
					    _width = width;
 | 
				
			||||||
 | 
					    _height = height;
 | 
				
			||||||
 | 
					    _dataType = MPPMaskDataTypeUInt8;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (shouldCopy) {
 | 
				
			||||||
 | 
					      size_t length = _width * _height;
 | 
				
			||||||
 | 
					      _uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
 | 
				
			||||||
 | 
					      _uint8Data = _uint8DataPtr.get();
 | 
				
			||||||
 | 
					      memcpy((UInt8 *)_uint8Data, uint8Data, length * sizeof(UInt8));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      _uint8Data = uint8Data;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return self;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (nullable instancetype)initWithFloat32Data:(const float *)float32Data
 | 
				
			||||||
 | 
					                                       width:(NSInteger)width
 | 
				
			||||||
 | 
					                                      height:(NSInteger)height
 | 
				
			||||||
 | 
					                                  shouldCopy:(BOOL)shouldCopy {
 | 
				
			||||||
 | 
					  self = [super init];
 | 
				
			||||||
 | 
					  if (self) {
 | 
				
			||||||
 | 
					    _width = width;
 | 
				
			||||||
 | 
					    _height = height;
 | 
				
			||||||
 | 
					    _dataType = MPPMaskDataTypeFloat32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (shouldCopy) {
 | 
				
			||||||
 | 
					      size_t length = _width * _height;
 | 
				
			||||||
 | 
					      _float32DataPtr = std::unique_ptr<float[]>(new float[length]);
 | 
				
			||||||
 | 
					      _float32Data = _float32DataPtr.get();
 | 
				
			||||||
 | 
					      memcpy((float *)_float32Data, float32Data, length * sizeof(float));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      _float32Data = float32Data;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return self;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (const UInt8 *)uint8Data {
 | 
				
			||||||
 | 
					  switch (_dataType) {
 | 
				
			||||||
 | 
					    case MPPMaskDataTypeUInt8: {
 | 
				
			||||||
 | 
					      return _uint8Data;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    case MPPMaskDataTypeFloat32: {
 | 
				
			||||||
 | 
					      if (_uint8DataPtr) {
 | 
				
			||||||
 | 
					        return _uint8DataPtr.get();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      size_t length = _width * _height;
 | 
				
			||||||
 | 
					      _uint8DataPtr = std::unique_ptr<UInt8[]>(new UInt8[length]);
 | 
				
			||||||
 | 
					      UInt8 *data = _uint8DataPtr.get();
 | 
				
			||||||
 | 
					      for (int i = 0; i < length; i++) {
 | 
				
			||||||
 | 
					        data[i] = _float32Data[i] * 255;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return data;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (const float *)float32Data {
 | 
				
			||||||
 | 
					  switch (_dataType) {
 | 
				
			||||||
 | 
					    case MPPMaskDataTypeUInt8: {
 | 
				
			||||||
 | 
					      if (_float32DataPtr) {
 | 
				
			||||||
 | 
					        return _float32DataPtr.get();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      size_t length = _width * _height;
 | 
				
			||||||
 | 
					      _float32DataPtr = std::unique_ptr<float[]>(new float[length]);
 | 
				
			||||||
 | 
					      float *data = _float32DataPtr.get();
 | 
				
			||||||
 | 
					      for (int i = 0; i < length; i++) {
 | 
				
			||||||
 | 
					        data[i] = (float)_uint8Data[i] / 255;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      return data;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    case MPPMaskDataTypeFloat32: {
 | 
				
			||||||
 | 
					      return _float32Data;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    default:
 | 
				
			||||||
 | 
					      return NULL;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- (id)copyWithZone:(NSZone *)zone {
 | 
				
			||||||
 | 
					  switch (_dataType) {
 | 
				
			||||||
 | 
					    case MPPMaskDataTypeUInt8:
 | 
				
			||||||
 | 
					      return [[MPPMask alloc] initWithUInt8Data:self.uint8Data
 | 
				
			||||||
 | 
					                                          width:self.width
 | 
				
			||||||
 | 
					                                         height:self.height
 | 
				
			||||||
 | 
					                                     shouldCopy:YES];
 | 
				
			||||||
 | 
					    case MPPMaskDataTypeFloat32:
 | 
				
			||||||
 | 
					      return [[MPPMask alloc] initWithFloat32Data:self.float32Data
 | 
				
			||||||
 | 
					                                            width:self.width
 | 
				
			||||||
 | 
					                                           height:self.height
 | 
				
			||||||
 | 
					                                       shouldCopy:YES];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@end
 | 
				
			||||||
| 
						 | 
					@ -165,7 +165,7 @@ static NSString *const kTaskPrefix = @"com.mediapipe.tasks.vision";
 | 
				
			||||||
  // For 90° and 270° rotations, we need to swap width and height.
 | 
					  // For 90° and 270° rotations, we need to swap width and height.
 | 
				
			||||||
  // This is due to the internal behavior of ImageToTensorCalculator, which:
 | 
					  // This is due to the internal behavior of ImageToTensorCalculator, which:
 | 
				
			||||||
  // - first denormalizes the provided rect by multiplying the rect width or height by the image
 | 
					  // - first denormalizes the provided rect by multiplying the rect width or height by the image
 | 
				
			||||||
  //   width or height, repectively.
 | 
					  //   width or height, respectively.
 | 
				
			||||||
  // - then rotates this by denormalized rect by the provided rotation, and uses this for cropping,
 | 
					  // - then rotates this by denormalized rect by the provided rotation, and uses this for cropping,
 | 
				
			||||||
  // - then finally rotates this back.
 | 
					  // - then finally rotates this back.
 | 
				
			||||||
  if (rotationDegrees % 180 == 0) {
 | 
					  if (rotationDegrees % 180 == 0) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,6 +28,7 @@
 | 
				
			||||||
- (id)copyWithZone:(NSZone *)zone {
 | 
					- (id)copyWithZone:(NSZone *)zone {
 | 
				
			||||||
  MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
 | 
					  MPPFaceDetectorOptions *faceDetectorOptions = [super copyWithZone:zone];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  faceDetectorOptions.runningMode = self.runningMode;
 | 
				
			||||||
  faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
 | 
					  faceDetectorOptions.minDetectionConfidence = self.minDetectionConfidence;
 | 
				
			||||||
  faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
 | 
					  faceDetectorOptions.minSuppressionThreshold = self.minSuppressionThreshold;
 | 
				
			||||||
  faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;
 | 
					  faceDetectorOptions.faceDetectorLiveStreamDelegate = self.faceDetectorLiveStreamDelegate;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -43,9 +43,16 @@ objc_library(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					objc_library(
 | 
				
			||||||
 | 
					    name = "MPPFaceLandmarksConnections",
 | 
				
			||||||
 | 
					    hdrs = ["sources/MPPFaceLandmarksConnections.h"],
 | 
				
			||||||
 | 
					    module_name = "MPPFaceLandmarksConnections",
 | 
				
			||||||
 | 
					    deps = ["//mediapipe/tasks/ios/components/containers:MPPConnection"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
objc_library(
 | 
					objc_library(
 | 
				
			||||||
    name = "MPPFaceLandmarker",
 | 
					    name = "MPPFaceLandmarker",
 | 
				
			||||||
    srcs = ["sources/MPPFaceLandmarker.m"],
 | 
					    srcs = ["sources/MPPFaceLandmarker.mm"],
 | 
				
			||||||
    hdrs = ["sources/MPPFaceLandmarker.h"],
 | 
					    hdrs = ["sources/MPPFaceLandmarker.h"],
 | 
				
			||||||
    copts = [
 | 
					    copts = [
 | 
				
			||||||
        "-ObjC++",
 | 
					        "-ObjC++",
 | 
				
			||||||
| 
						 | 
					@ -55,9 +62,11 @@ objc_library(
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":MPPFaceLandmarkerOptions",
 | 
					        ":MPPFaceLandmarkerOptions",
 | 
				
			||||||
        ":MPPFaceLandmarkerResult",
 | 
					        ":MPPFaceLandmarkerResult",
 | 
				
			||||||
 | 
					        ":MPPFaceLandmarksConnections",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
 | 
					        "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
 | 
				
			||||||
        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
					        "//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
 | 
				
			||||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
					        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/ios/components/containers:MPPConnection",
 | 
				
			||||||
        "//mediapipe/tasks/ios/core:MPPTaskInfo",
 | 
					        "//mediapipe/tasks/ios/core:MPPTaskInfo",
 | 
				
			||||||
        "//mediapipe/tasks/ios/vision/core:MPPImage",
 | 
					        "//mediapipe/tasks/ios/vision/core:MPPImage",
 | 
				
			||||||
        "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
 | 
					        "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,6 +14,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#import <Foundation/Foundation.h>
 | 
					#import <Foundation/Foundation.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#import "mediapipe/tasks/ios/components/containers/sources/MPPConnection.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
 | 
					#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h"
 | 
					#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h"
 | 
				
			||||||
#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerResult.h"
 | 
					#import "mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerResult.h"
 | 
				
			||||||
| 
						 | 
					@ -147,6 +148,83 @@ NS_SWIFT_NAME(FaceLandmarker)
 | 
				
			||||||
                      error:(NSError **)error
 | 
					                      error:(NSError **)error
 | 
				
			||||||
    NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
 | 
					    NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the lips.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the lips.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)lipsConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the left eye.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the left eye.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)leftEyeConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the left eyebrow.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the left eyebrow.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)leftEyebrowConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the left iris.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the left iris.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)leftIrisConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the right eye.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the right eyr.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)rightEyeConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the right eyebrow.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the right eyebrow.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)rightEyebrowConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the right iris.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the right iris.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)rightIrisConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks of the face oval.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks of the face oval.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)faceOvalConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between making up the contours of the face.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the contours of the face.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)contoursConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks making up the tesselation of the face.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks  making up the tesselation of the face.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)tesselationConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Returns the connections between all the landmarks in the face.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @return An array of connections between all the landmarks in the face.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					+ (NSArray<MPPConnection *> *)faceConnections;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- (instancetype)init NS_UNAVAILABLE;
 | 
					- (instancetype)init NS_UNAVAILABLE;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
+ (instancetype)new NS_UNAVAILABLE;
 | 
					+ (instancetype)new NS_UNAVAILABLE;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue
	
	Block a user