diff --git a/WORKSPACE b/WORKSPACE index 66828f61f..97c7f68ec 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -157,22 +157,22 @@ http_archive( # 2020-08-21 http_archive( name = "com_github_glog_glog", - strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", - sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", + strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372", + sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb", urls = [ - "https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", + "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip", ], ) http_archive( name = "com_github_glog_glog_no_gflags", - strip_prefix = "glog-0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6", - sha256 = "58c9b3b6aaa4dd8b836c0fd8f65d0f941441fb95e27212c5eeb9979cfd3592ab", + strip_prefix = "glog-3a0d4d22c5ae0b9a2216988411cfa6bf860cc372", + sha256 = "170d08f80210b82d95563f4723a15095eff1aad1863000e8eeb569c96a98fefb", build_file = "@//third_party:glog_no_gflags.BUILD", urls = [ - "https://github.com/google/glog/archive/0a2e5931bd5ff22fd3bf8999eb8ce776f159cda6.zip", + "https://github.com/google/glog/archive/3a0d4d22c5ae0b9a2216988411cfa6bf860cc372.zip", ], patches = [ - "@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff", + "@//third_party:com_github_glog_glog.diff", ], patch_args = [ "-p1", diff --git a/mediapipe/BUILD b/mediapipe/BUILD index fd0cbab36..41443c414 100644 --- a/mediapipe/BUILD +++ b/mediapipe/BUILD @@ -68,30 +68,108 @@ config_setting( visibility = ["//visibility:public"], ) -# Note: this cannot just match "apple_platform_type": "macos" because that option -# defaults to "macos" even when building on Linux! -alias( +# Generic MacOS. +config_setting( name = "macos", - actual = select({ - ":macos_i386": ":macos_i386", - ":macos_x86_64": ":macos_x86_64", - ":macos_arm64": ":macos_arm64", - "//conditions:default": ":macos_i386", # Arbitrarily chosen from above. - }), + constraint_values = [ + "@platforms//os:macos", + ], visibility = ["//visibility:public"], ) -# Note: this also matches on crosstool_top so that it does not produce ambiguous -# selectors when used together with "android". +# MacOS x86 64-bit. +config_setting( + name = "macos_x86_64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:x86_64", + ], + visibility = ["//visibility:public"], +) + +# MacOS ARM64. +config_setting( + name = "macos_arm64", + constraint_values = [ + "@platforms//os:macos", + "@platforms//cpu:arm64", + ], + visibility = ["//visibility:public"], +) + +# Generic iOS. config_setting( name = "ios", - values = { - "crosstool_top": "@bazel_tools//tools/cpp:toolchain", - "apple_platform_type": "ios", - }, + constraint_values = [ + "@platforms//os:ios", + ], visibility = ["//visibility:public"], ) +# iOS device ARM32. +config_setting( + name = "ios_armv7", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm", + ], + visibility = ["//visibility:public"], +) + +# iOS device ARM64. +config_setting( + name = "ios_arm64", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm64", + ], + visibility = ["//visibility:public"], +) + +# iOS device ARM64E. +config_setting( + name = "ios_arm64e", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm64e", + ], + visibility = ["//visibility:public"], +) + +# iOS simulator x86 32-bit. +config_setting( + name = "ios_i386", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:x86_32", + "@build_bazel_apple_support//constraints:simulator", + ], + visibility = ["//visibility:public"], +) + +# iOS simulator x86 64-bit. +config_setting( + name = "ios_x86_64", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:x86_64", + "@build_bazel_apple_support//constraints:simulator", + ], + visibility = ["//visibility:public"], +) + +# iOS simulator ARM64. +config_setting( + name = "ios_sim_arm64", + constraint_values = [ + "@platforms//os:ios", + "@platforms//cpu:arm64", + "@build_bazel_apple_support//constraints:simulator", + ], + visibility = ["//visibility:public"], +) + +# Generic Apple. alias( name = "apple", actual = select({ @@ -102,49 +180,6 @@ alias( visibility = ["//visibility:public"], ) -config_setting( - name = "macos_i386", - values = { - "apple_platform_type": "macos", - "cpu": "darwin", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "macos_x86_64", - values = { - "apple_platform_type": "macos", - "cpu": "darwin_x86_64", - }, - visibility = ["//visibility:public"], -) - -config_setting( - name = "macos_arm64", - values = { - "apple_platform_type": "macos", - "cpu": "darwin_arm64", - }, - visibility = ["//visibility:public"], -) - -[ - config_setting( - name = arch, - values = {"cpu": arch}, - visibility = ["//visibility:public"], - ) - for arch in [ - "ios_i386", - "ios_x86_64", - "ios_armv7", - "ios_arm64", - "ios_arm64e", - "ios_sim_arm64", - ] -] - config_setting( name = "windows", values = {"cpu": "x64_windows"}, diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index fbdbbab0a..7f6528ec1 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -210,6 +210,23 @@ REGISTER_CALCULATOR(SpectrogramCalculator); // Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0). const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518; +namespace { +std::unique_ptr MakeWindowFun( + const SpectrogramCalculatorOptions::WindowType window_type) { + switch (window_type) { + // The cosine window and square root of Hann are equivalent. + case SpectrogramCalculatorOptions::COSINE: + case SpectrogramCalculatorOptions::SQRT_HANN: + return std::make_unique(); + case SpectrogramCalculatorOptions::HANN: + return std::make_unique(); + case SpectrogramCalculatorOptions::HAMMING: + return std::make_unique(); + } + return nullptr; +} +} // namespace + absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { SpectrogramCalculatorOptions spectrogram_options = cc->Options(); @@ -266,28 +283,14 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { output_scale_ = spectrogram_options.output_scale(); - std::vector window; - switch (spectrogram_options.window_type()) { - case SpectrogramCalculatorOptions::COSINE: - audio_dsp::CosineWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - break; - case SpectrogramCalculatorOptions::HANN: - audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - break; - case SpectrogramCalculatorOptions::HAMMING: - audio_dsp::HammingWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - break; - case SpectrogramCalculatorOptions::SQRT_HANN: { - audio_dsp::HannWindow().GetPeriodicSamples(frame_duration_samples_, - &window); - absl::c_transform(window, window.begin(), - [](double x) { return std::sqrt(x); }); - break; - } + auto window_fun = MakeWindowFun(spectrogram_options.window_type()); + if (window_fun == nullptr) { + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("Invalid window type ", + spectrogram_options.window_type())); } + std::vector window; + window_fun->GetPeriodicSamples(frame_duration_samples_, &window); // Propagate settings down to the actual Spectrogram object. spectrogram_generators_.clear(); diff --git a/mediapipe/calculators/audio/spectrogram_calculator.proto b/mediapipe/calculators/audio/spectrogram_calculator.proto index ddfca1d1c..d8bca3f76 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.proto +++ b/mediapipe/calculators/audio/spectrogram_calculator.proto @@ -68,7 +68,7 @@ message SpectrogramCalculatorOptions { HANN = 0; HAMMING = 1; COSINE = 2; - SQRT_HANN = 4; + SQRT_HANN = 4; // Alias of COSINE. } optional WindowType window_type = 6 [default = HANN]; diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 99a63f633..7c5dfe81f 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -381,17 +381,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "clip_detection_vector_size_calculator", - srcs = ["clip_detection_vector_size_calculator.cc"], - deps = [ - ":clip_vector_size_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", - ], - alwayslink = 1, -) - cc_test( name = "clip_vector_size_calculator_test", srcs = ["clip_vector_size_calculator_test.cc"], diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index 7da90989b..d030bbbde 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -76,4 +76,9 @@ REGISTER_CALCULATOR(BeginLoopGpuBufferCalculator); // A calculator to process std::vector. typedef BeginLoopCalculator> BeginLoopImageCalculator; REGISTER_CALCULATOR(BeginLoopImageCalculator); + +// A calculator to process std::vector. +typedef BeginLoopCalculator> BeginLoopFloatCalculator; +REGISTER_CALCULATOR(BeginLoopFloatCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc index d67e6c061..36ee0f2d7 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -123,7 +123,10 @@ class PreviousLoopbackCalculator : public Node { // However, LOOP packet is empty. kPrevLoop(cc).SetNextTimestampBound(main_spec.timestamp + 1); } else { - kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); + // Avoids sending leftovers to a stream that's already closed. + if (!kPrevLoop(cc).IsClosed()) { + kPrevLoop(cc).Send(loop_candidate.At(main_spec.timestamp)); + } } loop_packets_.pop_front(); main_packet_specs_.pop_front(); diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 20e5ebda4..4f3059a51 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -135,7 +135,6 @@ cc_library( deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", ], diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index 6bb43dc00..88f1d4c12 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -112,7 +112,7 @@ class BilateralFilterCalculator : public CalculatorBase { REGISTER_CALCULATOR(BilateralFilterCalculator); absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { diff --git a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc index 81732f904..db0d38325 100644 --- a/mediapipe/calculators/image/segmentation_smoothing_calculator.cc +++ b/mediapipe/calculators/image/segmentation_smoothing_calculator.cc @@ -110,7 +110,7 @@ REGISTER_CALCULATOR(SegmentationSmoothingCalculator); absl::Status SegmentationSmoothingCalculator::GetContract( CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); cc->Inputs().Tag(kCurrentMaskTag).Set(); cc->Inputs().Tag(kPreviousMaskTag).Set(); diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index e20621e8d..9c381f62d 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -142,7 +142,7 @@ class SetAlphaCalculator : public CalculatorBase { REGISTER_CALCULATOR(SetAlphaCalculator); absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; diff --git a/mediapipe/calculators/image/yuv_to_image_calculator.cc b/mediapipe/calculators/image/yuv_to_image_calculator.cc index e84eee74e..6a82877c3 100644 --- a/mediapipe/calculators/image/yuv_to_image_calculator.cc +++ b/mediapipe/calculators/image/yuv_to_image_calculator.cc @@ -38,7 +38,7 @@ std::string FourCCToString(libyuv::FourCC fourcc) { buf[0] = (fourcc >> 24) & 0xff; buf[1] = (fourcc >> 16) & 0xff; buf[2] = (fourcc >> 8) & 0xff; - buf[3] = (fourcc)&0xff; + buf[3] = (fourcc) & 0xff; buf[4] = 0; return std::string(buf); } diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index 47617b375..01cc60a15 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -282,18 +282,23 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { if (options.has_volume_gain_db()) { gain_ = pow(10, options.volume_gain_db() / 20.0); } - RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ - !kAudioIn(cc).Header().IsEmpty()) - << "Must either specify the time series header of the \"AUDIO\" stream " - "or have the \"SAMPLE_RATE\" stream connected."; - if (!kAudioIn(cc).Header().IsEmpty()) { - mediapipe::TimeSeriesHeader input_header; - MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( - kAudioIn(cc).Header(), &input_header)); - if (stream_mode_) { - MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); - } else { - source_sample_rate_ = input_header.sample_rate(); + if (options.has_source_sample_rate()) { + source_sample_rate_ = options.source_sample_rate(); + } else { + RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ + !kAudioIn(cc).Header().IsEmpty()) + << "Must either specify the time series header of the \"AUDIO\" stream " + "or have the \"SAMPLE_RATE\" stream connected."; + if (!kAudioIn(cc).Header().IsEmpty()) { + mediapipe::TimeSeriesHeader input_header; + MP_RETURN_IF_ERROR( + mediapipe::time_series_util::FillTimeSeriesHeaderIfValid( + kAudioIn(cc).Header(), &input_header)); + if (stream_mode_) { + MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate())); + } else { + source_sample_rate_ = input_header.sample_rate(); + } } } AppendZerosToSampleBuffer(padding_samples_before_); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index 5b7d61bcb..948c82a36 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -85,4 +85,7 @@ message AudioToTensorCalculatorOptions { // The volume gain, measured in dB. // Scale the input audio amplitude by 10^(volume_gain_db/20). optional double volume_gain_db = 12; + + // The source number of samples per second (hertz) of the input audio buffers. + optional double source_sample_rate = 13; } diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 8aee46185..e265eaee7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -69,6 +69,7 @@ class InferenceCalculatorGlAdvancedImpl gpu_delegate_options); absl::Status ReadGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; absl::Status SaveGpuCaches(tflite::gpu::TFLiteGPURunner* gpu_runner) const; + bool UseSerializedModel() const { return use_serialized_model_; } private: bool use_kernel_caching_ = false; @@ -150,8 +151,6 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( } absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Close() { - MP_RETURN_IF_ERROR( - on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get())); return gpu_helper_.RunInGlContext([this]() -> absl::Status { tflite_gpu_runner_.reset(); return absl::OkStatus(); @@ -226,9 +225,14 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::InitTFLiteGPURunner( tflite_gpu_runner_->GetOutputShapes()[i].c}; } + if (on_disk_cache_helper_.UseSerializedModel()) { + tflite_gpu_runner_->ForceOpenCLInitFromSerializedModel(); + } + MP_RETURN_IF_ERROR( on_disk_cache_helper_.ReadGpuCaches(tflite_gpu_runner_.get())); - return tflite_gpu_runner_->Build(); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); + return on_disk_cache_helper_.SaveGpuCaches(tflite_gpu_runner_.get()); } #if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_CHROMIUMOS) diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index c8dd0e2a0..246269de1 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -256,6 +256,7 @@ class TensorsToDetectionsCalculator : public Node { bool gpu_inited_ = false; bool gpu_input_ = false; + bool gpu_has_enough_work_groups_ = true; bool anchors_init_ = false; }; MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); @@ -291,7 +292,7 @@ absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { auto output_detections = absl::make_unique>(); bool gpu_processing = false; - if (CanUseGpu()) { + if (CanUseGpu() && gpu_has_enough_work_groups_) { // Use GPU processing only if at least one input tensor is already on GPU // (to avoid CPU->GPU overhead). for (const auto& tensor : *kInTensors(cc)) { @@ -321,11 +322,20 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { RET_CHECK(!has_custom_box_indices_); } - if (gpu_processing) { - if (!gpu_inited_) { - MP_RETURN_IF_ERROR(GpuInit(cc)); + if (gpu_processing && !gpu_inited_) { + auto status = GpuInit(cc); + if (status.ok()) { gpu_inited_ = true; + } else if (status.code() == absl::StatusCode::kFailedPrecondition) { + // For initialization error because of hardware limitation, fallback to + // CPU processing. + LOG(WARNING) << status.message(); + } else { + // For other error, let the error propagates. + return status; } + } + if (gpu_processing && gpu_inited_) { MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); } else { MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); @@ -346,17 +356,41 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU( // TODO: Add flexible input tensor size handling. auto raw_box_tensor = &input_tensors[tensor_mapping_.detections_tensor_index()]; - RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3); - RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; - RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); - RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); + if (raw_box_tensor->shape().dims.size() == 3) { + // The tensors from CPU inference has dim 3. + RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); + RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); + } else if (raw_box_tensor->shape().dims.size() == 4) { + // The tensors from GPU inference has dim 4. For gpu-cpu fallback support, + // we allow tensors with 4 dims. + RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[1], 1); + RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_boxes_); + RET_CHECK_EQ(raw_box_tensor->shape().dims[3], num_coords_); + } else { + return absl::InvalidArgumentError( + "The dimensions of box Tensor must be 3 or 4."); + } auto raw_score_tensor = &input_tensors[tensor_mapping_.scores_tensor_index()]; - RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); - RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); - RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); - RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); + if (raw_score_tensor->shape().dims.size() == 3) { + // The tensors from CPU inference has dim 3. + RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); + RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_classes_); + } else if (raw_score_tensor->shape().dims.size() == 4) { + // The tensors from GPU inference has dim 4. For gpu-cpu fallback support, + // we allow tensors with 4 dims. + RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[1], 1); + RET_CHECK_EQ(raw_score_tensor->shape().dims[2], num_boxes_); + RET_CHECK_EQ(raw_score_tensor->shape().dims[3], num_classes_); + } else { + return absl::InvalidArgumentError( + "The dimensions of score Tensor must be 3 or 4."); + } auto raw_box_view = raw_box_tensor->GetCpuReadView(); auto raw_boxes = raw_box_view.buffer(); auto raw_scores_view = raw_score_tensor->GetCpuReadView(); @@ -1111,8 +1145,13 @@ void main() { int max_wg_size; // typically <= 1024 glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &max_wg_size); // y-dim - CHECK_LT(num_classes_, max_wg_size) - << "# classes must be < " << max_wg_size; + gpu_has_enough_work_groups_ = num_classes_ < max_wg_size; + if (!gpu_has_enough_work_groups_) { + return absl::FailedPreconditionError(absl::StrFormat( + "Hardware limitation: Processing will be done on CPU, because " + "num_classes %d exceeds the max work_group size %d.", + num_classes_, max_wg_size)); + } // TODO support better filtering. if (class_index_set_.is_allowlist) { CHECK_EQ(class_index_set_.values.size(), @@ -1370,7 +1409,13 @@ kernel void scoreKernel( Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); // # filter classes supported is hardware dependent. int max_wg_size = score_program_.maxTotalThreadsPerThreadgroup; - CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size; + gpu_has_enough_work_groups_ = num_classes_ < max_wg_size; + if (!gpu_has_enough_work_groups_) { + return absl::FailedPreconditionError(absl::StrFormat( + "Hardware limitation: Processing will be done on CPU, because " + "num_classes %d exceeds the max work_group size %d.", + num_classes_, max_wg_size)); + } } #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index feee2372a..2d6948671 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -406,8 +406,13 @@ cc_library( alwayslink = 1, ) -# This dependency removed tensorflow_jellyfish_deps and xprofilez_with_server because they failed -# Boq conformance test. Weigh your use case to see if this will work for you. +# This dependency removed the following 3 targets because they failed Boq conformance test: +# +# tensorflow_jellyfish_deps +# jfprof_lib +# xprofilez_with_server +# +# If you need them plz consider tensorflow_inference_calculator_no_envelope_loader. cc_library( name = "tensorflow_inference_calculator_for_boq", srcs = ["tensorflow_inference_calculator.cc"], @@ -927,7 +932,6 @@ cc_test( "//mediapipe/framework:timestamp", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 34136440d..4bb2093da 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -164,8 +164,8 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } - CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || - cc->OutputSidePackets().HasTag(kSequenceExampleTag)) + RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) || + cc->OutputSidePackets().HasTag(kSequenceExampleTag)) << "Neither the output stream nor the output side packet is set to " "output the sequence example."; if (cc->Outputs().HasTag(kSequenceExampleTag)) { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 752db621e..9d45e38e2 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -23,7 +23,6 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" #include "mediapipe/framework/port/gmock.h" @@ -96,7 +95,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { mpms::SetClipMediaId(test_video_id, input_sequence.get()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); @@ -139,7 +139,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { mpms::SetClipMediaId(test_video_id, input_sequence.get()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); @@ -378,7 +379,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { Adopt(input_sequence.release()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); auto image_ptr = @@ -410,7 +412,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoForwardFlowEncodeds) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); std::string test_flow_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_flow; encoded_flow.set_encoded_image(test_flow_string); @@ -618,7 +621,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { } cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(width); @@ -767,7 +771,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamOK) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); std::string test_flow_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_flow; encoded_flow.set_encoded_image(test_flow_string); @@ -813,7 +818,8 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { mpms::SetClipMediaId(test_video_id, input_sequence.get()); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); std::string test_flow_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_flow; encoded_flow.set_encoded_image(test_flow_string); @@ -970,7 +976,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { auto input_sequence = ::absl::make_unique(); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); @@ -1021,7 +1028,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { auto input_sequence = ::absl::make_unique(); cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); OpenCvImageEncoderCalculatorResults encoded_image; encoded_image.set_encoded_image(bytes.data(), bytes.size()); int height = 2; diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 34093702c..5afede99d 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -172,7 +172,7 @@ class AnnotationOverlayCalculator : public CalculatorBase { REGISTER_CALCULATOR(AnnotationOverlayCalculator); absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { - CHECK_GE(cc->Inputs().NumEntries(), 1); + RET_CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; @@ -189,13 +189,13 @@ absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { #if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); - CHECK(cc->Outputs().HasTag(kGpuBufferTag)); + RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag)); use_gpu = true; } #endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); - CHECK(cc->Outputs().HasTag(kImageFrameTag)); + RET_CHECK(cc->Outputs().HasTag(kImageFrameTag)); } // Data streams to render. diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index 263ef85c6..b0d4f4175 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -322,27 +322,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { options_.presence_threshold(), options_.connection_color(), thickness, /*normalized=*/false, render_data.get()); } - for (int i = 0; i < landmarks.landmark_size(); ++i) { - const Landmark& landmark = landmarks.landmark(i); + if (options_.render_landmarks()) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const Landmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibleAndPresent( - landmark, options_.utilize_visibility(), - options_.visibility_threshold(), options_.utilize_presence(), - options_.presence_threshold())) { - continue; - } + if (!IsLandmarkVisibleAndPresent( + landmark, options_.utilize_visibility(), + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold())) { + continue; + } - auto* landmark_data_render = AddPointRenderData( - options_.landmark_color(), thickness, render_data.get()); - if (visualize_depth) { - SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, - options_.min_depth_circle_thickness(), - options_.max_depth_circle_thickness()); + auto* landmark_data_render = AddPointRenderData( + options_.landmark_color(), thickness, render_data.get()); + if (visualize_depth) { + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, + landmark_data_render, + options_.min_depth_circle_thickness(), + options_.max_depth_circle_thickness()); + } + auto* landmark_data = landmark_data_render->mutable_point(); + landmark_data->set_normalized(false); + landmark_data->set_x(landmark.x()); + landmark_data->set_y(landmark.y()); } - auto* landmark_data = landmark_data_render->mutable_point(); - landmark_data->set_normalized(false); - landmark_data->set_x(landmark.x()); - landmark_data->set_y(landmark.y()); } } @@ -368,27 +371,30 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { options_.presence_threshold(), options_.connection_color(), thickness, /*normalized=*/true, render_data.get()); } - for (int i = 0; i < landmarks.landmark_size(); ++i) { - const NormalizedLandmark& landmark = landmarks.landmark(i); + if (options_.render_landmarks()) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const NormalizedLandmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibleAndPresent( - landmark, options_.utilize_visibility(), - options_.visibility_threshold(), options_.utilize_presence(), - options_.presence_threshold())) { - continue; - } + if (!IsLandmarkVisibleAndPresent( + landmark, options_.utilize_visibility(), + options_.visibility_threshold(), options_.utilize_presence(), + options_.presence_threshold())) { + continue; + } - auto* landmark_data_render = AddPointRenderData( - options_.landmark_color(), thickness, render_data.get()); - if (visualize_depth) { - SetColorSizeValueFromZ(landmark.z(), z_min, z_max, landmark_data_render, - options_.min_depth_circle_thickness(), - options_.max_depth_circle_thickness()); + auto* landmark_data_render = AddPointRenderData( + options_.landmark_color(), thickness, render_data.get()); + if (visualize_depth) { + SetColorSizeValueFromZ(landmark.z(), z_min, z_max, + landmark_data_render, + options_.min_depth_circle_thickness(), + options_.max_depth_circle_thickness()); + } + auto* landmark_data = landmark_data_render->mutable_point(); + landmark_data->set_normalized(true); + landmark_data->set_x(landmark.x()); + landmark_data->set_y(landmark.y()); } - auto* landmark_data = landmark_data_render->mutable_point(); - landmark_data->set_normalized(true); - landmark_data->set_x(landmark.x()); - landmark_data->set_y(landmark.y()); } } diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto index 990919540..67dca84ad 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.proto @@ -32,6 +32,10 @@ message LandmarksToRenderDataCalculatorOptions { // Color of the landmarks. optional Color landmark_color = 2; + + // Whether to render landmarks as points. + optional bool render_landmarks = 14 [default = true]; + // Color of the connections. optional Color connection_color = 3; diff --git a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc index 59b21d574..30dc11dbe 100644 --- a/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc +++ b/mediapipe/calculators/util/refine_landmarks_from_heatmap_calculator.cc @@ -124,7 +124,7 @@ absl::StatusOr RefineLandmarksFromHeatMap( int center_row = out_lms.landmark(lm_index).y() * hm_height; // Point is outside of the image let's keep it intact. if (center_col < 0 || center_col >= hm_width || center_row < 0 || - center_col >= hm_height) { + center_row >= hm_height) { continue; } diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 7245b13c2..569fd8bad 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -130,7 +130,6 @@ cc_library( "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_video", - "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/tool:status_util", ], @@ -341,7 +340,6 @@ cc_test( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:test_util", - "@com_google_absl//absl/flags:flag", ], ) @@ -367,7 +365,6 @@ cc_test( "//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:test_util", - "@com_google_absl//absl/flags:flag", ], ) @@ -451,7 +448,6 @@ cc_test( "//mediapipe/framework/tool:test_util", "//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto", - "@com_google_absl//absl/flags:flag", ], ) diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar index 943f0cbfa..afba10928 100644 Binary files a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar and b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 508322917..4e86b9270 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip networkTimeout=10000 zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/mediapipe/examples/ios/facedetectioncpu/BUILD b/mediapipe/examples/ios/facedetectioncpu/BUILD index 9424fddea..300901909 100644 --- a/mediapipe/examples/ios/facedetectioncpu/BUILD +++ b/mediapipe/examples/ios/facedetectioncpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "facedetectioncpu", diff --git a/mediapipe/examples/ios/facedetectiongpu/BUILD b/mediapipe/examples/ios/facedetectiongpu/BUILD index 8ed689b4f..d3725aa33 100644 --- a/mediapipe/examples/ios/facedetectiongpu/BUILD +++ b/mediapipe/examples/ios/facedetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "facedetectiongpu", diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 1152bed33..c9415068b 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "faceeffect", diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index 6caf8c09c..250a8bca1 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "facemeshgpu", diff --git a/mediapipe/examples/ios/handdetectiongpu/BUILD b/mediapipe/examples/ios/handdetectiongpu/BUILD index 9b9255374..6deb1be1d 100644 --- a/mediapipe/examples/ios/handdetectiongpu/BUILD +++ b/mediapipe/examples/ios/handdetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "handdetectiongpu", diff --git a/mediapipe/examples/ios/handtrackinggpu/BUILD b/mediapipe/examples/ios/handtrackinggpu/BUILD index c5b8e7b58..b8f1442fe 100644 --- a/mediapipe/examples/ios/handtrackinggpu/BUILD +++ b/mediapipe/examples/ios/handtrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "handtrackinggpu", diff --git a/mediapipe/examples/ios/helloworld/BUILD b/mediapipe/examples/ios/helloworld/BUILD index 6bfcfaaef..3bed74843 100644 --- a/mediapipe/examples/ios/helloworld/BUILD +++ b/mediapipe/examples/ios/helloworld/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "helloworld", diff --git a/mediapipe/examples/ios/holistictrackinggpu/BUILD b/mediapipe/examples/ios/holistictrackinggpu/BUILD index cd10877de..56c74148c 100644 --- a/mediapipe/examples/ios/holistictrackinggpu/BUILD +++ b/mediapipe/examples/ios/holistictrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "holistictrackinggpu", diff --git a/mediapipe/examples/ios/iristrackinggpu/BUILD b/mediapipe/examples/ios/iristrackinggpu/BUILD index 646d2e5a2..78d4bbd1e 100644 --- a/mediapipe/examples/ios/iristrackinggpu/BUILD +++ b/mediapipe/examples/ios/iristrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "iristrackinggpu", diff --git a/mediapipe/examples/ios/objectdetectioncpu/BUILD b/mediapipe/examples/ios/objectdetectioncpu/BUILD index 7638c7413..47bde166e 100644 --- a/mediapipe/examples/ios/objectdetectioncpu/BUILD +++ b/mediapipe/examples/ios/objectdetectioncpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "objectdetectioncpu", diff --git a/mediapipe/examples/ios/objectdetectiongpu/BUILD b/mediapipe/examples/ios/objectdetectiongpu/BUILD index 3b925c078..174db7582 100644 --- a/mediapipe/examples/ios/objectdetectiongpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiongpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "objectdetectiongpu", diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD index 2236c5257..cb8626cc3 100644 --- a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "objectdetectiontrackinggpu", diff --git a/mediapipe/examples/ios/posetrackinggpu/BUILD b/mediapipe/examples/ios/posetrackinggpu/BUILD index 4fbc2280c..855d32954 100644 --- a/mediapipe/examples/ios/posetrackinggpu/BUILD +++ b/mediapipe/examples/ios/posetrackinggpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "posetrackinggpu", diff --git a/mediapipe/examples/ios/selfiesegmentationgpu/BUILD b/mediapipe/examples/ios/selfiesegmentationgpu/BUILD index 1ba7997ed..2abf05617 100644 --- a/mediapipe/examples/ios/selfiesegmentationgpu/BUILD +++ b/mediapipe/examples/ios/selfiesegmentationgpu/BUILD @@ -24,7 +24,7 @@ load( licenses(["notice"]) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" alias( name = "selfiesegmentationgpu", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 93e9475f3..6dca0ba98 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -44,6 +44,9 @@ bzl_library( "encode_binary_proto.bzl", ], visibility = ["//visibility:public"], + deps = [ + "@bazel_skylib//lib:paths", + ], ) alias( diff --git a/mediapipe/framework/api2/node.h b/mediapipe/framework/api2/node.h index 14c098246..58cebf1ea 100644 --- a/mediapipe/framework/api2/node.h +++ b/mediapipe/framework/api2/node.h @@ -64,58 +64,13 @@ class CalculatorBaseFactoryFor< namespace api2 { namespace internal { -// Defining a member of this type causes P to be ODR-used, which forces its -// instantiation if it's a static member of a template. -// Previously we depended on the pointer's value to determine whether the size -// of a character array is 0 or 1, forcing it to be instantiated so the -// compiler can determine the object's layout. But using it as a template -// argument is more compact. -template -struct ForceStaticInstantiation { -#ifdef _MSC_VER - // Just having it as the template argument does not count as a use for - // MSVC. - static constexpr bool Use() { return P != nullptr; } - char force_static[Use()]; -#endif // _MSC_VER -}; +MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE( + NodeRegistrator, mediapipe::CalculatorBaseRegistry, T::kCalculatorName, + absl::make_unique>) -// Helper template for forcing the definition of a static registration token. -template -struct NodeRegistrationStatic { - static NoDestructor registration; - - static mediapipe::RegistrationToken Make() { - return mediapipe::CalculatorBaseRegistry::Register( - T::kCalculatorName, - absl::make_unique>, - __FILE__, __LINE__); - } - - using RequireStatics = ForceStaticInstantiation<®istration>; -}; - -// Static members of template classes can be defined in the header. -template -NoDestructor - NodeRegistrationStatic::registration(NodeRegistrationStatic::Make()); - -template -struct SubgraphRegistrationImpl { - static NoDestructor registration; - - static mediapipe::RegistrationToken Make() { - return mediapipe::SubgraphRegistry::Register( - T::kCalculatorName, absl::make_unique, __FILE__, __LINE__); - } - - using RequireStatics = ForceStaticInstantiation<®istration>; -}; - -template -NoDestructor - SubgraphRegistrationImpl::registration( - SubgraphRegistrationImpl::Make()); +MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(SubgraphRegistrator, + mediapipe::SubgraphRegistry, + T::kCalculatorName, absl::make_unique) } // namespace internal @@ -128,14 +83,7 @@ template class RegisteredNode; template -class RegisteredNode : public Node { - private: - // The member below triggers instantiation of the registration static. - // Note that the constructor of calculator subclasses is only invoked through - // the registration token, and so we cannot simply use the static in the - // constructor. - typename internal::NodeRegistrationStatic::RequireStatics register_; -}; +class RegisteredNode : public Node, private internal::NodeRegistrator {}; // No-op version for backwards compatibility. template <> @@ -217,31 +165,27 @@ class NodeImpl : public RegisteredNode, public Intf { // TODO: verify that the subgraph config fully implements the // declared interface. template -class SubgraphImpl : public Subgraph, public Intf { - private: - typename internal::SubgraphRegistrationImpl::RequireStatics register_; -}; +class SubgraphImpl : public Subgraph, + public Intf, + private internal::SubgraphRegistrator {}; // This macro is used to register a calculator that does not use automatic // registration. Deprecated. -#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ - static mediapipe::NoDestructor \ - REGISTRY_STATIC_VAR(calculator_registration, \ - __LINE__)(mediapipe::CalculatorBaseRegistry::Register( \ - Impl::kCalculatorName, \ - absl::make_unique>, \ - __FILE__, __LINE__)) +#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ + MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \ + mediapipe::CalculatorBaseRegistry, calculator_registration, \ + Impl::kCalculatorName, \ + absl::make_unique>) // This macro is used to register a non-split-contract calculator. Deprecated. #define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name) // This macro is used to define a subgraph that does not use automatic // registration. Deprecated. -#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ - static mediapipe::NoDestructor \ - REGISTRY_STATIC_VAR(subgraph_registration, \ - __LINE__)(mediapipe::SubgraphRegistry::Register( \ - Impl::kCalculatorName, absl::make_unique, __FILE__, __LINE__)) +#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ + MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED( \ + mediapipe::SubgraphRegistry, subgraph_registration, \ + Impl::kCalculatorName, absl::make_unique) } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/calculator_base_test.cc b/mediapipe/framework/calculator_base_test.cc index c26006e0f..42c03696c 100644 --- a/mediapipe/framework/calculator_base_test.cc +++ b/mediapipe/framework/calculator_base_test.cc @@ -183,8 +183,7 @@ TEST(CalculatorTest, CreateByNameWhitelisted) { CalculatorBaseRegistry::Register( "::mediapipe::test_ns::whitelisted_ns::DeadCalculator", absl::make_unique>, - __FILE__, __LINE__); + mediapipe::test_ns::whitelisted_ns::DeadCalculator>>); // A whitelisted calculator can be found in its own namespace. MP_EXPECT_OK(CalculatorBaseRegistry::CreateByNameInNamespace( // diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index 7965539b6..c67f07305 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -16,7 +16,6 @@ #define MEDIAPIPE_DEPS_REGISTRATION_H_ #include -#include #include #include #include @@ -145,6 +144,23 @@ template struct WrapStatusOr> { using type = absl::StatusOr; }; + +// Defining a member of this type causes P to be ODR-used, which forces its +// instantiation if it's a static member of a template. +// Previously we depended on the pointer's value to determine whether the size +// of a character array is 0 or 1, forcing it to be instantiated so the +// compiler can determine the object's layout. But using it as a template +// argument is more compact. +template +struct ForceStaticInstantiation { +#ifdef _MSC_VER + // Just having it as the template argument does not count as a use for + // MSVC. + static constexpr bool Use() { return P != nullptr; } + char force_static[Use()]; +#endif // _MSC_VER +}; + } // namespace registration_internal class NamespaceAllowlist { @@ -162,8 +178,7 @@ class FunctionRegistry { FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete; - RegistrationToken Register(absl::string_view name, Function func, - std::string filename, uint64_t line) + RegistrationToken Register(absl::string_view name, Function func) ABSL_LOCKS_EXCLUDED(lock_) { std::string normalized_name = GetNormalizedName(name); absl::WriterMutexLock lock(&lock_); @@ -173,21 +188,10 @@ class FunctionRegistry { } if (functions_.insert(std::make_pair(normalized_name, std::move(func))) .second) { -#ifndef NDEBUG - locations_.emplace(normalized_name, - std::make_pair(std::move(filename), line)); -#endif return RegistrationToken( [this, normalized_name]() { Unregister(normalized_name); }); } -#ifndef NDEBUG - LOG(FATAL) << "Function with name " << name << " already registered." - << " First registration at " - << locations_.at(normalized_name).first << ":" - << locations_.at(normalized_name).second; -#else LOG(FATAL) << "Function with name " << name << " already registered."; -#endif return RegistrationToken([]() {}); } @@ -316,11 +320,6 @@ class FunctionRegistry { private: mutable absl::Mutex lock_; absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); -#ifndef NDEBUG - // Stores filename and line number for useful debug log. - absl::flat_hash_map> locations_ - ABSL_GUARDED_BY(lock_); -#endif // For names included in NamespaceAllowlist, strips the namespace. std::string GetAdjustedName(absl::string_view name) { @@ -351,10 +350,8 @@ class GlobalFactoryRegistry { public: static RegistrationToken Register(absl::string_view name, - typename Functions::Function func, - std::string filename, uint64_t line) { - return functions()->Register(name, std::move(func), std::move(filename), - line); + typename Functions::Function func) { + return functions()->Register(name, std::move(func)); } // Invokes the specified factory function and returns the result. @@ -414,12 +411,77 @@ class GlobalFactoryRegistry { #define MEDIAPIPE_REGISTER_FACTORY_FUNCTION(RegistryType, name, ...) \ static auto* REGISTRY_STATIC_VAR(registration_##name, __LINE__) = \ new mediapipe::RegistrationToken( \ - RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__)) + RegistryType::Register(#name, __VA_ARGS__)) +#define MEDIAPIPE_REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, \ + name, ...) \ + static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \ + new mediapipe::RegistrationToken( \ + RegistryType::Register(name, __VA_ARGS__)) + +// TODO: migrate to the above. #define REGISTER_FACTORY_FUNCTION_QUALIFIED(RegistryType, var_name, name, ...) \ static auto* REGISTRY_STATIC_VAR(var_name, __LINE__) = \ new mediapipe::RegistrationToken( \ - RegistryType::Register(#name, __VA_ARGS__, __FILE__, __LINE__)) + RegistryType::Register(#name, __VA_ARGS__)) + +// Defines a utility registrator class which can be used to automatically +// register factory functions. +// +// Example: +// === Defining a registry ================================================ +// +// class Component {}; +// +// using ComponentRegistry = GlobalFactoryRegistry>; +// +// === Defining a registrator ============================================= +// +// MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(ComponentRegistrator, +// ComponentRegistry, T::kName, +// absl::make_unique); +// +// === Defining and registering a new component. ========================== +// +// class MyComponent : public Component, +// private ComponentRegistrator { +// public: +// static constexpr char kName[] = "MyComponent"; +// ... +// }; +// +// NOTE: +// - MyComponent is automatically registered in ComponentRegistry by +// "MyComponent" name. +// - Every component is require to provide its name (T::kName here.) +#define MEDIAPIPE_STATIC_REGISTRATOR_TEMPLATE(RegistratorName, RegistryType, \ + name, ...) \ + template \ + struct Internal##RegistratorName { \ + static NoDestructor registration; \ + \ + static mediapipe::RegistrationToken Make() { \ + return RegistryType::Register(name, __VA_ARGS__); \ + } \ + \ + using RequireStatics = \ + registration_internal::ForceStaticInstantiation<®istration>; \ + }; \ + /* Static members of template classes can be defined in the header. */ \ + template \ + NoDestructor \ + Internal##RegistratorName::registration( \ + Internal##RegistratorName::Make()); \ + \ + template \ + class RegistratorName { \ + private: \ + /* The member below triggers instantiation of the registration static. */ \ + /* Note that the constructor of calculator subclasses is only invoked */ \ + /* through the registration token, and so we cannot simply use the */ \ + /* static in theconstructor. */ \ + typename Internal##RegistratorName::RequireStatics register_; \ + }; } // namespace mediapipe diff --git a/mediapipe/framework/encode_binary_proto.bzl b/mediapipe/framework/encode_binary_proto.bzl index e849d971f..e0e9ae680 100644 --- a/mediapipe/framework/encode_binary_proto.bzl +++ b/mediapipe/framework/encode_binary_proto.bzl @@ -37,29 +37,33 @@ Args: output: The desired name of the output file. Optional. """ +load("@bazel_skylib//lib:paths.bzl", "paths") + PROTOC = "@com_google_protobuf//:protoc" -def _canonicalize_proto_path_oss(all_protos, genfile_path): - """For the protos from external repository, canonicalize the proto path and the file name. +def _canonicalize_proto_path_oss(f): + if not f.root.path: + return struct( + proto_path = ".", + file_name = f.short_path, + ) - Returns: - Proto path list and proto source file list. - """ - proto_paths = [] - proto_file_names = [] - for s in all_protos.to_list(): - if s.path.startswith(genfile_path): - repo_name, _, file_name = s.path[len(genfile_path + "/external/"):].partition("/") + # `f.path` looks like "/external//(_virtual_imports//)?" + repo_name, _, file_name = f.path[len(paths.join(f.root.path, "external") + "/"):].partition("/") + if file_name.startswith("_virtual_imports/"): + # This is a virtual import; move "_virtual_imports/" from `repo_name` to `file_name`. + repo_name = paths.join(repo_name, *file_name.split("/", 2)[:2]) + file_name = file_name.split("/", 2)[-1] + return struct( + proto_path = paths.join(f.root.path, "external", repo_name), + file_name = file_name, + ) - # handle virtual imports - if file_name.startswith("_virtual_imports"): - repo_name = repo_name + "/" + "/".join(file_name.split("/", 2)[:2]) - file_name = file_name.split("/", 2)[-1] - proto_paths.append(genfile_path + "/external/" + repo_name) - proto_file_names.append(file_name) - else: - proto_file_names.append(s.path) - return ([" --proto_path=" + path for path in proto_paths], proto_file_names) +def _map_root_path(f): + return _canonicalize_proto_path_oss(f).proto_path + +def _map_short_path(f): + return _canonicalize_proto_path_oss(f).file_name def _get_proto_provider(dep): """Get the provider for protocol buffers from a dependnecy. @@ -90,24 +94,35 @@ def _encode_binary_proto_impl(ctx): sibling = textpb, ) - path_list, file_list = _canonicalize_proto_path_oss(all_protos, ctx.genfiles_dir.path) + args = ctx.actions.args() + args.add(textpb) + args.add(binarypb) + args.add(ctx.executable._proto_compiler) + args.add(ctx.attr.message_type, format = "--encode=%s") + args.add("--proto_path=.") + args.add_all( + all_protos, + map_each = _map_root_path, + format_each = "--proto_path=%s", + uniquify = True, + ) + args.add_all( + all_protos, + map_each = _map_short_path, + uniquify = True, + ) # Note: the combination of absolute_paths and proto_path, as well as the exact # order of gendir before ., is needed for the proto compiler to resolve # import statements that reference proto files produced by a genrule. ctx.actions.run_shell( - tools = all_protos.to_list() + [textpb, ctx.executable._proto_compiler], - outputs = [binarypb], - command = " ".join( - [ - ctx.executable._proto_compiler.path, - "--encode=" + ctx.attr.message_type, - "--proto_path=" + ctx.genfiles_dir.path, - "--proto_path=" + ctx.bin_dir.path, - "--proto_path=.", - ] + path_list + file_list + - ["<", textpb.path, ">", binarypb.path], + tools = depset( + direct = [textpb, ctx.executable._proto_compiler], + transitive = [all_protos], ), + outputs = [binarypb], + command = "${@:3} < $1 > $2", + arguments = [args], mnemonic = "EncodeProto", ) diff --git a/mediapipe/framework/formats/body_rig.proto b/mediapipe/framework/formats/body_rig.proto index 5420ccc10..88964d995 100644 --- a/mediapipe/framework/formats/body_rig.proto +++ b/mediapipe/framework/formats/body_rig.proto @@ -19,7 +19,7 @@ package mediapipe; // Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of // the joint and its visibility. message Joint { - // Joint rotation in 6D contineous representation ordered as + // Joint rotation in 6D continuous representation ordered as // [a1, b1, a2, b2, a3, b3]. // // Such representation is more sutable for NN model training and can be diff --git a/mediapipe/framework/mediapipe_cc_test.bzl b/mediapipe/framework/mediapipe_cc_test.bzl index 0fc0a462d..5e1daca7b 100644 --- a/mediapipe/framework/mediapipe_cc_test.bzl +++ b/mediapipe/framework/mediapipe_cc_test.bzl @@ -15,7 +15,7 @@ def mediapipe_cc_test( platforms = ["linux", "android", "ios", "wasm"], exclude_platforms = None, # ios_unit_test arguments - ios_minimum_os_version = "11.0", + ios_minimum_os_version = "12.0", # android_cc_test arguments open_gl_driver = None, emulator_mini_boot = True, diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index af2ec5a98..1024cbc15 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -466,8 +466,7 @@ struct MessageRegistrationImpl { template NoDestructor MessageRegistrationImpl::registration(MessageHolderRegistry::Register( - T{}.GetTypeName(), MessageRegistrationImpl::CreateMessageHolder, - __FILE__, __LINE__)); + T{}.GetTypeName(), MessageRegistrationImpl::CreateMessageHolder)); // For non-Message payloads, this does nothing. template diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index cae439bc0..5894e4715 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -261,8 +261,8 @@ cc_library( ) cc_library( - name = "opencv_highgui", - hdrs = ["opencv_highgui_inc.h"], + name = "opencv_photo", + hdrs = ["opencv_photo_inc.h"], deps = [ ":opencv_core", "//third_party:opencv", @@ -297,6 +297,15 @@ cc_library( ], ) +cc_library( + name = "opencv_highgui", + hdrs = ["opencv_highgui_inc.h"], + deps = [ + ":opencv_core", + "//third_party:opencv", + ], +) + cc_library( name = "opencv_videoio", hdrs = ["opencv_videoio_inc.h"], diff --git a/mediapipe/framework/port/opencv_highgui_inc.h b/mediapipe/framework/port/opencv_highgui_inc.h index c3ca4b7f0..c79804e1f 100644 --- a/mediapipe/framework/port/opencv_highgui_inc.h +++ b/mediapipe/framework/port/opencv_highgui_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ -#define MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ +#ifndef MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_ +#define MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_ #include @@ -25,4 +25,4 @@ #include #endif -#endif // MEDIAPIPE_PORT_OPENCV_HIGHGUI_INC_H_ +#endif // MEDIAPIPE_FRAMEWORK_PORT_OPENCV_HIGHGUI_INC_H_ diff --git a/mediapipe/framework/port/opencv_imgcodecs_inc.h b/mediapipe/framework/port/opencv_imgcodecs_inc.h index 60bcd49e9..4c867ed56 100644 --- a/mediapipe/framework/port/opencv_imgcodecs_inc.h +++ b/mediapipe/framework/port/opencv_imgcodecs_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/calculators/core/clip_detection_vector_size_calculator.cc b/mediapipe/framework/port/opencv_photo_inc.h similarity index 59% rename from mediapipe/calculators/core/clip_detection_vector_size_calculator.cc rename to mediapipe/framework/port/opencv_photo_inc.h index 55bcf2feb..1416fda70 100644 --- a/mediapipe/calculators/core/clip_detection_vector_size_calculator.cc +++ b/mediapipe/framework/port/opencv_photo_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2023 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,15 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#ifndef MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_ -#include "mediapipe/calculators/core/clip_vector_size_calculator.h" -#include "mediapipe/framework/formats/detection.pb.h" +#include "third_party/OpenCV/photo.hpp" -namespace mediapipe { - -typedef ClipVectorSizeCalculator<::mediapipe::Detection> - ClipDetectionVectorSizeCalculator; -REGISTER_CALCULATOR(ClipDetectionVectorSizeCalculator); - -} // namespace mediapipe +#endif // MEDIAPIPE_PORT_OPENCV_PHOTO_INC_H_ diff --git a/mediapipe/framework/port/opencv_video_inc.h b/mediapipe/framework/port/opencv_video_inc.h index dc84bf59b..5f06d9233 100644 --- a/mediapipe/framework/port/opencv_video_inc.h +++ b/mediapipe/framework/port/opencv_video_inc.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2022 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/mediapipe/framework/stream_handler/mux_input_stream_handler.cc b/mediapipe/framework/stream_handler/mux_input_stream_handler.cc index 0303a5778..209c3b6f5 100644 --- a/mediapipe/framework/stream_handler/mux_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/mux_input_stream_handler.cc @@ -48,6 +48,18 @@ class MuxInputStreamHandler : public InputStreamHandler { : InputStreamHandler(std::move(tag_map), cc_manager, options, calculator_run_in_parallel) {} + private: + CollectionItemId GetControlStreamId() const { + return input_stream_managers_.EndId() - 1; + } + void RemoveOutdatedDataPackets(Timestamp timestamp) { + const CollectionItemId control_stream_id = GetControlStreamId(); + for (CollectionItemId id = input_stream_managers_.BeginId(); + id < control_stream_id; ++id) { + input_stream_managers_.Get(id)->ErasePacketsEarlierThan(timestamp); + } + } + protected: // In MuxInputStreamHandler, a node is "ready" if: // - the control stream is done (need to call Close() in this case), or @@ -58,9 +70,15 @@ class MuxInputStreamHandler : public InputStreamHandler { absl::MutexLock lock(&input_streams_mutex_); const auto& control_stream = - input_stream_managers_.Get(input_stream_managers_.EndId() - 1); + input_stream_managers_.Get(GetControlStreamId()); bool empty; *min_stream_timestamp = control_stream->MinTimestampOrBound(&empty); + + // Data streams may contain some outdated packets which failed to be popped + // out during "FillInputSet". (This handler doesn't sync input streams, + // hence "FillInputSet" can be triggerred before every input stream is + // filled with packets corresponding to the same timestamp.) + RemoveOutdatedDataPackets(*min_stream_timestamp); if (empty) { if (*min_stream_timestamp == Timestamp::Done()) { // Calculator is done if the control input stream is done. @@ -78,11 +96,6 @@ class MuxInputStreamHandler : public InputStreamHandler { const auto& data_stream = input_stream_managers_.Get( input_stream_managers_.BeginId() + control_value); - // Data stream may contain some outdated packets which failed to be popped - // out during "FillInputSet". (This handler doesn't sync input streams, - // hence "FillInputSet" can be triggerred before every input stream is - // filled with packets corresponding to the same timestamp.) - data_stream->ErasePacketsEarlierThan(*min_stream_timestamp); Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty); if (empty) { if (stream_timestamp <= *min_stream_timestamp) { @@ -111,8 +124,7 @@ class MuxInputStreamHandler : public InputStreamHandler { CHECK(input_set); absl::MutexLock lock(&input_streams_mutex_); - const CollectionItemId control_stream_id = - input_stream_managers_.EndId() - 1; + const CollectionItemId control_stream_id = GetControlStreamId(); auto& control_stream = input_stream_managers_.Get(control_stream_id); int num_packets_dropped = 0; bool stream_is_done = false; @@ -140,15 +152,8 @@ class MuxInputStreamHandler : public InputStreamHandler { AddPacketToShard(&input_set->Get(data_stream_id), std::move(data_packet), stream_is_done); - // Discard old packets on other streams. - // Note that control_stream_id is the last valid id. - auto next_timestamp = input_timestamp.NextAllowedInStream(); - for (CollectionItemId id = input_stream_managers_.BeginId(); - id < control_stream_id; ++id) { - if (id == data_stream_id) continue; - auto& other_stream = input_stream_managers_.Get(id); - other_stream->ErasePacketsEarlierThan(next_timestamp); - } + // Discard old packets on data streams. + RemoveOutdatedDataPackets(input_timestamp.NextAllowedInStream()); } private: diff --git a/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc index f19a3ddec..78b2bb3f7 100644 --- a/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/mux_input_stream_handler_test.cc @@ -645,5 +645,41 @@ TEST(MuxInputStreamHandlerTest, MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(MuxInputStreamHandlerTest, RemovesUnusedDataStreamPackets) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input0" + input_stream: "input1" + input_stream: "select" + node { + calculator: "MuxCalculator" + input_stream: "INPUT:0:input0" + input_stream: "INPUT:1:input1" + input_stream: "SELECT:select" + output_stream: "OUTPUT:output" + input_stream_handler { input_stream_handler: "MuxInputStreamHandler" } + } + )pb"); + config.set_max_queue_size(1); + config.set_report_deadlock(true); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "select", MakePacket(0).At(Timestamp(2)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input0", MakePacket(1000).At(Timestamp(2)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Add two delayed packets to the deselected input. They should be discarded + // instead of triggering the deadlock detection (max_queue_size = 1). + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input1", MakePacket(900).At(Timestamp(1)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input1", MakePacket(900).At(Timestamp(2)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index 6c18c9cac..7cbde28bf 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -64,13 +64,13 @@ GraphRegistry::GraphRegistry( void GraphRegistry::Register( const std::string& type_name, std::function()> factory) { - local_factories_.Register(type_name, factory, __FILE__, __LINE__); + local_factories_.Register(type_name, factory); } // TODO: Remove this convenience function. void GraphRegistry::Register(const std::string& type_name, const CalculatorGraphConfig& config) { - Register(type_name, [config] { + local_factories_.Register(type_name, [config] { auto result = absl::make_unique(config); return std::unique_ptr(result.release()); }); @@ -79,7 +79,7 @@ void GraphRegistry::Register(const std::string& type_name, // TODO: Remove this convenience function. void GraphRegistry::Register(const std::string& type_name, const CalculatorGraphTemplate& templ) { - Register(type_name, [templ] { + local_factories_.Register(type_name, [templ] { auto result = absl::make_unique(templ); return std::unique_ptr(result.release()); }); diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 64b5072c5..5e712ecf5 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -228,7 +228,9 @@ absl::Status CompareAndSaveImageOutput( auto status = CompareImageFrames(**expected, actual, options.max_color_diff, options.max_alpha_diff, options.max_avg_diff, diff_img); - ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff")); + if (diff_img) { + ASSIGN_OR_RETURN(auto diff_img_path, SavePngTestOutput(*diff_img, "diff")); + } return status; } diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index ee32b91e2..bc5fb95fc 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -1121,7 +1121,7 @@ objc_library( alwayslink = 1, ) -MIN_IOS_VERSION = "11.0" +MIN_IOS_VERSION = "12.0" test_suite( name = "ios", diff --git a/mediapipe/gpu/gl_context_webgl.cc b/mediapipe/gpu/gl_context_webgl.cc index 25cbed83d..1bbb42c84 100644 --- a/mediapipe/gpu/gl_context_webgl.cc +++ b/mediapipe/gpu/gl_context_webgl.cc @@ -109,9 +109,8 @@ absl::Status GlContext::CreateContext( } MP_RETURN_IF_ERROR(status); - LOG(INFO) << "Successfully created a WebGL context with major version " - << gl_major_version_ << " and handle " << context_; - + VLOG(1) << "Successfully created a WebGL context with major version " + << gl_major_version_ << " and handle " << context_; return absl::OkStatus(); } diff --git a/mediapipe/gpu/gl_scaler_calculator.cc b/mediapipe/gpu/gl_scaler_calculator.cc index fa06c8854..14540b52d 100644 --- a/mediapipe/gpu/gl_scaler_calculator.cc +++ b/mediapipe/gpu/gl_scaler_calculator.cc @@ -104,6 +104,7 @@ class GlScalerCalculator : public CalculatorBase { bool vertical_flip_output_; bool horizontal_flip_output_; FrameScaleMode scale_mode_ = FrameScaleMode::kStretch; + bool use_nearest_neighbor_interpolation_ = false; }; REGISTER_CALCULATOR(GlScalerCalculator); @@ -186,7 +187,8 @@ absl::Status GlScalerCalculator::Open(CalculatorContext* cc) { scale_mode_ = FrameScaleModeFromProto(options.scale_mode(), FrameScaleMode::kStretch); } - + use_nearest_neighbor_interpolation_ = + options.use_nearest_neighbor_interpolation(); if (HasTagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1)) { const auto& dimensions = TagOrIndex(cc->InputSidePackets(), "OUTPUT_DIMENSIONS", 1) @@ -297,6 +299,11 @@ absl::Status GlScalerCalculator::Process(CalculatorContext* cc) { glBindTexture(src2.target(), src2.name()); } + if (use_nearest_neighbor_interpolation_) { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); + } + MP_RETURN_IF_ERROR(renderer->GlRender( src1.width(), src1.height(), dst.width(), dst.height(), scale_mode_, rotation_, horizontal_flip_output_, vertical_flip_output_, diff --git a/mediapipe/gpu/gl_scaler_calculator.proto b/mediapipe/gpu/gl_scaler_calculator.proto index 99c0d439a..f746a30f8 100644 --- a/mediapipe/gpu/gl_scaler_calculator.proto +++ b/mediapipe/gpu/gl_scaler_calculator.proto @@ -19,7 +19,7 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; import "mediapipe/gpu/scale_mode.proto"; -// Next id: 8. +// Next id: 9. message GlScalerCalculatorOptions { extend CalculatorOptions { optional GlScalerCalculatorOptions ext = 166373014; @@ -39,4 +39,7 @@ message GlScalerCalculatorOptions { // Flip the output texture horizontally. This is applied after rotation. optional bool flip_horizontal = 5; optional ScaleMode.Mode scale_mode = 6; + // Whether to use nearest neighbor interpolation. Default to use linear + // interpolation. + optional bool use_nearest_neighbor_interpolation = 8 [default = false]; } diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 00ee9e248..e88aa602e 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -100,6 +100,10 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, #endif // TARGET_OS_OSX }}, + {GpuBufferFormat::kOneComponent8Alpha, + { + {GL_ALPHA, GL_ALPHA, GL_UNSIGNED_BYTE, 1}, + }}, {GpuBufferFormat::kOneComponent8Red, { {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, @@ -221,6 +225,7 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kRGBA32: // TODO: this likely maps to ImageFormat::SRGBA case GpuBufferFormat::kGrayHalf16: + case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Red: case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponentHalf16: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 5d77afeb6..06eabda77 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -43,6 +43,7 @@ enum class GpuBufferFormat : uint32_t { kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'), kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'), kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'), + kOneComponent8Alpha = MEDIAPIPE_FOURCC('A', '0', '0', '8'), kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'), kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'), kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'), @@ -101,6 +102,7 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { return kCVPixelFormatType_OneComponent32Float; case GpuBufferFormat::kOneComponent8: return kCVPixelFormatType_OneComponent8; + case GpuBufferFormat::kOneComponent8Alpha: case GpuBufferFormat::kOneComponent8Red: return -1; case GpuBufferFormat::kTwoComponent8: diff --git a/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java index 20c63c069..242cd616a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java @@ -78,17 +78,21 @@ public class AppTextureFrame implements TextureFrame { * Use {@link waitUntilReleasedWithGpuSync} whenever possible. */ public void waitUntilReleased() throws InterruptedException { + GlSyncToken tokenToRelease = null; synchronized (this) { while (inUse && releaseSyncToken == null) { wait(); } if (releaseSyncToken != null) { - releaseSyncToken.waitOnCpu(); - releaseSyncToken.release(); + tokenToRelease = releaseSyncToken; inUse = false; releaseSyncToken = null; } } + if (tokenToRelease != null) { + tokenToRelease.waitOnCpu(); + tokenToRelease.release(); + } } /** @@ -98,17 +102,21 @@ public class AppTextureFrame implements TextureFrame { * TextureFrame. */ public void waitUntilReleasedWithGpuSync() throws InterruptedException { + GlSyncToken tokenToRelease = null; synchronized (this) { while (inUse && releaseSyncToken == null) { wait(); } if (releaseSyncToken != null) { - releaseSyncToken.waitOnGpu(); - releaseSyncToken.release(); + tokenToRelease = releaseSyncToken; inUse = false; releaseSyncToken = null; } } + if (tokenToRelease != null) { + tokenToRelease.waitOnGpu(); + tokenToRelease.release(); + } } /** diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 1c1daadcc..5ea12872a 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -239,7 +239,7 @@ public final class PacketGetter { /** * Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer - * array has the the same size of image list packet, and assumes the output buffer stores pixels + * array has the same size of image list packet, and assumes the output buffer stores pixels * contiguously. It returns false if this assumption does not hold. * *

If deepCopy is true, it assumes the given buffersArray has allocated the required size of diff --git a/mediapipe/model_maker/python/BUILD b/mediapipe/model_maker/python/BUILD index 775ac82dd..42681fadb 100644 --- a/mediapipe/model_maker/python/BUILD +++ b/mediapipe/model_maker/python/BUILD @@ -24,6 +24,7 @@ package_group( package_group( name = "1p_client", packages = [ + "//cloud/ml/applications/vision/model_garden/model_oss/mediapipe/...", "//research/privacy/learning/fl_eval/pcvr/...", ], ) diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD index 1c2fb7a44..4364b7744 100644 --- a/mediapipe/model_maker/python/core/data/BUILD +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -57,3 +57,14 @@ py_test( srcs = ["classification_dataset_test.py"], deps = [":classification_dataset"], ) + +py_library( + name = "cache_files", + srcs = ["cache_files.py"], +) + +py_test( + name = "cache_files_test", + srcs = ["cache_files_test.py"], + deps = [":cache_files"], +) diff --git a/mediapipe/model_maker/python/core/data/cache_files.py b/mediapipe/model_maker/python/core/data/cache_files.py new file mode 100644 index 000000000..13d3d5b61 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/cache_files.py @@ -0,0 +1,112 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common TFRecord cache files library.""" + +import dataclasses +import os +import tempfile +from typing import Any, Mapping, Sequence + +import tensorflow as tf +import yaml + + +# Suffix of the meta data file name. +METADATA_FILE_SUFFIX = '_metadata.yaml' + + +@dataclasses.dataclass(frozen=True) +class TFRecordCacheFiles: + """TFRecordCacheFiles dataclass to store and load cached TFRecord files. + + Attributes: + cache_prefix_filename: The cache prefix filename. This is usually provided + as a hash of the original data source to avoid different data sources + resulting in the same cache file. + cache_dir: The cache directory to save TFRecord and metadata file. When + cache_dir is None, a temporary folder will be created and will not be + removed automatically after training which makes it can be used later. + num_shards: Number of shards for output tfrecord files. + """ + + cache_prefix_filename: str = 'cache_prefix' + cache_dir: str = dataclasses.field(default_factory=tempfile.mkdtemp) + num_shards: int = 1 + + def __post_init__(self): + if not tf.io.gfile.exists(self.cache_dir): + tf.io.gfile.makedirs(self.cache_dir) + if not self.cache_prefix_filename: + raise ValueError('cache_prefix_filename cannot be empty.') + if self.num_shards <= 0: + raise ValueError( + f'num_shards must be greater than 0, got {self.num_shards}' + ) + + @property + def cache_prefix(self) -> str: + """The cache prefix including the cache directory and the cache prefix filename.""" + return os.path.join(self.cache_dir, self.cache_prefix_filename) + + @property + def tfrecord_files(self) -> Sequence[str]: + """The TFRecord files.""" + tfrecord_files = [ + self.cache_prefix + '-%05d-of-%05d.tfrecord' % (i, self.num_shards) + for i in range(self.num_shards) + ] + return tfrecord_files + + @property + def metadata_file(self) -> str: + """The metadata file.""" + return self.cache_prefix + METADATA_FILE_SUFFIX + + def get_writers(self) -> Sequence[tf.io.TFRecordWriter]: + """Gets an array of TFRecordWriter objects. + + Note that these writers should each be closed using .close() when done. + + Returns: + Array of TFRecordWriter objects + """ + return [tf.io.TFRecordWriter(path) for path in self.tfrecord_files] + + def save_metadata(self, metadata): + """Writes metadata to file. + + Args: + metadata: A dictionary of metadata content to write. Exact format is + dependent on the specific dataset, but typically includes a 'size' and + 'label_names' entry. + """ + with tf.io.gfile.GFile(self.metadata_file, 'w') as f: + yaml.dump(metadata, f) + + def load_metadata(self) -> Mapping[Any, Any]: + """Reads metadata from file. + + Returns: + Dictionary object containing metadata + """ + if not tf.io.gfile.exists(self.metadata_file): + return {} + with tf.io.gfile.GFile(self.metadata_file, 'r') as f: + metadata = yaml.load(f, Loader=yaml.FullLoader) + return metadata + + def is_cached(self) -> bool: + """Checks whether this CacheFiles is already cached.""" + all_cached_files = list(self.tfrecord_files) + [self.metadata_file] + return all(tf.io.gfile.exists(f) for f in all_cached_files) diff --git a/mediapipe/model_maker/python/core/data/cache_files_test.py b/mediapipe/model_maker/python/core/data/cache_files_test.py new file mode 100644 index 000000000..ac727b3fe --- /dev/null +++ b/mediapipe/model_maker/python/core/data/cache_files_test.py @@ -0,0 +1,77 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import cache_files + + +class CacheFilesTest(tf.test.TestCase): + + def test_tfrecord_cache_files(self): + cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='tfrecord', + cache_dir='/tmp/cache_dir', + num_shards=2, + ) + self.assertEqual(cf.cache_prefix, '/tmp/cache_dir/tfrecord') + self.assertEqual( + cf.metadata_file, + '/tmp/cache_dir/tfrecord' + cache_files.METADATA_FILE_SUFFIX, + ) + expected_tfrecord_files = [ + '/tmp/cache_dir/tfrecord-%05d-of-%05d.tfrecord' % (i, 2) + for i in range(2) + ] + self.assertEqual(cf.tfrecord_files, expected_tfrecord_files) + + # Writing TFRecord Files + self.assertFalse(cf.is_cached()) + for tfrecord_file in cf.tfrecord_files: + self.assertFalse(tf.io.gfile.exists(tfrecord_file)) + writers = cf.get_writers() + for writer in writers: + writer.close() + for tfrecord_file in cf.tfrecord_files: + self.assertTrue(tf.io.gfile.exists(tfrecord_file)) + self.assertFalse(cf.is_cached()) + + # Writing Metadata Files + original_metadata = {'size': 10, 'label_names': ['label1', 'label2']} + cf.save_metadata(original_metadata) + self.assertTrue(cf.is_cached()) + metadata = cf.load_metadata() + self.assertEqual(metadata, original_metadata) + + def test_recordio_cache_files_error(self): + with self.assertRaisesRegex( + ValueError, 'cache_prefix_filename cannot be empty' + ): + cache_files.TFRecordCacheFiles( + cache_prefix_filename='', + cache_dir='/tmp/cache_dir', + num_shards=2, + ) + with self.assertRaisesRegex( + ValueError, 'num_shards must be greater than 0, got 0' + ): + cache_files.TFRecordCacheFiles( + cache_prefix_filename='tfrecord', + cache_dir='/tmp/cache_dir', + num_shards=0, + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/classification_dataset.py b/mediapipe/model_maker/python/core/data/classification_dataset.py index b1df3b6d4..352caca6f 100644 --- a/mediapipe/model_maker/python/core/data/classification_dataset.py +++ b/mediapipe/model_maker/python/core/data/classification_dataset.py @@ -13,7 +13,7 @@ # limitations under the License. """Common classification dataset library.""" -from typing import List, Tuple +from typing import List, Optional, Tuple import tensorflow as tf @@ -23,8 +23,12 @@ from mediapipe.model_maker.python.core.data import dataset as ds class ClassificationDataset(ds.Dataset): """Dataset Loader for classification models.""" - def __init__(self, dataset: tf.data.Dataset, size: int, - label_names: List[str]): + def __init__( + self, + dataset: tf.data.Dataset, + label_names: List[str], + size: Optional[int] = None, + ): super().__init__(dataset, size) self._label_names = label_names diff --git a/mediapipe/model_maker/python/core/data/classification_dataset_test.py b/mediapipe/model_maker/python/core/data/classification_dataset_test.py index d21803f43..dfcea7da6 100644 --- a/mediapipe/model_maker/python/core/data/classification_dataset_test.py +++ b/mediapipe/model_maker/python/core/data/classification_dataset_test.py @@ -36,9 +36,14 @@ class ClassificationDatasetTest(tf.test.TestCase): value: A value variable stored by the mock dataset class for testing. """ - def __init__(self, dataset: tf.data.Dataset, size: int, - label_names: List[str], value: Any): - super().__init__(dataset=dataset, size=size, label_names=label_names) + def __init__( + self, + dataset: tf.data.Dataset, + label_names: List[str], + value: Any, + size: int, + ): + super().__init__(dataset=dataset, label_names=label_names, size=size) self.value = value def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: @@ -52,7 +57,8 @@ class ClassificationDatasetTest(tf.test.TestCase): # Create data loader from sample data. ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) data = MagicClassificationDataset( - dataset=ds, size=len(ds), label_names=label_names, value=magic_value) + dataset=ds, label_names=label_names, value=magic_value, size=len(ds) + ) # Train/Test data split. fraction = .25 diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py index bfdc5b0f1..0cfccb149 100644 --- a/mediapipe/model_maker/python/core/data/dataset.py +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -56,15 +56,14 @@ class Dataset(object): def size(self) -> Optional[int]: """Returns the size of the dataset. - Note that this function may return None becuase the exact size of the - dataset isn't a necessary parameter to create an instance of this class, - and tf.data.Dataset donesn't support a function to get the length directly - since it's lazy-loaded and may be infinite. - In most cases, however, when an instance of this class is created by helper - functions like 'from_folder', the size of the dataset will be preprocessed, - and this function can return an int representing the size of the dataset. + Same functionality as calling __len__. See the __len__ method definition for + more information. + + Raises: + TypeError if self._size is not set and the cardinality of self._dataset + is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY. """ - return self._size + return self.__len__() def gen_tf_dataset( self, @@ -116,8 +115,22 @@ class Dataset(object): # here. return dataset - def __len__(self): - """Returns the number of element of the dataset.""" + def __len__(self) -> int: + """Returns the number of element of the dataset. + + If size is not set, this method will fallback to using the __len__ method + of the tf.data.Dataset in self._dataset. Calling __len__ on a + tf.data.Dataset instance may throw a TypeError because the dataset may + be lazy-loaded with an unknown size or have infinite size. + + In most cases, however, when an instance of this class is created by helper + functions like 'from_folder', the size of the dataset will be preprocessed, + and the _size instance variable will be already set. + + Raises: + TypeError if self._size is not set and the cardinality of self._dataset + is INFINITE_CARDINALITY or UNKNOWN_CARDINALITY. + """ if self._size is not None: return self._size else: @@ -152,15 +165,25 @@ class Dataset(object): Returns: The splitted two sub datasets. + + Raises: + ValueError: if the provided fraction is not between 0 and 1. + ValueError: if this dataset does not have a set size. """ - assert (fraction > 0 and fraction < 1) + if not (fraction > 0 and fraction < 1): + raise ValueError(f'Fraction must be between 0 and 1. Got:{fraction}') + if not self._size: + raise ValueError( + 'Dataset size unknown. Cannot split the dataset when ' + 'the size is unknown.' + ) dataset = self._dataset train_size = int(self._size * fraction) - trainset = self.__class__(dataset.take(train_size), train_size, *args) + trainset = self.__class__(dataset.take(train_size), *args, size=train_size) test_size = self._size - train_size - testset = self.__class__(dataset.skip(train_size), test_size, *args) + testset = self.__class__(dataset.skip(train_size), *args, size=test_size) return trainset, testset diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py index 224716550..92e1856cc 100644 --- a/mediapipe/model_maker/python/core/hyperparameters.py +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -15,7 +15,7 @@ import dataclasses import tempfile -from typing import Optional +from typing import Mapping, Optional import tensorflow as tf @@ -36,6 +36,8 @@ class BaseHParams: steps_per_epoch: An optional integer indicate the number of training steps per epoch. If not set, the training pipeline calculates the default steps per epoch as the training dataset size divided by batch size. + class_weights: An optional mapping of indices to weights for weighting the + loss function during training. shuffle: True if the dataset is shuffled before training. export_dir: The location of the model checkpoint files. distribution_strategy: A string specifying which Distribution Strategy to @@ -57,6 +59,7 @@ class BaseHParams: batch_size: int epochs: int steps_per_epoch: Optional[int] = None + class_weights: Optional[Mapping[int, float]] = None # Dataset-related parameters shuffle: bool = False diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index a042c0ec7..d504defbe 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel): # dataset is exhausted even if there are epochs remaining. steps_per_epoch=None, validation_data=validation_dataset, - callbacks=self._callbacks) + callbacks=self._callbacks, + class_weight=self._hparams.class_weights, + ) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: """Evaluates the classifier with the provided evaluation dataset. diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py index 504ba91ef..c741e4282 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -59,7 +59,7 @@ class FocalLoss(tf.keras.losses.Loss): """ def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): - """Constructor. + """Initializes FocalLoss. Args: gamma: Focal loss gamma, as described in class docs. @@ -115,6 +115,51 @@ class FocalLoss(tf.keras.losses.Loss): return tf.reduce_sum(losses) / batch_size +class SparseFocalLoss(FocalLoss): + """Sparse implementation of Focal Loss. + + This is the same as FocalLoss, except the labels are expected to be class ids + instead of 1-hot encoded vectors. See FocalLoss class documentation defined + in this same file for more details. + + Example usage: + >>> y_true = [1, 2] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> gamma = 2 + >>> focal_loss = SparseFocalLoss(gamma, 3) + >>> focal_loss(y_true, y_pred).numpy() + 0.9326 + + >>> # Calling with 'sample_weight'. + >>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() + 0.6528 + """ + + def __init__( + self, gamma, num_classes, class_weight: Optional[Sequence[float]] = None + ): + """Initializes SparseFocalLoss. + + Args: + gamma: Focal loss gamma, as described in class docs. + num_classes: Number of classes. + class_weight: A weight to apply to the loss, one for each class. The + weight is applied for each input where the ground truth label matches. + """ + super().__init__(gamma, class_weight=class_weight) + self._num_classes = num_classes + + def __call__( + self, + y_true: tf.Tensor, + y_pred: tf.Tensor, + sample_weight: Optional[tf.Tensor] = None, + ) -> tf.Tensor: + y_true = tf.cast(tf.reshape(y_true, [-1]), tf.int32) + y_true_one_hot = tf.one_hot(y_true, self._num_classes) + return super().__call__(y_true_one_hot, y_pred, sample_weight=sample_weight) + + @dataclasses.dataclass class PerceptualLossWeight: """The weight for each perceptual loss. diff --git a/mediapipe/model_maker/python/core/utils/loss_functions_test.py b/mediapipe/model_maker/python/core/utils/loss_functions_test.py index 01f9a667d..3a14567ed 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions_test.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions_test.py @@ -101,6 +101,23 @@ class FocalLossTest(tf.test.TestCase, parameterized.TestCase): self.assertNear(loss, expected_loss, 1e-4) +class SparseFocalLossTest(tf.test.TestCase): + + def test_sparse_focal_loss_matches_focal_loss(self): + num_classes = 2 + y_pred = tf.constant([[0.8, 0.2], [0.3, 0.7]]) + y_true = tf.constant([1, 0]) + y_true_one_hot = tf.one_hot(y_true, num_classes) + for gamma in [0.0, 0.5, 1.0]: + expected_loss_fn = loss_functions.FocalLoss(gamma=gamma) + loss_fn = loss_functions.SparseFocalLoss( + gamma=gamma, num_classes=num_classes + ) + expected_loss = expected_loss_fn(y_true_one_hot, y_pred) + loss = loss_fn(y_true, y_pred) + self.assertNear(loss, expected_loss, 1e-4) + + class MockPerceptualLoss(loss_functions.PerceptualLoss): """A mock class with implementation of abstract methods for testing.""" diff --git a/mediapipe/model_maker/python/text/core/bert_model_spec.py b/mediapipe/model_maker/python/text/core/bert_model_spec.py index 792c2c9a6..80e92a06a 100644 --- a/mediapipe/model_maker/python/text/core/bert_model_spec.py +++ b/mediapipe/model_maker/python/text/core/bert_model_spec.py @@ -46,13 +46,17 @@ class BertModelSpec: """ downloaded_files: file_util.DownloadedFiles - hparams: hp.BaseHParams = hp.BaseHParams( - epochs=3, - batch_size=32, - learning_rate=3e-5, - distribution_strategy='mirrored') - model_options: bert_model_options.BertModelOptions = ( - bert_model_options.BertModelOptions()) + hparams: hp.BaseHParams = dataclasses.field( + default_factory=lambda: hp.BaseHParams( + epochs=3, + batch_size=32, + learning_rate=3e-5, + distribution_strategy='mirrored', + ) + ) + model_options: bert_model_options.BertModelOptions = dataclasses.field( + default_factory=bert_model_options.BertModelOptions + ) do_lower_case: bool = True tflite_input_name: Dict[str, str] = dataclasses.field( default_factory=lambda: _DEFAULT_TFLITE_INPUT_NAME) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 64ace4ba0..322b1e1e5 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. # Placeholder for internal Python strict test compatibility macro. package(default_visibility = ["//mediapipe:__subpackages__"]) @@ -76,7 +76,10 @@ py_test( py_library( name = "dataset", srcs = ["dataset.py"], - deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], + deps = [ + "//mediapipe/model_maker/python/core/data:cache_files", + "//mediapipe/model_maker/python/core/data:classification_dataset", + ], ) py_test( @@ -88,7 +91,10 @@ py_test( py_library( name = "preprocessor", srcs = ["preprocessor.py"], - deps = [":dataset"], + deps = [ + ":dataset", + "//mediapipe/model_maker/python/core/data:cache_files", + ], ) py_test( @@ -99,6 +105,7 @@ py_test( ":dataset", ":model_spec", ":preprocessor", + "//mediapipe/model_maker/python/core/data:cache_files", ], ) @@ -124,6 +131,7 @@ py_library( ":text_classifier_options", "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:loss_functions", "//mediapipe/model_maker/python/core/utils:metrics", "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", @@ -147,6 +155,7 @@ py_test( ], deps = [ ":text_classifier_import", + "//mediapipe/model_maker/python/core/utils:loss_functions", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/text/text_classifier/dataset.py b/mediapipe/model_maker/python/text/text_classifier/dataset.py index 63605b477..1f8798df7 100644 --- a/mediapipe/model_maker/python/text/text_classifier/dataset.py +++ b/mediapipe/model_maker/python/text/text_classifier/dataset.py @@ -15,11 +15,15 @@ import csv import dataclasses +import hashlib +import os import random +import tempfile +from typing import List, Optional, Sequence -from typing import Optional, Sequence import tensorflow as tf +from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.core.data import classification_dataset @@ -46,21 +50,49 @@ class CSVParameters: class Dataset(classification_dataset.ClassificationDataset): """Dataset library for text classifier.""" + def __init__( + self, + dataset: tf.data.Dataset, + label_names: List[str], + tfrecord_cache_files: Optional[cache_files_lib.TFRecordCacheFiles] = None, + size: Optional[int] = None, + ): + super().__init__(dataset, label_names, size) + if not tfrecord_cache_files: + tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles( + cache_prefix_filename="tfrecord", num_shards=1 + ) + self.tfrecord_cache_files = tfrecord_cache_files + @classmethod - def from_csv(cls, - filename: str, - csv_params: CSVParameters, - shuffle: bool = True) -> "Dataset": + def from_csv( + cls, + filename: str, + csv_params: CSVParameters, + shuffle: bool = True, + cache_dir: Optional[str] = None, + num_shards: int = 1, + ) -> "Dataset": """Loads text with labels from a CSV file. Args: filename: Name of the CSV file. csv_params: Parameters used for reading the CSV file. shuffle: If True, randomly shuffle the data. + cache_dir: Optional parameter to specify where to store the preprocessed + dataset. Only used for BERT models. + num_shards: Optional parameter for num shards of the preprocessed dataset. + Note that using more than 1 shard will reorder the dataset. Only used + for BERT models. Returns: Dataset containing (text, label) pairs and other related info. """ + if cache_dir is None: + cache_dir = tempfile.mkdtemp() + # calculate hash for cache based off of files + hasher = hashlib.md5() + hasher.update(os.path.basename(filename).encode("utf-8")) with tf.io.gfile.GFile(filename, "r") as f: reader = csv.DictReader( f, @@ -69,6 +101,9 @@ class Dataset(classification_dataset.ClassificationDataset): quotechar=csv_params.quotechar) lines = list(reader) + for line in lines: + hasher.update(str(line).encode("utf-8")) + if shuffle: random.shuffle(lines) @@ -81,8 +116,18 @@ class Dataset(classification_dataset.ClassificationDataset): index_by_label[line[csv_params.label_column]] for line in lines ] label_index_ds = tf.data.Dataset.from_tensor_slices( - tf.cast(label_indices, tf.int64)) + tf.cast(label_indices, tf.int64) + ) text_label_ds = tf.data.Dataset.zip((text_ds, label_index_ds)) + hasher.update(str(num_shards).encode("utf-8")) + cache_prefix_filename = hasher.hexdigest() + tfrecord_cache_files = cache_files_lib.TFRecordCacheFiles( + cache_prefix_filename, cache_dir, num_shards + ) return Dataset( - dataset=text_label_ds, size=len(texts), label_names=label_names) + dataset=text_label_ds, + label_names=label_names, + tfrecord_cache_files=tfrecord_cache_files, + size=len(texts), + ) diff --git a/mediapipe/model_maker/python/text/text_classifier/dataset_test.py b/mediapipe/model_maker/python/text/text_classifier/dataset_test.py index 012476e0b..2fa90b860 100644 --- a/mediapipe/model_maker/python/text/text_classifier/dataset_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/dataset_test.py @@ -53,7 +53,7 @@ class DatasetTest(tf.test.TestCase): def test_split(self): ds = tf.data.Dataset.from_tensor_slices(['good', 'bad', 'neutral', 'odd']) - data = dataset.Dataset(ds, 4, ['pos', 'neg']) + data = dataset.Dataset(ds, ['pos', 'neg'], size=4) train_data, test_data = data.split(0.5) expected_train_data = [b'good', b'bad'] expected_test_data = [b'neutral', b'odd'] diff --git a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py index ae0a9a627..71470edb3 100644 --- a/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py +++ b/mediapipe/model_maker/python/text/text_classifier/hyperparameters.py @@ -15,7 +15,7 @@ import dataclasses import enum -from typing import Union +from typing import Sequence, Union from mediapipe.model_maker.python.core import hyperparameters as hp @@ -39,16 +39,34 @@ class BertHParams(hp.BaseHParams): Attributes: learning_rate: Learning rate to use for gradient descent training. - batch_size: Batch size for training. - epochs: Number of training iterations over the dataset. - optimizer: Optimizer to use for training. Only supported values are "adamw" - and "lamb". + end_learning_rate: End learning rate for linear decay. Defaults to 0. + batch_size: Batch size for training. Defaults to 48. + epochs: Number of training iterations over the dataset. Defaults to 2. + optimizer: Optimizer to use for training. Supported values are defined in + BertOptimizer enum: ADAMW and LAMB. + weight_decay: Weight decay of the optimizer. Defaults to 0.01. + desired_precisions: If specified, adds a RecallAtPrecision metric per + desired_precisions[i] entry which tracks the recall given the constraint + on precision. Only supported for binary classification. + desired_recalls: If specified, adds a PrecisionAtRecall metric per + desired_recalls[i] entry which tracks the precision given the constraint + on recall. Only supported for binary classification. + gamma: Gamma parameter for focal loss. To use cross entropy loss, set this + value to 0. Defaults to 2.0. """ learning_rate: float = 3e-5 + end_learning_rate: float = 0.0 + batch_size: int = 48 epochs: int = 2 optimizer: BertOptimizer = BertOptimizer.ADAMW + weight_decay: float = 0.01 + + desired_precisions: Sequence[float] = dataclasses.field(default_factory=list) + desired_recalls: Sequence[float] = dataclasses.field(default_factory=list) + + gamma: float = 2.0 HParams = Union[BertHParams, AverageWordEmbeddingHParams] diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec.py b/mediapipe/model_maker/python/text/text_classifier/model_spec.py index 452e22679..724aaf377 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec.py @@ -47,11 +47,14 @@ class AverageWordEmbeddingClassifierSpec: """ # `learning_rate` is unused for the average word embedding model - hparams: hp.AverageWordEmbeddingHParams = hp.AverageWordEmbeddingHParams( - epochs=10, batch_size=32, learning_rate=0 + hparams: hp.AverageWordEmbeddingHParams = dataclasses.field( + default_factory=lambda: hp.AverageWordEmbeddingHParams( + epochs=10, batch_size=32, learning_rate=0 + ) + ) + model_options: mo.AverageWordEmbeddingModelOptions = dataclasses.field( + default_factory=mo.AverageWordEmbeddingModelOptions ) - model_options: mo.AverageWordEmbeddingModelOptions = ( - mo.AverageWordEmbeddingModelOptions()) name: str = 'AverageWordEmbedding' average_word_embedding_classifier_spec = functools.partial( @@ -66,7 +69,7 @@ class BertClassifierSpec(bert_model_spec.BertModelSpec): inherited from the BertModelSpec. """ - hparams: hp.BertHParams = hp.BertHParams() + hparams: hp.BertHParams = dataclasses.field(default_factory=hp.BertHParams) mobilebert_classifier_spec = functools.partial( @@ -76,11 +79,6 @@ mobilebert_classifier_spec = functools.partial( epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' ), name='MobileBert', - tflite_input_name={ - 'ids': 'serving_default_input_1:0', - 'segment_ids': 'serving_default_input_2:0', - 'mask': 'serving_default_input_3:0', - }, ) exbert_classifier_spec = functools.partial( @@ -90,11 +88,6 @@ exbert_classifier_spec = functools.partial( epochs=3, batch_size=48, learning_rate=3e-5, distribution_strategy='off' ), name='ExBert', - tflite_input_name={ - 'ids': 'serving_default_input_1:0', - 'segment_ids': 'serving_default_input_2:0', - 'mask': 'serving_default_input_3:0', - }, ) diff --git a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py index 7c45a2675..4d42851d5 100644 --- a/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/model_spec_test.py @@ -46,11 +46,13 @@ class ModelSpecTest(tf.test.TestCase): self.assertTrue(os.path.exists(model_spec_obj.downloaded_files.get_path())) self.assertTrue(model_spec_obj.do_lower_case) self.assertEqual( - model_spec_obj.tflite_input_name, { - 'ids': 'serving_default_input_1:0', - 'mask': 'serving_default_input_3:0', - 'segment_ids': 'serving_default_input_2:0' - }) + model_spec_obj.tflite_input_name, + { + 'ids': 'serving_default_input_word_ids:0', + 'mask': 'serving_default_input_mask:0', + 'segment_ids': 'serving_default_input_type_ids:0', + }, + ) self.assertEqual( model_spec_obj.model_options, classifier_model_options.BertModelOptions( diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py index 15b9d90d0..2a31bbd09 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor.py @@ -15,14 +15,15 @@ """Preprocessors for text classification.""" import collections +import hashlib import os import re -import tempfile from typing import Mapping, Sequence, Tuple, Union import tensorflow as tf import tensorflow_hub +from mediapipe.model_maker.python.core.data import cache_files as cache_files_lib from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from official.nlp.data import classifier_data_lib from official.nlp.tools import tokenization @@ -75,19 +76,20 @@ def _decode_record( return bert_features, example["label_ids"] -def _single_file_dataset( - input_file: str, name_to_features: Mapping[str, tf.io.FixedLenFeature] +def _tfrecord_dataset( + tfrecord_files: Sequence[str], + name_to_features: Mapping[str, tf.io.FixedLenFeature], ) -> tf.data.TFRecordDataset: """Creates a single-file dataset to be passed for BERT custom training. Args: - input_file: Filepath for the dataset. + tfrecord_files: Filepaths for the dataset. name_to_features: Maps record keys to feature types. Returns: Dataset containing BERT model input features and labels. """ - d = tf.data.TFRecordDataset(input_file) + d = tf.data.TFRecordDataset(tfrecord_files) d = d.map( lambda record: _decode_record(record, name_to_features), num_parallel_calls=tf.data.AUTOTUNE) @@ -221,15 +223,23 @@ class BertClassifierPreprocessor: seq_len: Length of the input sequence to the model. vocab_file: File containing the BERT vocab. tokenizer: BERT tokenizer. + model_name: Name of the model provided by the model_spec. Used to associate + cached files with specific Bert model vocab. """ - def __init__(self, seq_len: int, do_lower_case: bool, uri: str): + def __init__( + self, seq_len: int, do_lower_case: bool, uri: str, model_name: str + ): self._seq_len = seq_len # Vocab filepath is tied to the BERT module's URI. self._vocab_file = os.path.join( - tensorflow_hub.resolve(uri), "assets", "vocab.txt") - self._tokenizer = tokenization.FullTokenizer(self._vocab_file, - do_lower_case) + tensorflow_hub.resolve(uri), "assets", "vocab.txt" + ) + self._do_lower_case = do_lower_case + self._tokenizer = tokenization.FullTokenizer( + self._vocab_file, self._do_lower_case + ) + self._model_name = model_name def _get_name_to_features(self): """Gets the dictionary mapping record keys to feature types.""" @@ -244,8 +254,45 @@ class BertClassifierPreprocessor: """Returns the vocab file of the BertClassifierPreprocessor.""" return self._vocab_file + def _get_tfrecord_cache_files( + self, ds_cache_files + ) -> cache_files_lib.TFRecordCacheFiles: + """Helper to regenerate cache prefix filename using preprocessor info. + + We need to update the dataset cache_prefix cache because the actual cached + dataset depends on the preprocessor parameters such as model_name, seq_len, + and do_lower_case in addition to the raw dataset parameters which is already + included in the ds_cache_files.cache_prefix_filename + + Specifically, the new cache_prefix_filename used by the preprocessor will + be a hash generated from the following: + 1. cache_prefix_filename of the initial raw dataset + 2. model_name + 3. seq_len + 4. do_lower_case + + Args: + ds_cache_files: TFRecordCacheFiles from the original raw dataset object + + Returns: + A new TFRecordCacheFiles object which incorporates the preprocessor + parameters. + """ + hasher = hashlib.md5() + hasher.update(ds_cache_files.cache_prefix_filename.encode("utf-8")) + hasher.update(self._model_name.encode("utf-8")) + hasher.update(str(self._seq_len).encode("utf-8")) + hasher.update(str(self._do_lower_case).encode("utf-8")) + cache_prefix_filename = hasher.hexdigest() + return cache_files_lib.TFRecordCacheFiles( + cache_prefix_filename, + ds_cache_files.cache_dir, + ds_cache_files.num_shards, + ) + def preprocess( - self, dataset: text_classifier_ds.Dataset) -> text_classifier_ds.Dataset: + self, dataset: text_classifier_ds.Dataset + ) -> text_classifier_ds.Dataset: """Preprocesses data into input for a BERT-based classifier. Args: @@ -254,32 +301,65 @@ class BertClassifierPreprocessor: Returns: Dataset containing (bert_features, label) data. """ - examples = [] - for index, (text, label) in enumerate(dataset.gen_tf_dataset()): - _validate_text_and_label(text, label) - examples.append( - classifier_data_lib.InputExample( - guid=str(index), - text_a=text.numpy()[0].decode("utf-8"), - text_b=None, - # InputExample expects the label name rather than the int ID - label=dataset.label_names[label.numpy()[0]])) + ds_cache_files = dataset.tfrecord_cache_files + # Get new tfrecord_cache_files by including preprocessor information. + tfrecord_cache_files = self._get_tfrecord_cache_files(ds_cache_files) + if not tfrecord_cache_files.is_cached(): + print(f"Writing new cache files to {tfrecord_cache_files.cache_prefix}") + writers = tfrecord_cache_files.get_writers() + size = 0 + for index, (text, label) in enumerate(dataset.gen_tf_dataset()): + _validate_text_and_label(text, label) + example = classifier_data_lib.InputExample( + guid=str(index), + text_a=text.numpy()[0].decode("utf-8"), + text_b=None, + # InputExample expects the label name rather than the int ID + # label=dataset.label_names[label.numpy()[0]]) + label=label.numpy()[0], + ) + feature = classifier_data_lib.convert_single_example( + index, example, None, self._seq_len, self._tokenizer + ) - tfrecord_file = os.path.join(tempfile.mkdtemp(), "bert_features.tfrecord") - classifier_data_lib.file_based_convert_examples_to_features( - examples=examples, - label_list=dataset.label_names, - max_seq_length=self._seq_len, - tokenizer=self._tokenizer, - output_file=tfrecord_file) - preprocessed_ds = _single_file_dataset(tfrecord_file, - self._get_name_to_features()) + def create_int_feature(values): + f = tf.train.Feature( + int64_list=tf.train.Int64List(value=list(values)) + ) + return f + + features = collections.OrderedDict() + features["input_ids"] = create_int_feature(feature.input_ids) + features["input_mask"] = create_int_feature(feature.input_mask) + features["segment_ids"] = create_int_feature(feature.segment_ids) + features["label_ids"] = create_int_feature([feature.label_id]) + tf_example = tf.train.Example( + features=tf.train.Features(feature=features) + ) + writers[index % len(writers)].write(tf_example.SerializeToString()) + size = index + 1 + for writer in writers: + writer.close() + metadata = {"size": size, "label_names": dataset.label_names} + tfrecord_cache_files.save_metadata(metadata) + else: + print( + f"Using existing cache files at {tfrecord_cache_files.cache_prefix}" + ) + metadata = tfrecord_cache_files.load_metadata() + size = metadata["size"] + label_names = metadata["label_names"] + preprocessed_ds = _tfrecord_dataset( + tfrecord_cache_files.tfrecord_files, self._get_name_to_features() + ) return text_classifier_ds.Dataset( dataset=preprocessed_ds, - size=dataset.size, - label_names=dataset.label_names) + size=size, + label_names=label_names, + tfrecord_cache_files=tfrecord_cache_files, + ) -TextClassifierPreprocessor = ( - Union[BertClassifierPreprocessor, - AverageWordEmbeddingClassifierPreprocessor]) +TextClassifierPreprocessor = Union[ + BertClassifierPreprocessor, AverageWordEmbeddingClassifierPreprocessor +] diff --git a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py index 27e98e262..28c12f96c 100644 --- a/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/preprocessor_test.py @@ -13,14 +13,17 @@ # limitations under the License. import csv +import io import os import tempfile from unittest import mock as unittest_mock +import mock import numpy as np import numpy.testing as npt import tensorflow as tf +from mediapipe.model_maker.python.core.data import cache_files from mediapipe.model_maker.python.text.text_classifier import dataset as text_classifier_ds from mediapipe.model_maker.python.text.text_classifier import model_spec from mediapipe.model_maker.python.text.text_classifier import preprocessor @@ -84,11 +87,12 @@ class PreprocessorTest(tf.test.TestCase): csv_file = self._get_csv_file() dataset = text_classifier_ds.Dataset.from_csv( filename=csv_file, csv_params=self.CSV_PARAMS_) - bert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() bert_preprocessor = preprocessor.BertClassifierPreprocessor( seq_len=5, do_lower_case=bert_spec.do_lower_case, uri=bert_spec.downloaded_files.get_path(), + model_name=bert_spec.name, ) preprocessed_dataset = bert_preprocessor.preprocess(dataset) labels = [] @@ -97,18 +101,91 @@ class PreprocessorTest(tf.test.TestCase): self.assertEqual(label.shape, [1]) labels.append(label.numpy()[0]) self.assertSameElements( - features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids']) + features.keys(), ['input_word_ids', 'input_mask', 'input_type_ids'] + ) for feature in features.values(): self.assertEqual(feature.shape, [1, 5]) input_masks.append(features['input_mask'].numpy()[0]) - npt.assert_array_equal(features['input_type_ids'].numpy()[0], - [0, 0, 0, 0, 0]) + npt.assert_array_equal( + features['input_type_ids'].numpy()[0], [0, 0, 0, 0, 0] + ) npt.assert_array_equal( - np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])) + np.stack(input_masks), np.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]) + ) self.assertEqual(labels, [1, 0]) + def test_bert_preprocessor_cache(self): + csv_file = self._get_csv_file() + dataset = text_classifier_ds.Dataset.from_csv( + filename=csv_file, + csv_params=self.CSV_PARAMS_, + cache_dir=self.get_temp_dir(), + ) + bert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() + bert_preprocessor = preprocessor.BertClassifierPreprocessor( + seq_len=5, + do_lower_case=bert_spec.do_lower_case, + uri=bert_spec.downloaded_files.get_path(), + model_name=bert_spec.name, + ) + ds_cache_files = dataset.tfrecord_cache_files + preprocessed_cache_files = bert_preprocessor._get_tfrecord_cache_files( + ds_cache_files + ) + self.assertFalse(preprocessed_cache_files.is_cached()) + preprocessed_dataset = bert_preprocessor.preprocess(dataset) + self.assertTrue(preprocessed_cache_files.is_cached()) + self.assertEqual( + preprocessed_dataset.tfrecord_cache_files, preprocessed_cache_files + ) + + # The second time running preprocessor, it should load from cache directly + mock_stdout = io.StringIO() + with mock.patch('sys.stdout', mock_stdout): + _ = bert_preprocessor.preprocess(dataset) + self.assertEqual( + mock_stdout.getvalue(), + 'Using existing cache files at' + f' {preprocessed_cache_files.cache_prefix}\n', + ) + + def _get_new_prefix(self, cf, bert_spec, seq_len, do_lower_case): + bert_preprocessor = preprocessor.BertClassifierPreprocessor( + seq_len=seq_len, + do_lower_case=do_lower_case, + uri=bert_spec.downloaded_files.get_path(), + model_name=bert_spec.name, + ) + new_cf = bert_preprocessor._get_tfrecord_cache_files(cf) + return new_cf.cache_prefix_filename + + def test_bert_get_tfrecord_cache_files(self): + # Test to ensure regenerated cache_files have different prefixes + all_cf_prefixes = set() + cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='cache_prefix', + cache_dir=self.get_temp_dir(), + num_shards=1, + ) + exbert_spec = model_spec.SupportedModels.EXBERT_CLASSIFIER.value() + all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, True)) + all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 10, True)) + all_cf_prefixes.add(self._get_new_prefix(cf, exbert_spec, 5, False)) + mobilebert_spec = model_spec.SupportedModels.MOBILEBERT_CLASSIFIER.value() + all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, True)) + all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 10, True)) + all_cf_prefixes.add(self._get_new_prefix(cf, mobilebert_spec, 5, False)) + new_cf = cache_files.TFRecordCacheFiles( + cache_prefix_filename='new_cache_prefix', + cache_dir=self.get_temp_dir(), + num_shards=1, + ) + all_cf_prefixes.add(self._get_new_prefix(new_cf, exbert_spec, 5, True)) + + # Each item of all_cf_prefixes should be unique, so 7 total. + self.assertLen(all_cf_prefixes, 7) + if __name__ == '__main__': # Load compressed models from tensorflow_hub - os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' tf.test.main() diff --git a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json index 24214a80d..22fb220fb 100644 --- a/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json +++ b/mediapipe/model_maker/python/text/text_classifier/testdata/bert_metadata.json @@ -16,8 +16,8 @@ } }, { - "name": "mask", - "description": "Mask with 1 for real tokens and 0 for padding tokens.", + "name": "segment_ids", + "description": "0 for the first sequence, 1 for the second sequence if exists.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { @@ -27,8 +27,8 @@ } }, { - "name": "segment_ids", - "description": "0 for the first sequence, 1 for the second sequence if exists.", + "name": "mask", + "description": "Mask with 1 for real tokens and 0 for padding tokens.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py index 6c8adc82c..10d88110d 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier.py @@ -24,6 +24,7 @@ import tensorflow_hub as hub from mediapipe.model_maker.python.core.data import dataset as ds from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import loss_functions from mediapipe.model_maker.python.core.utils import metrics from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization @@ -116,17 +117,14 @@ class TextClassifier(classifier.Classifier): options.supported_model == ms.SupportedModels.MOBILEBERT_CLASSIFIER or options.supported_model == ms.SupportedModels.EXBERT_CLASSIFIER ): - text_classifier = ( - _BertClassifier.create_bert_classifier(train_data, validation_data, - options, - train_data.label_names)) + text_classifier = _BertClassifier.create_bert_classifier( + train_data, validation_data, options + ) elif (options.supported_model == ms.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER): - text_classifier = ( - _AverageWordEmbeddingClassifier - .create_average_word_embedding_classifier(train_data, validation_data, - options, - train_data.label_names)) + text_classifier = _AverageWordEmbeddingClassifier.create_average_word_embedding_classifier( + train_data, validation_data, options + ) else: raise ValueError(f"Unknown model {options.supported_model}") @@ -166,28 +164,8 @@ class TextClassifier(classifier.Classifier): processed_data = self._text_preprocessor.preprocess(data) dataset = processed_data.gen_tf_dataset(batch_size, is_training=False) - additional_metrics = [] - if desired_precisions and len(data.label_names) == 2: - for precision in desired_precisions: - additional_metrics.append( - metrics.BinarySparseRecallAtPrecision( - precision, name=f"recall_at_precision_{precision}" - ) - ) - if desired_recalls and len(data.label_names) == 2: - for recall in desired_recalls: - additional_metrics.append( - metrics.BinarySparsePrecisionAtRecall( - recall, name=f"precision_at_recall_{recall}" - ) - ) - metric_functions = self._metric_functions + additional_metrics - self._model.compile( - optimizer=self._optimizer, - loss=self._loss_function, - metrics=metric_functions, - ) - return self._model.evaluate(dataset) + with self._hparams.get_strategy().scope(): + return self._model.evaluate(dataset) def export_model( self, @@ -255,16 +233,17 @@ class _AverageWordEmbeddingClassifier(TextClassifier): @classmethod def create_average_word_embedding_classifier( - cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + cls, + train_data: text_ds.Dataset, + validation_data: text_ds.Dataset, options: text_classifier_options.TextClassifierOptions, - label_names: Sequence[str]) -> "_AverageWordEmbeddingClassifier": + ) -> "_AverageWordEmbeddingClassifier": """Creates, trains, and returns an Average Word Embedding classifier. Args: train_data: Training data. validation_data: Validation data. options: Options for creating and training the text classifier. - label_names: Label names used in the data. Returns: An Average Word Embedding classifier. @@ -370,28 +349,25 @@ class _BertClassifier(TextClassifier): self._callbacks = model_util.get_default_callbacks(self._hparams.export_dir) self._model_options = model_options with self._hparams.get_strategy().scope(): - self._loss_function = tf.keras.losses.SparseCategoricalCrossentropy() - self._metric_functions = [ - tf.keras.metrics.SparseCategoricalAccuracy( - "test_accuracy", dtype=tf.float32 - ), - metrics.SparsePrecision(name="precision", dtype=tf.float32), - metrics.SparseRecall(name="recall", dtype=tf.float32), - ] - self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None + self._loss_function = loss_functions.SparseFocalLoss( + self._hparams.gamma, self._num_classes + ) + self._metric_functions = self._create_metrics() + self._text_preprocessor: preprocessor.BertClassifierPreprocessor = None @classmethod def create_bert_classifier( - cls, train_data: text_ds.Dataset, validation_data: text_ds.Dataset, + cls, + train_data: text_ds.Dataset, + validation_data: text_ds.Dataset, options: text_classifier_options.TextClassifierOptions, - label_names: Sequence[str]) -> "_BertClassifier": + ) -> "_BertClassifier": """Creates, trains, and returns a BERT-based classifier. Args: train_data: Training data. validation_data: Validation data. options: Options for creating and training the text classifier. - label_names: Label names used in the data. Returns: A BERT-based classifier. @@ -435,9 +411,59 @@ class _BertClassifier(TextClassifier): seq_len=self._model_options.seq_len, do_lower_case=self._model_spec.do_lower_case, uri=self._model_spec.downloaded_files.get_path(), + model_name=self._model_spec.name, ) - return (self._text_preprocessor.preprocess(train_data), - self._text_preprocessor.preprocess(validation_data)) + return ( + self._text_preprocessor.preprocess(train_data), + self._text_preprocessor.preprocess(validation_data), + ) + + def _create_metrics(self): + """Creates metrics for training and evaluation. + + The default metrics are accuracy, precision, and recall. + + For binary classification tasks only (num_classes=2): + Users can configure PrecisionAtRecall and RecallAtPrecision metrics using + the desired_presisions and desired_recalls fields in BertHParams. + + Returns: + A list of tf.keras.Metric subclasses which can be used with model.compile + """ + metric_functions = [ + tf.keras.metrics.SparseCategoricalAccuracy( + "accuracy", dtype=tf.float32 + ), + metrics.SparsePrecision(name="precision", dtype=tf.float32), + metrics.SparseRecall(name="recall", dtype=tf.float32), + ] + if self._num_classes == 2: + if self._hparams.desired_precisions: + for desired_precision in self._hparams.desired_precisions: + metric_functions.append( + metrics.BinarySparseRecallAtPrecision( + desired_precision, + name=f"recall_at_precision_{desired_precision}", + num_thresholds=1000, + ) + ) + if self._hparams.desired_recalls: + for desired_recall in self._hparams.desired_recalls: + metric_functions.append( + metrics.BinarySparseRecallAtPrecision( + desired_recall, + name=f"precision_at_recall_{desired_recall}", + num_thresholds=1000, + ) + ) + else: + if self._hparams.desired_precisions or self._hparams.desired_recalls: + raise ValueError( + "desired_recalls and desired_precisions parameters are binary" + " metrics and not supported for num_classes > 2. Found" + f" num_classes: {self._num_classes}" + ) + return metric_functions def _create_model(self): """Creates a BERT-based classifier model. @@ -447,11 +473,20 @@ class _BertClassifier(TextClassifier): """ encoder_inputs = dict( input_word_ids=tf.keras.layers.Input( - shape=(self._model_options.seq_len,), dtype=tf.int32), + shape=(self._model_options.seq_len,), + dtype=tf.int32, + name="input_word_ids", + ), input_mask=tf.keras.layers.Input( - shape=(self._model_options.seq_len,), dtype=tf.int32), + shape=(self._model_options.seq_len,), + dtype=tf.int32, + name="input_mask", + ), input_type_ids=tf.keras.layers.Input( - shape=(self._model_options.seq_len,), dtype=tf.int32), + shape=(self._model_options.seq_len,), + dtype=tf.int32, + name="input_type_ids", + ), ) encoder = hub.KerasLayer( self._model_spec.downloaded_files.get_path(), @@ -493,16 +528,21 @@ class _BertClassifier(TextClassifier): lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=initial_lr, decay_steps=total_steps, - end_learning_rate=0.0, - power=1.0) + end_learning_rate=self._hparams.end_learning_rate, + power=1.0, + ) if warmup_steps: lr_schedule = model_util.WarmUp( initial_learning_rate=initial_lr, decay_schedule_fn=lr_schedule, - warmup_steps=warmup_steps) + warmup_steps=warmup_steps, + ) if self._hparams.optimizer == hp.BertOptimizer.ADAMW: self._optimizer = tf.keras.optimizers.experimental.AdamW( - lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0 + lr_schedule, + weight_decay=self._hparams.weight_decay, + epsilon=1e-6, + global_clipnorm=1.0, ) self._optimizer.exclude_from_weight_decay( var_names=["LayerNorm", "layer_norm", "bias"] @@ -510,7 +550,7 @@ class _BertClassifier(TextClassifier): elif self._hparams.optimizer == hp.BertOptimizer.LAMB: self._optimizer = tfa_optimizers.LAMB( lr_schedule, - weight_decay_rate=0.01, + weight_decay_rate=self._hparams.weight_decay, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], global_clipnorm=1.0, diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py index 934bb1c4b..b646a15ad 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_demo.py @@ -84,8 +84,8 @@ def run(data_dir, options) # Gets evaluation results. - _, acc = model.evaluate(validation_data) - print('Eval accuracy: %f' % acc) + metrics = model.evaluate(validation_data) + print('Eval accuracy: %f' % metrics[1]) model.export_model(quantization_config=quantization_config) model.export_labels(export_dir=options.hparams.export_dir) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index e6057059c..be4646f68 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -16,17 +16,17 @@ import csv import filecmp import os import tempfile -import unittest from unittest import mock as unittest_mock +from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import loss_functions from mediapipe.model_maker.python.text import text_classifier from mediapipe.tasks.python.test import test_utils -@unittest.skip('b/275624089') -class TextClassifierTest(tf.test.TestCase): +class TextClassifierTest(tf.test.TestCase, parameterized.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( test_utils.get_test_data_path('average_word_embedding_metadata.json')) @@ -78,8 +78,8 @@ class TextClassifierTest(tf.test.TestCase): text_classifier.TextClassifier.create(train_data, validation_data, options)) - _, accuracy = average_word_embedding_classifier.evaluate(validation_data) - self.assertGreaterEqual(accuracy, 0.0) + metrics = average_word_embedding_classifier.evaluate(validation_data) + self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy # Test export_model average_word_embedding_classifier.export_model() @@ -98,12 +98,25 @@ class TextClassifierTest(tf.test.TestCase): filecmp.cmp( output_metadata_file, self._AVERAGE_WORD_EMBEDDING_JSON_FILE, - shallow=False)) + shallow=False, + ) + ) - def test_create_and_train_bert(self): + @parameterized.named_parameters( + # Skipping mobilebert b/c OSS test timeout/flakiness: b/275624089 + # dict( + # testcase_name='mobilebert', + # supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, + # ), + dict( + testcase_name='exbert', + supported_model=text_classifier.SupportedModels.EXBERT_CLASSIFIER, + ), + ) + def test_create_and_train_bert(self, supported_model): train_data, validation_data = self._get_data() options = text_classifier.TextClassifierOptions( - supported_model=text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER, + supported_model=supported_model, model_options=text_classifier.BertModelOptions( do_fine_tuning=False, seq_len=2 ), @@ -117,8 +130,8 @@ class TextClassifierTest(tf.test.TestCase): bert_classifier = text_classifier.TextClassifier.create( train_data, validation_data, options) - _, accuracy = bert_classifier.evaluate(validation_data) - self.assertGreaterEqual(accuracy, 0.0) + metrics = bert_classifier.evaluate(validation_data) + self.assertGreaterEqual(metrics[1], 0.0) # metrics[1] is accuracy # Test export_model bert_classifier.export_model() @@ -142,45 +155,93 @@ class TextClassifierTest(tf.test.TestCase): ) def test_label_mismatch(self): - options = ( - text_classifier.TextClassifierOptions( - supported_model=( - text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER))) + options = text_classifier.TextClassifierOptions( + supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER) + ) train_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) - train_data = text_classifier.Dataset(train_tf_dataset, 1, ['foo']) + train_data = text_classifier.Dataset(train_tf_dataset, ['foo'], 1) validation_tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) - validation_data = text_classifier.Dataset(validation_tf_dataset, 1, ['bar']) + validation_data = text_classifier.Dataset(validation_tf_dataset, ['bar'], 1) with self.assertRaisesRegex( ValueError, - 'Training data label names .* not equal to validation data label names' + 'Training data label names .* not equal to validation data label names', ): - text_classifier.TextClassifier.create(train_data, validation_data, - options) + text_classifier.TextClassifier.create( + train_data, validation_data, options + ) def test_options_mismatch(self): train_data, validation_data = self._get_data() - avg_options = ( - text_classifier.TextClassifierOptions( - supported_model=( - text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER), - model_options=text_classifier.AverageWordEmbeddingModelOptions())) - with self.assertRaisesRegex( - ValueError, 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' - ' SupportedModels.MOBILEBERT_CLASSIFIER'): - text_classifier.TextClassifier.create(train_data, validation_data, - avg_options) + avg_options = text_classifier.TextClassifierOptions( + supported_model=(text_classifier.SupportedModels.EXBERT_CLASSIFIER), + model_options=text_classifier.AverageWordEmbeddingModelOptions(), + ) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Expected AVERAGE_WORD_EMBEDDING_CLASSIFIER, got' + ' SupportedModels.EXBERT_CLASSIFIER', + ): + text_classifier.TextClassifier.create( + train_data, validation_data, avg_options + ) - bert_options = ( - text_classifier.TextClassifierOptions( - supported_model=(text_classifier.SupportedModels - .AVERAGE_WORD_EMBEDDING_CLASSIFIER), - model_options=text_classifier.BertModelOptions())) - with self.assertRaisesRegex( - ValueError, 'Expected MOBILEBERT_CLASSIFIER, got' - ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER'): - text_classifier.TextClassifier.create(train_data, validation_data, - bert_options) + bert_options = text_classifier.TextClassifierOptions( + supported_model=( + text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER + ), + model_options=text_classifier.BertModelOptions(), + ) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Expected a Bert Classifier(MobileBERT or EXBERT), got' + ' SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER', + ): + text_classifier.TextClassifier.create( + train_data, validation_data, bert_options + ) + + def test_bert_loss_and_metrics_creation(self): + train_data, validation_data = self._get_data() + supported_model = text_classifier.SupportedModels.EXBERT_CLASSIFIER + hparams = text_classifier.BertHParams( + desired_recalls=[0.2], + desired_precisions=[0.9], + epochs=1, + batch_size=1, + learning_rate=3e-5, + distribution_strategy='off', + gamma=3.5, + ) + options = text_classifier.TextClassifierOptions( + supported_model=supported_model, hparams=hparams + ) + bert_classifier = text_classifier.TextClassifier.create( + train_data, validation_data, options + ) + loss_fn = bert_classifier._loss_function + self.assertIsInstance(loss_fn, loss_functions.SparseFocalLoss) + self.assertEqual(loss_fn._gamma, 3.5) + self.assertEqual(loss_fn._num_classes, 2) + metric_names = [m.name for m in bert_classifier._metric_functions] + expected_metric_names = [ + 'accuracy', + 'recall', + 'precision', + 'precision_at_recall_0.2', + 'recall_at_precision_0.9', + ] + self.assertCountEqual(metric_names, expected_metric_names) + + # Non-binary data + tf_dataset = tf.data.Dataset.from_tensor_slices([[0]]) + data = text_classifier.Dataset(tf_dataset, ['foo', 'bar', 'baz'], 1) + with self.assertRaisesWithLiteralMatch( + ValueError, + 'desired_recalls and desired_precisions parameters are binary metrics' + ' and not supported for num_classes > 2. Found num_classes: 3', + ): + text_classifier.TextClassifier.create(data, data, options) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py index 93478de1b..85802f908 100644 --- a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py @@ -115,5 +115,7 @@ class Dataset(classification_dataset.ClassificationDataset): ', '.join(label_names), ) return Dataset( - dataset=image_label_ds, size=all_image_size, label_names=label_names + dataset=image_label_ds, + label_names=label_names, + size=all_image_size, ) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index ecd2a7125..969887e64 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -13,7 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict test compatibility macro. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py index 1ba626be9..8e2095a33 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/dataset.py @@ -249,5 +249,6 @@ class Dataset(classification_dataset.ClassificationDataset): len(valid_hand_data), len(label_names), ','.join(label_names))) return Dataset( dataset=hand_embedding_label_ds, + label_names=label_names, size=len(valid_hand_data), - label_names=label_names) + ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index 73d1d2f7c..a9d91e845 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. # Placeholder for internal Python library rule. licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset.py b/mediapipe/model_maker/python/vision/image_classifier/dataset.py index 6bc180be8..f627dfecc 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset.py @@ -15,28 +15,12 @@ import os import random - -from typing import List, Optional import tensorflow as tf -import tensorflow_datasets as tfds from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.vision.core import image_utils -def _create_data( - name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo, - label_names: List[str] -) -> Optional[classification_dataset.ClassificationDataset]: - """Creates a Dataset object from tfds data.""" - if name not in data: - return None - data = data[name] - data = data.map(lambda a: (a['image'], a['label'])) - size = info.splits[name].num_examples - return Dataset(data, size, label_names) - - class Dataset(classification_dataset.ClassificationDataset): """Dataset library for image classifier.""" @@ -99,4 +83,5 @@ class Dataset(classification_dataset.ClassificationDataset): 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, all_label_size, ', '.join(label_names)) return Dataset( - dataset=image_label_ds, size=all_image_size, label_names=label_names) + dataset=image_label_ds, label_names=label_names, size=all_image_size + ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py index 63fa666b3..33101382f 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py @@ -41,7 +41,7 @@ class DatasetTest(tf.test.TestCase): def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) - data = dataset.Dataset(dataset=ds, size=4, label_names=['pos', 'neg']) + data = dataset.Dataset(dataset=ds, label_names=['pos', 'neg'], size=4) train_data, test_data = data.split(fraction=0.5) self.assertLen(train_data, 2) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 4b1ea607f..71a47d9eb 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -52,8 +52,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): ds = tf.data.Dataset.from_generator( self._gen, (tf.uint8, tf.int64), (tf.TensorShape( [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) - data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3, - ['cyan', 'magenta', 'yellow']) + data = image_classifier.Dataset( + ds, ['cyan', 'magenta', 'yellow'], self.IMAGES_PER_CLASS * 3 + ) return data def setUp(self): diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index 75c08dbc8..14d378a19 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict binary and library compatibility macro. # Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) @@ -54,6 +54,7 @@ py_library( srcs = ["dataset.py"], deps = [ ":dataset_util", + "//mediapipe/model_maker/python/core/data:cache_files", "//mediapipe/model_maker/python/core/data:classification_dataset", ], ) @@ -73,6 +74,7 @@ py_test( py_library( name = "dataset_util", srcs = ["dataset_util.py"], + deps = ["//mediapipe/model_maker/python/core/data:cache_files"], ) py_test( diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset.py b/mediapipe/model_maker/python/vision/object_detector/dataset.py index c18a071b2..f7751915e 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset.py @@ -16,8 +16,8 @@ from typing import Optional import tensorflow as tf -import yaml +from mediapipe.model_maker.python.core.data import cache_files from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.vision.object_detector import dataset_util from official.vision.dataloaders import tf_example_decoder @@ -76,14 +76,16 @@ class Dataset(classification_dataset.ClassificationDataset): ValueError: If the label_name for id 0 is set to something other than the 'background' class. """ - cache_files = dataset_util.get_cache_files_coco(data_dir, cache_dir) - if not dataset_util.is_cached(cache_files): + tfrecord_cache_files = dataset_util.get_cache_files_coco( + data_dir, cache_dir + ) + if not tfrecord_cache_files.is_cached(): label_map = dataset_util.get_label_map_coco(data_dir) cache_writer = dataset_util.COCOCacheFilesWriter( label_map=label_map, max_num_images=max_num_images ) - cache_writer.write_files(cache_files, data_dir) - return cls.from_cache(cache_files.cache_prefix) + cache_writer.write_files(tfrecord_cache_files, data_dir) + return cls.from_cache(tfrecord_cache_files) @classmethod def from_pascal_voc_folder( @@ -134,47 +136,48 @@ class Dataset(classification_dataset.ClassificationDataset): Raises: ValueError: if the input data directory is empty. """ - cache_files = dataset_util.get_cache_files_pascal_voc(data_dir, cache_dir) - if not dataset_util.is_cached(cache_files): + tfrecord_cache_files = dataset_util.get_cache_files_pascal_voc( + data_dir, cache_dir + ) + if not tfrecord_cache_files.is_cached(): label_map = dataset_util.get_label_map_pascal_voc(data_dir) cache_writer = dataset_util.PascalVocCacheFilesWriter( label_map=label_map, max_num_images=max_num_images ) - cache_writer.write_files(cache_files, data_dir) + cache_writer.write_files(tfrecord_cache_files, data_dir) - return cls.from_cache(cache_files.cache_prefix) + return cls.from_cache(tfrecord_cache_files) @classmethod - def from_cache(cls, cache_prefix: str) -> 'Dataset': + def from_cache( + cls, tfrecord_cache_files: cache_files.TFRecordCacheFiles + ) -> 'Dataset': """Loads the TFRecord data from cache. Args: - cache_prefix: The cache prefix including the cache directory and the cache - prefix filename, e.g: '/tmp/cache/train'. + tfrecord_cache_files: The TFRecordCacheFiles object containing the already + cached TFRecord and metadata files. Returns: ObjectDetectorDataset object. + + Raises: + ValueError if tfrecord_cache_files are not already cached. """ - # Get TFRecord Files - tfrecord_file_pattern = cache_prefix + '*.tfrecord' - matched_files = tf.io.gfile.glob(tfrecord_file_pattern) - if not matched_files: - raise ValueError('TFRecord files are empty.') + if not tfrecord_cache_files.is_cached(): + raise ValueError( + 'Cache files must be already cached to use the from_cache method.' + ) - # Load meta_data. - meta_data_file = cache_prefix + dataset_util.META_DATA_FILE_SUFFIX - if not tf.io.gfile.exists(meta_data_file): - raise ValueError("Metadata file %s doesn't exist." % meta_data_file) - with tf.io.gfile.GFile(meta_data_file, 'r') as f: - meta_data = yaml.load(f, Loader=yaml.FullLoader) + metadata = tfrecord_cache_files.load_metadata() - dataset = tf.data.TFRecordDataset(matched_files) + dataset = tf.data.TFRecordDataset(tfrecord_cache_files.tfrecord_files) decoder = tf_example_decoder.TfExampleDecoder(regenerate_source_id=False) dataset = dataset.map(decoder.decode, num_parallel_calls=tf.data.AUTOTUNE) - label_map = meta_data['label_map'] + label_map = metadata['label_map'] label_names = [label_map[k] for k in sorted(label_map.keys())] return Dataset( - dataset=dataset, size=meta_data['size'], label_names=label_names + dataset=dataset, label_names=label_names, size=metadata['size'] ) diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py index 74d082f9f..fbb821b3b 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util.py @@ -15,25 +15,20 @@ import abc import collections -import dataclasses import hashlib import json import math import os import tempfile -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional import xml.etree.ElementTree as ET import tensorflow as tf -import yaml +from mediapipe.model_maker.python.core.data import cache_files from official.vision.data import tfrecord_lib -# Suffix of the meta data file name. -META_DATA_FILE_SUFFIX = '_meta_data.yaml' - - def _xml_get(node: ET.Element, name: str) -> ET.Element: """Gets a named child from an XML Element node. @@ -71,18 +66,9 @@ def _get_dir_basename(data_dir: str) -> str: return os.path.basename(os.path.abspath(data_dir)) -@dataclasses.dataclass(frozen=True) -class CacheFiles: - """Cache files for object detection.""" - - cache_prefix: str - tfrecord_files: Sequence[str] - meta_data_file: str - - def _get_cache_files( cache_dir: Optional[str], cache_prefix_filename: str, num_shards: int = 10 -) -> CacheFiles: +) -> cache_files.TFRecordCacheFiles: """Creates an object of CacheFiles class. Args: @@ -96,28 +82,16 @@ def _get_cache_files( An object of CacheFiles class. """ cache_dir = _get_cache_dir_or_create(cache_dir) - # The cache prefix including the cache directory and the cache prefix - # filename, e.g: '/tmp/cache/train'. - cache_prefix = os.path.join(cache_dir, cache_prefix_filename) - tf.compat.v1.logging.info( - 'Cache will be stored in %s with prefix filename %s. Cache_prefix is %s' - % (cache_dir, cache_prefix_filename, cache_prefix) - ) - - # Cached files including the TFRecord files and the meta data file. - tfrecord_files = [ - cache_prefix + '-%05d-of-%05d.tfrecord' % (i, num_shards) - for i in range(num_shards) - ] - meta_data_file = cache_prefix + META_DATA_FILE_SUFFIX - return CacheFiles( - cache_prefix=cache_prefix, - tfrecord_files=tuple(tfrecord_files), - meta_data_file=meta_data_file, + return cache_files.TFRecordCacheFiles( + cache_prefix_filename=cache_prefix_filename, + cache_dir=cache_dir, + num_shards=num_shards, ) -def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: +def get_cache_files_coco( + data_dir: str, cache_dir: str +) -> cache_files.TFRecordCacheFiles: """Creates an object of CacheFiles class using a COCO formatted dataset. Args: @@ -152,7 +126,9 @@ def get_cache_files_coco(data_dir: str, cache_dir: str) -> CacheFiles: return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) -def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: +def get_cache_files_pascal_voc( + data_dir: str, cache_dir: str +) -> cache_files.TFRecordCacheFiles: """Gets an object of CacheFiles using a PASCAL VOC formatted dataset. Args: @@ -181,14 +157,6 @@ def get_cache_files_pascal_voc(data_dir: str, cache_dir: str) -> CacheFiles: return _get_cache_files(cache_dir, cache_prefix_filename, num_shards) -def is_cached(cache_files: CacheFiles) -> bool: - """Checks whether cache files are already cached.""" - all_cached_files = list(cache_files.tfrecord_files) + [ - cache_files.meta_data_file - ] - return all(tf.io.gfile.exists(path) for path in all_cached_files) - - class CacheFilesWriter(abc.ABC): """CacheFilesWriter class to write the cached files.""" @@ -208,19 +176,22 @@ class CacheFilesWriter(abc.ABC): self.label_map = label_map self.max_num_images = max_num_images - def write_files(self, cache_files: CacheFiles, *args, **kwargs) -> None: - """Writes TFRecord and meta_data files. + def write_files( + self, + tfrecord_cache_files: cache_files.TFRecordCacheFiles, + *args, + **kwargs, + ) -> None: + """Writes TFRecord and metadata files. Args: - cache_files: CacheFiles object including a list of TFRecord files and the - meta data yaml file to save the meta_data including data size and - label_map. + tfrecord_cache_files: TFRecordCacheFiles object including a list of + TFRecord files and the meta data yaml file to save the metadata + including data size and label_map. *args: Non-keyword of parameters used in the `_get_example` method. **kwargs: Keyword parameters used in the `_get_example` method. """ - writers = [ - tf.io.TFRecordWriter(path) for path in cache_files.tfrecord_files - ] + writers = tfrecord_cache_files.get_writers() # Writes tf.Example into TFRecord files. size = 0 @@ -235,10 +206,9 @@ class CacheFilesWriter(abc.ABC): for writer in writers: writer.close() - # Writes meta_data into meta_data_file. - meta_data = {'size': size, 'label_map': self.label_map} - with tf.io.gfile.GFile(cache_files.meta_data_file, 'w') as f: - yaml.dump(meta_data, f) + # Writes metadata into metadata_file. + metadata = {'size': size, 'label_map': self.label_map} + tfrecord_cache_files.save_metadata(metadata) @abc.abstractmethod def _get_example(self, *args, **kwargs): diff --git a/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py b/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py index 6daea1f47..250c5d45e 100644 --- a/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/dataset_util_test.py @@ -19,7 +19,6 @@ import shutil from unittest import mock as unittest_mock import tensorflow as tf -import yaml from mediapipe.model_maker.python.vision.core import test_utils from mediapipe.model_maker.python.vision.object_detector import dataset_util @@ -30,13 +29,10 @@ class DatasetUtilTest(tf.test.TestCase): def _assert_cache_files_equal(self, cf1, cf2): self.assertEqual(cf1.cache_prefix, cf2.cache_prefix) - self.assertCountEqual(cf1.tfrecord_files, cf2.tfrecord_files) - self.assertEqual(cf1.meta_data_file, cf2.meta_data_file) + self.assertEqual(cf1.num_shards, cf2.num_shards) def _assert_cache_files_not_equal(self, cf1, cf2): self.assertNotEqual(cf1.cache_prefix, cf2.cache_prefix) - self.assertNotEqual(cf1.tfrecord_files, cf2.tfrecord_files) - self.assertNotEqual(cf1.meta_data_file, cf2.meta_data_file) def _get_cache_files_and_assert_neq_fn(self, cache_files_fn): def get_cache_files_and_assert_neq(cf, data_dir, cache_dir): @@ -57,7 +53,7 @@ class DatasetUtilTest(tf.test.TestCase): self.assertEqual( cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' ) - self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') + self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml') def test_matching_get_cache_files_coco(self): cache_dir = self.create_tempdir() @@ -118,7 +114,7 @@ class DatasetUtilTest(tf.test.TestCase): self.assertEqual( cache_files.tfrecord_files[0], '/tmp/train-00000-of-00001.tfrecord' ) - self.assertEqual(cache_files.meta_data_file, '/tmp/train_meta_data.yaml') + self.assertEqual(cache_files.metadata_file, '/tmp/train_metadata.yaml') def test_matching_get_cache_files_pascal_voc(self): cache_dir = self.create_tempdir() @@ -173,13 +169,13 @@ class DatasetUtilTest(tf.test.TestCase): cache_files = dataset_util.get_cache_files_coco( tasks_test_utils.get_test_data_path('coco_data'), cache_dir=tempdir ) - self.assertFalse(dataset_util.is_cached(cache_files)) + self.assertFalse(cache_files.is_cached()) with open(cache_files.tfrecord_files[0], 'w') as f: f.write('test') - self.assertFalse(dataset_util.is_cached(cache_files)) - with open(cache_files.meta_data_file, 'w') as f: + self.assertFalse(cache_files.is_cached()) + with open(cache_files.metadata_file, 'w') as f: f.write('test') - self.assertTrue(dataset_util.is_cached(cache_files)) + self.assertTrue(cache_files.is_cached()) def test_get_label_map_coco(self): coco_dir = tasks_test_utils.get_test_data_path('coco_data') @@ -203,13 +199,11 @@ class DatasetUtilTest(tf.test.TestCase): self.assertTrue(os.path.isfile(cache_files.tfrecord_files[0])) self.assertGreater(os.path.getsize(cache_files.tfrecord_files[0]), 0) - # Checks the meta_data file - self.assertTrue(os.path.isfile(cache_files.meta_data_file)) - self.assertGreater(os.path.getsize(cache_files.meta_data_file), 0) - with tf.io.gfile.GFile(cache_files.meta_data_file, 'r') as f: - meta_data_dict = yaml.load(f, Loader=yaml.FullLoader) - # Size is 3 because some examples are skipped for having poor bboxes - self.assertEqual(meta_data_dict['size'], expected_size) + # Checks the metadata file + self.assertTrue(os.path.isfile(cache_files.metadata_file)) + self.assertGreater(os.path.getsize(cache_files.metadata_file), 0) + metadata_dict = cache_files.load_metadata() + self.assertEqual(metadata_dict['size'], expected_size) def test_coco_cache_files_writer(self): tempdir = self.create_tempdir() diff --git a/mediapipe/model_maker/python/vision/object_detector/model.py b/mediapipe/model_maker/python/vision/object_detector/model.py index b1b4951fd..ea78ca8c6 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model.py +++ b/mediapipe/model_maker/python/vision/object_detector/model.py @@ -74,8 +74,8 @@ class ObjectDetectorModel(tf.keras.Model): generator_config: configs.retinanet.DetectionGenerator = configs.retinanet.DetectionGenerator(), ) -> configs.retinanet.RetinaNet: model_config = configs.retinanet.RetinaNet( - min_level=3, - max_level=7, + min_level=self._model_spec.min_level, + max_level=self._model_spec.max_level, num_classes=self._num_classes, input_size=self._model_spec.input_image_shape, anchor=configs.retinanet.Anchor( diff --git a/mediapipe/model_maker/python/vision/object_detector/model_spec.py b/mediapipe/model_maker/python/vision/object_detector/model_spec.py index 9c89c4ed0..ad043e872 100644 --- a/mediapipe/model_maker/python/vision/object_detector/model_spec.py +++ b/mediapipe/model_maker/python/vision/object_detector/model_spec.py @@ -20,18 +20,30 @@ from typing import List from mediapipe.model_maker.python.core.utils import file_util -MOBILENET_V2_FILES = file_util.DownloadedFiles( - 'object_detector/mobilenetv2', +MOBILENET_V2_I256_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetv2_i256', 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i256_ckpt.tar.gz', is_folder=True, ) +MOBILENET_V2_I320_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetv2_i320', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv2_ssd_i320_ckpt.tar.gz', + is_folder=True, +) + MOBILENET_MULTI_AVG_FILES = file_util.DownloadedFiles( 'object_detector/mobilenetmultiavg', 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv3.5_ssd_coco/mobilenetv3.5_ssd_i256_ckpt.tar.gz', is_folder=True, ) +MOBILENET_MULTI_AVG_I384_FILES = file_util.DownloadedFiles( + 'object_detector/mobilenetmultiavg_i384', + 'https://storage.googleapis.com/tf_model_garden/vision/qat/mobilenetv2_ssd_coco/mobilenetv3.5_ssd_i384_ckpt.tar.gz', + is_folder=True, +) + @dataclasses.dataclass class ModelSpec(object): @@ -48,30 +60,66 @@ class ModelSpec(object): input_image_shape: List[int] model_id: str + # Model Config values + min_level: int + max_level: int -mobilenet_v2_spec = functools.partial( + +mobilenet_v2_i256_spec = functools.partial( ModelSpec, - downloaded_files=MOBILENET_V2_FILES, + downloaded_files=MOBILENET_V2_I256_FILES, checkpoint_name='ckpt-277200', input_image_shape=[256, 256, 3], model_id='MobileNetV2', + min_level=3, + max_level=7, ) -mobilenet_multi_avg_spec = functools.partial( +mobilenet_v2_i320_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_V2_I320_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[320, 320, 3], + model_id='MobileNetV2', + min_level=3, + max_level=6, +) + +mobilenet_multi_avg_i256_spec = functools.partial( ModelSpec, downloaded_files=MOBILENET_MULTI_AVG_FILES, checkpoint_name='ckpt-277200', input_image_shape=[256, 256, 3], model_id='MobileNetMultiAVG', + min_level=3, + max_level=7, +) + +mobilenet_multi_avg_i384_spec = functools.partial( + ModelSpec, + downloaded_files=MOBILENET_MULTI_AVG_I384_FILES, + checkpoint_name='ckpt-277200', + input_image_shape=[384, 384, 3], + model_id='MobileNetMultiAVG', + min_level=3, + max_level=7, ) @enum.unique class SupportedModels(enum.Enum): - """Predefined object detector model specs supported by Model Maker.""" + """Predefined object detector model specs supported by Model Maker. - MOBILENET_V2 = mobilenet_v2_spec - MOBILENET_MULTI_AVG = mobilenet_multi_avg_spec + Supported models include the following: + - MOBILENET_V2: MobileNetV2 256x256 input + - MOBILENET_V2_I320: MobileNetV2 320x320 input + - MOBILENET_MULTI_AVG: MobileNet-MultiHW-AVG 256x256 input + - MOBILENET_MULTI_AVG_I384: MobileNet-MultiHW-AVG 384x384 input + """ + MOBILENET_V2 = mobilenet_v2_i256_spec + MOBILENET_V2_I320 = mobilenet_v2_i320_spec + MOBILENET_MULTI_AVG = mobilenet_multi_avg_i256_spec + MOBILENET_MULTI_AVG_I384 = mobilenet_multi_avg_i384_spec @classmethod def get(cls, spec: 'SupportedModels') -> 'ModelSpec': diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector.py b/mediapipe/model_maker/python/vision/object_detector/object_detector.py index 486c3ffa9..6c7b9811c 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector.py @@ -395,7 +395,7 @@ class ObjectDetector(classifier.Classifier): ) -> tf.keras.optimizers.Optimizer: """Creates an optimizer with learning rate schedule for regular training. - Uses Keras PiecewiseConstantDecay schedule by default. + Uses Keras CosineDecay schedule by default. Args: steps_per_epoch: Steps per epoch to calculate the step boundaries from the @@ -404,6 +404,8 @@ class ObjectDetector(classifier.Classifier): Returns: A tf.keras.optimizer.Optimizer for model training. """ + total_steps = steps_per_epoch * self._hparams.epochs + warmup_steps = int(total_steps * 0.1) init_lr = self._hparams.learning_rate * self._hparams.batch_size / 256 decay_epochs = ( self._hparams.cosine_decay_epochs @@ -415,6 +417,11 @@ class ObjectDetector(classifier.Classifier): steps_per_epoch * decay_epochs, self._hparams.cosine_decay_alpha, ) + learning_rate = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate, + warmup_steps=warmup_steps, + ) return tf.keras.optimizers.experimental.SGD( learning_rate=learning_rate, momentum=0.9 ) diff --git a/mediapipe/model_maker/python/vision/object_detector/preprocessor.py b/mediapipe/model_maker/python/vision/object_detector/preprocessor.py index ebea6a07b..1388cc7df 100644 --- a/mediapipe/model_maker/python/vision/object_detector/preprocessor.py +++ b/mediapipe/model_maker/python/vision/object_detector/preprocessor.py @@ -32,8 +32,8 @@ class Preprocessor(object): self._mean_norm = model_spec.mean_norm self._stddev_norm = model_spec.stddev_norm self._output_size = model_spec.input_image_shape[:2] - self._min_level = 3 - self._max_level = 7 + self._min_level = model_spec.min_level + self._max_level = model_spec.max_level self._num_scales = 3 self._aspect_ratios = [0.5, 1, 2] self._anchor_size = 3 diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 5c78dc582..a1c975c1e 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -3,6 +3,7 @@ mediapipe>=0.10.0 numpy opencv-python tensorflow>=2.10 +tensorflow-addons tensorflow-datasets tensorflow-hub -tf-models-official==2.11.6 +tf-models-official>=2.13.1 diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index 1b8b173f7..a1acc0be2 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -13,17 +13,17 @@ # limitations under the License. """MediaPipe solution drawing utils.""" +import dataclasses import math from typing import List, Mapping, Optional, Tuple, Union import cv2 -import dataclasses import matplotlib.pyplot as plt import numpy as np from mediapipe.framework.formats import detection_pb2 -from mediapipe.framework.formats import location_data_pb2 from mediapipe.framework.formats import landmark_pb2 +from mediapipe.framework.formats import location_data_pb2 _PRESENCE_THRESHOLD = 0.5 _VISIBILITY_THRESHOLD = 0.5 diff --git a/mediapipe/python/solutions/drawing_utils_test.py b/mediapipe/python/solutions/drawing_utils_test.py index 0039f9a90..8943a0581 100644 --- a/mediapipe/python/solutions/drawing_utils_test.py +++ b/mediapipe/python/solutions/drawing_utils_test.py @@ -20,7 +20,6 @@ import cv2 import numpy as np from google.protobuf import text_format - from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import landmark_pb2 from mediapipe.python.solutions import drawing_utils diff --git a/mediapipe/tasks/c/components/containers/BUILD b/mediapipe/tasks/c/components/containers/BUILD new file mode 100644 index 000000000..4d1f190bb --- /dev/null +++ b/mediapipe/tasks/c/components/containers/BUILD @@ -0,0 +1,29 @@ +# TODO: describe this package. + +# Copyright 2022 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "category", + hdrs = ["category.h"], +) + +cc_library( + name = "classification_result", + hdrs = ["classification_result.h"], +) diff --git a/mediapipe/tasks/c/components/containers/category.h b/mediapipe/tasks/c/components/containers/category.h new file mode 100644 index 000000000..565dd65fe --- /dev/null +++ b/mediapipe/tasks/c/components/containers/category.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ + +// Defines a single classification result. +// +// The label maps packed into the TFLite Model Metadata [1] are used to populate +// the 'category_name' and 'display_name' fields. +// +// [1]: https://www.tensorflow.org/lite/convert/metadata +struct Category { + // The index of the category in the classification model output. + int index; + + // The score for this category, e.g. (but not necessarily) a probability in + // [0,1]. + float score; + + // The optional ID for the category, read from the label map packed in the + // TFLite Model Metadata if present. Not necessarily human-readable. + char* category_name; + + // The optional human-readable name for the category, read from the label map + // packed in the TFLite Model Metadata if present. + char* display_name; +}; + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ diff --git a/mediapipe/tasks/c/components/containers/classification_result.h b/mediapipe/tasks/c/components/containers/classification_result.h new file mode 100644 index 000000000..540ab4464 --- /dev/null +++ b/mediapipe/tasks/c/components/containers/classification_result.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ + +#include +#include + +// Defines classification results for a given classifier head. +struct Classifications { + // The array of predicted categories, usually sorted by descending scores, + // e.g. from high to low probability. + struct Category* categories; + // The number of elements in the categories array. + uint32_t categories_count; + + // The index of the classifier head (i.e. output tensor) these categories + // refer to. This is useful for multi-head models. + int head_index; + + // The optional name of the classifier head, as provided in the TFLite Model + // Metadata [1] if present. This is useful for multi-head models. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + char* head_name; +}; + +// Defines classification results of a model. +struct ClassificationResult { + // The classification results for each head of the model. + struct Classifications* classifications; + // The number of classifications in the classifications array. + uint32_t classifications_count; + + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for classification on time series (e.g. audio + // classification). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + int64_t timestamp_ms; + // Specifies whether the timestamp contains a valid value. + bool has_timestamp_ms; +}; + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ diff --git a/mediapipe/tasks/c/components/processors/BUILD b/mediapipe/tasks/c/components/processors/BUILD new file mode 100644 index 000000000..24d3a181e --- /dev/null +++ b/mediapipe/tasks/c/components/processors/BUILD @@ -0,0 +1,22 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "classifier_options", + hdrs = ["classifier_options.h"], +) diff --git a/mediapipe/tasks/c/components/processors/classifier_options.h b/mediapipe/tasks/c/components/processors/classifier_options.h new file mode 100644 index 000000000..4cce2ce69 --- /dev/null +++ b/mediapipe/tasks/c/components/processors/classifier_options.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ + +#include + +// Classifier options for MediaPipe C classification Tasks. +struct ClassifierOptions { + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + char* display_names_locale; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + int max_results; + + // Score threshold to override the one provided in the model metadata (if + // any). Results below this value are rejected. + float score_threshold; + + // The allowlist of category names. If non-empty, detection results whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + char** category_allowlist; + // The number of elements in the category allowlist. + uint32_t category_allowlist_count; + + // The denylist of category names. If non-empty, detection results whose + // category name is in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_allowlist. + char** category_denylist = {}; + // The number of elements in the category denylist. + uint32_t category_denylist_count; +}; + +#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/c/core/BUILD b/mediapipe/tasks/c/core/BUILD new file mode 100644 index 000000000..60d10857f --- /dev/null +++ b/mediapipe/tasks/c/core/BUILD @@ -0,0 +1,22 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "base_options", + hdrs = ["base_options.h"], +) diff --git a/mediapipe/tasks/c/core/base_options.h b/mediapipe/tasks/c/core/base_options.h new file mode 100644 index 000000000..f5f6b0318 --- /dev/null +++ b/mediapipe/tasks/c/core/base_options.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ +#define MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ + +// Base options for MediaPipe C Tasks. +struct BaseOptions { + // The model asset file contents as a string. + char* model_asset_buffer; + + // The path to the model asset to open and mmap in memory. + char* model_asset_path; +}; + +#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ diff --git a/mediapipe/tasks/c/text/text_classifier/BUILD b/mediapipe/tasks/c/text/text_classifier/BUILD new file mode 100644 index 000000000..0402689c7 --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/BUILD @@ -0,0 +1,28 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_classifier", + hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/c/components/containers:classification_result", + "//mediapipe/tasks/c/components/processors:classifier_options", + "//mediapipe/tasks/c/core:base_options", + ], +) diff --git a/mediapipe/tasks/c/text/text_classifier/text_classifier.h b/mediapipe/tasks/c/text/text_classifier/text_classifier.h new file mode 100644 index 000000000..7439644b8 --- /dev/null +++ b/mediapipe/tasks/c/text/text_classifier/text_classifier.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#define MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ + +#include "mediapipe/tasks/c/components/containers/classification_result.h" +#include "mediapipe/tasks/c/components/processors/classifier_options.h" +#include "mediapipe/tasks/c/core/base_options.h" + +typedef ClassificationResult TextClassifierResult; + +// The options for configuring a MediaPipe text classifier task. +struct TextClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + struct ClassifierOptions classifier_options; +}; + +// Creates a TextClassifier from the provided `options`. +void* text_classifier_create(struct TextClassifierOptions options); + +// Performs classification on the input `text`. +TextClassifierResult text_classifier_classify(void* classifier, + char* utf8_text); + +// Shuts down the TextClassifier when all the work is done. Frees all memory. +void text_classifier_close(void* classifier); + +#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/core/base_options.cc b/mediapipe/tasks/cc/core/base_options.cc index b7987f982..863338fe5 100644 --- a/mediapipe/tasks/cc/core/base_options.cc +++ b/mediapipe/tasks/cc/core/base_options.cc @@ -41,9 +41,15 @@ proto::Acceleration ConvertDelegateOptionsToAccelerationProto( proto::Acceleration acceleration_proto = proto::Acceleration(); auto* gpu = acceleration_proto.mutable_gpu(); gpu->set_use_advanced_gpu_api(true); - gpu->set_cached_kernel_path(options.cached_kernel_path); - gpu->set_serialized_model_dir(options.serialized_model_dir); - gpu->set_model_token(options.model_token); + if (!options.cached_kernel_path.empty()) { + gpu->set_cached_kernel_path(options.cached_kernel_path); + } + if (!options.serialized_model_dir.empty()) { + gpu->set_serialized_model_dir(options.serialized_model_dir); + } + if (!options.model_token.empty()) { + gpu->set_model_token(options.model_token); + } return acceleration_proto; } diff --git a/mediapipe/tasks/cc/core/base_options_test.cc b/mediapipe/tasks/cc/core/base_options_test.cc index af9a55a37..390663515 100644 --- a/mediapipe/tasks/cc/core/base_options_test.cc +++ b/mediapipe/tasks/cc/core/base_options_test.cc @@ -59,14 +59,15 @@ TEST(DelegateOptionsTest, SucceedGpuOptions) { BaseOptions base_options; base_options.delegate = BaseOptions::Delegate::GPU; BaseOptions::GpuOptions gpu_options; - gpu_options.cached_kernel_path = kCachedModelDir; + gpu_options.serialized_model_dir = kCachedModelDir; gpu_options.model_token = kModelToken; base_options.delegate_options = gpu_options; proto::BaseOptions proto = ConvertBaseOptionsToProto(&base_options); ASSERT_TRUE(proto.acceleration().has_gpu()); ASSERT_FALSE(proto.acceleration().has_tflite()); EXPECT_TRUE(proto.acceleration().gpu().use_advanced_gpu_api()); - EXPECT_EQ(proto.acceleration().gpu().cached_kernel_path(), kCachedModelDir); + EXPECT_FALSE(proto.acceleration().gpu().has_cached_kernel_path()); + EXPECT_EQ(proto.acceleration().gpu().serialized_model_dir(), kCachedModelDir); EXPECT_EQ(proto.acceleration().gpu().model_token(), kModelToken); } diff --git a/mediapipe/tasks/cc/vision/face_landmarker/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/BUILD index 16de2271a..36c4bf551 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/BUILD @@ -217,3 +217,8 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "face_landmarks_connections", + hdrs = ["face_landmarks_connections.h"], +) diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h new file mode 100644 index 000000000..360083a7f --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_connections.h @@ -0,0 +1,651 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKS_CONNECTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKS_CONNECTIONS_H_ + +#include + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_landmarker { + +struct FaceLandmarksConnections { + static constexpr std::array, 40> kFaceLandmarksLips{ + {{61, 146}, {146, 91}, {91, 181}, {181, 84}, {84, 17}, {17, 314}, + {314, 405}, {405, 321}, {321, 375}, {375, 291}, {61, 185}, {185, 40}, + {40, 39}, {39, 37}, {37, 0}, {0, 267}, {267, 269}, {269, 270}, + {270, 409}, {409, 291}, {78, 95}, {95, 88}, {88, 178}, {178, 87}, + {87, 14}, {14, 317}, {317, 402}, {402, 318}, {318, 324}, {324, 308}, + {78, 191}, {191, 80}, {80, 81}, {81, 82}, {82, 13}, {13, 312}, + {312, 311}, {311, 310}, {310, 415}, {415, 308}}}; + + static constexpr std::array, 16> kFaceLandmarksLeftEye{ + {{263, 249}, + {249, 390}, + {390, 373}, + {373, 374}, + {374, 380}, + {380, 381}, + {381, 382}, + {382, 362}, + {263, 466}, + {466, 388}, + {388, 387}, + {387, 386}, + {386, 385}, + {385, 384}, + {384, 398}, + {398, 362}}}; + + static constexpr std::array, 8> kFaceLandmarksLeftEyeBrow{ + {{276, 283}, + {283, 282}, + {282, 295}, + {295, 285}, + {300, 293}, + {293, 334}, + {334, 296}, + {296, 336}}}; + + static constexpr std::array, 4> kFaceLandmarksLeftIris{ + {{474, 475}, {475, 476}, {476, 477}, {477, 474}}}; + + static constexpr std::array, 16> kFaceLandmarksRightEye{ + {{33, 7}, + {7, 163}, + {163, 144}, + {144, 145}, + {145, 153}, + {153, 154}, + {154, 155}, + {155, 133}, + {33, 246}, + {246, 161}, + {161, 160}, + {160, 159}, + {159, 158}, + {158, 157}, + {157, 173}, + {173, 133}}}; + + static constexpr std::array, 8> kFaceLandmarksRightEyeBrow{ + {{46, 53}, + {53, 52}, + {52, 65}, + {65, 55}, + {70, 63}, + {63, 105}, + {105, 66}, + {66, 107}}}; + + static constexpr std::array, 4> kFaceLandmarksRightIris{ + {{469, 470}, {470, 471}, {471, 472}, {472, 469}}}; + + static constexpr std::array, 36> kFaceLandmarksFaceOval{ + {{10, 338}, {338, 297}, {297, 332}, {332, 284}, {284, 251}, {251, 389}, + {389, 356}, {356, 454}, {454, 323}, {323, 361}, {361, 288}, {288, 397}, + {397, 365}, {365, 379}, {379, 378}, {378, 400}, {400, 377}, {377, 152}, + {152, 148}, {148, 176}, {176, 149}, {149, 150}, {150, 136}, {136, 172}, + {172, 58}, {58, 132}, {132, 93}, {93, 234}, {234, 127}, {127, 162}, + {162, 21}, {21, 54}, {54, 103}, {103, 67}, {67, 109}, {109, 10}}}; + + // Lips + Left Eye + Left Eye Brows + Right Eye + Right Eye Brows + Face Oval. + static constexpr std::array, 132> kFaceLandmarksConnectors{ + {{61, 146}, {146, 91}, {91, 181}, {181, 84}, {84, 17}, {17, 314}, + {314, 405}, {405, 321}, {321, 375}, {375, 291}, {61, 185}, {185, 40}, + {40, 39}, {39, 37}, {37, 0}, {0, 267}, {267, 269}, {269, 270}, + {270, 409}, {409, 291}, {78, 95}, {95, 88}, {88, 178}, {178, 87}, + {87, 14}, {14, 317}, {317, 402}, {402, 318}, {318, 324}, {324, 308}, + {78, 191}, {191, 80}, {80, 81}, {81, 82}, {82, 13}, {13, 312}, + {312, 311}, {311, 310}, {310, 415}, {415, 30}, {263, 249}, {249, 390}, + {390, 373}, {373, 374}, {374, 380}, {380, 381}, {381, 382}, {382, 362}, + {263, 466}, {466, 388}, {388, 387}, {387, 386}, {386, 385}, {385, 384}, + {384, 398}, {398, 362}, {276, 283}, {283, 282}, {282, 295}, {295, 285}, + {300, 293}, {293, 334}, {334, 296}, {296, 336}, {33, 7}, {7, 163}, + {163, 144}, {144, 145}, {145, 153}, {153, 154}, {154, 155}, {155, 133}, + {33, 246}, {246, 161}, {161, 160}, {160, 159}, {159, 158}, {158, 157}, + {157, 173}, {173, 13}, {46, 53}, {53, 52}, {52, 65}, {65, 55}, + {70, 63}, {63, 105}, {105, 66}, {66, 107}, {10, 338}, {338, 297}, + {297, 332}, {332, 284}, {284, 251}, {251, 389}, {389, 356}, {356, 454}, + {454, 323}, {323, 361}, {361, 288}, {288, 397}, {397, 365}, {365, 379}, + {379, 378}, {378, 400}, {400, 377}, {377, 152}, {152, 148}, {148, 176}, + {176, 149}, {149, 150}, {150, 136}, {136, 172}, {172, 58}, {58, 132}, + {132, 93}, {93, 234}, {234, 127}, {127, 162}, {162, 21}, {21, 54}, + {54, 103}, {103, 67}, {67, 109}, {109, 10}}}; + + static constexpr std::array, 2556> + kFaceLandmarksTesselation{ + {{127, 34}, {34, 139}, {139, 127}, {11, 0}, {0, 37}, + {37, 11}, {232, 231}, {231, 120}, {120, 232}, {72, 37}, + {37, 39}, {39, 72}, {128, 121}, {121, 47}, {47, 128}, + {232, 121}, {121, 128}, {128, 232}, {104, 69}, {69, 67}, + {67, 104}, {175, 171}, {171, 148}, {148, 175}, {118, 50}, + {50, 101}, {101, 118}, {73, 39}, {39, 40}, {40, 73}, + {9, 151}, {151, 108}, {108, 9}, {48, 115}, {115, 131}, + {131, 48}, {194, 204}, {204, 211}, {211, 194}, {74, 40}, + {40, 185}, {185, 74}, {80, 42}, {42, 183}, {183, 80}, + {40, 92}, {92, 186}, {186, 40}, {230, 229}, {229, 118}, + {118, 230}, {202, 212}, {212, 214}, {214, 202}, {83, 18}, + {18, 17}, {17, 83}, {76, 61}, {61, 146}, {146, 76}, + {160, 29}, {29, 30}, {30, 160}, {56, 157}, {157, 173}, + {173, 56}, {106, 204}, {204, 194}, {194, 106}, {135, 214}, + {214, 192}, {192, 135}, {203, 165}, {165, 98}, {98, 203}, + {21, 71}, {71, 68}, {68, 21}, {51, 45}, {45, 4}, + {4, 51}, {144, 24}, {24, 23}, {23, 144}, {77, 146}, + {146, 91}, {91, 77}, {205, 50}, {50, 187}, {187, 205}, + {201, 200}, {200, 18}, {18, 201}, {91, 106}, {106, 182}, + {182, 91}, {90, 91}, {91, 181}, {181, 90}, {85, 84}, + {84, 17}, {17, 85}, {206, 203}, {203, 36}, {36, 206}, + {148, 171}, {171, 140}, {140, 148}, {92, 40}, {40, 39}, + {39, 92}, {193, 189}, {189, 244}, {244, 193}, {159, 158}, + {158, 28}, {28, 159}, {247, 246}, {246, 161}, {161, 247}, + {236, 3}, {3, 196}, {196, 236}, {54, 68}, {68, 104}, + {104, 54}, {193, 168}, {168, 8}, {8, 193}, {117, 228}, + {228, 31}, {31, 117}, {189, 193}, {193, 55}, {55, 189}, + {98, 97}, {97, 99}, {99, 98}, {126, 47}, {47, 100}, + {100, 126}, {166, 79}, {79, 218}, {218, 166}, {155, 154}, + {154, 26}, {26, 155}, {209, 49}, {49, 131}, {131, 209}, + {135, 136}, {136, 150}, {150, 135}, {47, 126}, {126, 217}, + {217, 47}, {223, 52}, {52, 53}, {53, 223}, {45, 51}, + {51, 134}, {134, 45}, {211, 170}, {170, 140}, {140, 211}, + {67, 69}, {69, 108}, {108, 67}, {43, 106}, {106, 91}, + {91, 43}, {230, 119}, {119, 120}, {120, 230}, {226, 130}, + {130, 247}, {247, 226}, {63, 53}, {53, 52}, {52, 63}, + {238, 20}, {20, 242}, {242, 238}, {46, 70}, {70, 156}, + {156, 46}, {78, 62}, {62, 96}, {96, 78}, {46, 53}, + {53, 63}, {63, 46}, {143, 34}, {34, 227}, {227, 143}, + {123, 117}, {117, 111}, {111, 123}, {44, 125}, {125, 19}, + {19, 44}, {236, 134}, {134, 51}, {51, 236}, {216, 206}, + {206, 205}, {205, 216}, {154, 153}, {153, 22}, {22, 154}, + {39, 37}, {37, 167}, {167, 39}, {200, 201}, {201, 208}, + {208, 200}, {36, 142}, {142, 100}, {100, 36}, {57, 212}, + {212, 202}, {202, 57}, {20, 60}, {60, 99}, {99, 20}, + {28, 158}, {158, 157}, {157, 28}, {35, 226}, {226, 113}, + {113, 35}, {160, 159}, {159, 27}, {27, 160}, {204, 202}, + {202, 210}, {210, 204}, {113, 225}, {225, 46}, {46, 113}, + {43, 202}, {202, 204}, {204, 43}, {62, 76}, {76, 77}, + {77, 62}, {137, 123}, {123, 116}, {116, 137}, {41, 38}, + {38, 72}, {72, 41}, {203, 129}, {129, 142}, {142, 203}, + {64, 98}, {98, 240}, {240, 64}, {49, 102}, {102, 64}, + {64, 49}, {41, 73}, {73, 74}, {74, 41}, {212, 216}, + {216, 207}, {207, 212}, {42, 74}, {74, 184}, {184, 42}, + {169, 170}, {170, 211}, {211, 169}, {170, 149}, {149, 176}, + {176, 170}, {105, 66}, {66, 69}, {69, 105}, {122, 6}, + {6, 168}, {168, 122}, {123, 147}, {147, 187}, {187, 123}, + {96, 77}, {77, 90}, {90, 96}, {65, 55}, {55, 107}, + {107, 65}, {89, 90}, {90, 180}, {180, 89}, {101, 100}, + {100, 120}, {120, 101}, {63, 105}, {105, 104}, {104, 63}, + {93, 137}, {137, 227}, {227, 93}, {15, 86}, {86, 85}, + {85, 15}, {129, 102}, {102, 49}, {49, 129}, {14, 87}, + {87, 86}, {86, 14}, {55, 8}, {8, 9}, {9, 55}, + {100, 47}, {47, 121}, {121, 100}, {145, 23}, {23, 22}, + {22, 145}, {88, 89}, {89, 179}, {179, 88}, {6, 122}, + {122, 196}, {196, 6}, {88, 95}, {95, 96}, {96, 88}, + {138, 172}, {172, 136}, {136, 138}, {215, 58}, {58, 172}, + {172, 215}, {115, 48}, {48, 219}, {219, 115}, {42, 80}, + {80, 81}, {81, 42}, {195, 3}, {3, 51}, {51, 195}, + {43, 146}, {146, 61}, {61, 43}, {171, 175}, {175, 199}, + {199, 171}, {81, 82}, {82, 38}, {38, 81}, {53, 46}, + {46, 225}, {225, 53}, {144, 163}, {163, 110}, {110, 144}, + {52, 65}, {65, 66}, {66, 52}, {229, 228}, {228, 117}, + {117, 229}, {34, 127}, {127, 234}, {234, 34}, {107, 108}, + {108, 69}, {69, 107}, {109, 108}, {108, 151}, {151, 109}, + {48, 64}, {64, 235}, {235, 48}, {62, 78}, {78, 191}, + {191, 62}, {129, 209}, {209, 126}, {126, 129}, {111, 35}, + {35, 143}, {143, 111}, {117, 123}, {123, 50}, {50, 117}, + {222, 65}, {65, 52}, {52, 222}, {19, 125}, {125, 141}, + {141, 19}, {221, 55}, {55, 65}, {65, 221}, {3, 195}, + {195, 197}, {197, 3}, {25, 7}, {7, 33}, {33, 25}, + {220, 237}, {237, 44}, {44, 220}, {70, 71}, {71, 139}, + {139, 70}, {122, 193}, {193, 245}, {245, 122}, {247, 130}, + {130, 33}, {33, 247}, {71, 21}, {21, 162}, {162, 71}, + {170, 169}, {169, 150}, {150, 170}, {188, 174}, {174, 196}, + {196, 188}, {216, 186}, {186, 92}, {92, 216}, {2, 97}, + {97, 167}, {167, 2}, {141, 125}, {125, 241}, {241, 141}, + {164, 167}, {167, 37}, {37, 164}, {72, 38}, {38, 12}, + {12, 72}, {38, 82}, {82, 13}, {13, 38}, {63, 68}, + {68, 71}, {71, 63}, {226, 35}, {35, 111}, {111, 226}, + {101, 50}, {50, 205}, {205, 101}, {206, 92}, {92, 165}, + {165, 206}, {209, 198}, {198, 217}, {217, 209}, {165, 167}, + {167, 97}, {97, 165}, {220, 115}, {115, 218}, {218, 220}, + {133, 112}, {112, 243}, {243, 133}, {239, 238}, {238, 241}, + {241, 239}, {214, 135}, {135, 169}, {169, 214}, {190, 173}, + {173, 133}, {133, 190}, {171, 208}, {208, 32}, {32, 171}, + {125, 44}, {44, 237}, {237, 125}, {86, 87}, {87, 178}, + {178, 86}, {85, 86}, {86, 179}, {179, 85}, {84, 85}, + {85, 180}, {180, 84}, {83, 84}, {84, 181}, {181, 83}, + {201, 83}, {83, 182}, {182, 201}, {137, 93}, {93, 132}, + {132, 137}, {76, 62}, {62, 183}, {183, 76}, {61, 76}, + {76, 184}, {184, 61}, {57, 61}, {61, 185}, {185, 57}, + {212, 57}, {57, 186}, {186, 212}, {214, 207}, {207, 187}, + {187, 214}, {34, 143}, {143, 156}, {156, 34}, {79, 239}, + {239, 237}, {237, 79}, {123, 137}, {137, 177}, {177, 123}, + {44, 1}, {1, 4}, {4, 44}, {201, 194}, {194, 32}, + {32, 201}, {64, 102}, {102, 129}, {129, 64}, {213, 215}, + {215, 138}, {138, 213}, {59, 166}, {166, 219}, {219, 59}, + {242, 99}, {99, 97}, {97, 242}, {2, 94}, {94, 141}, + {141, 2}, {75, 59}, {59, 235}, {235, 75}, {24, 110}, + {110, 228}, {228, 24}, {25, 130}, {130, 226}, {226, 25}, + {23, 24}, {24, 229}, {229, 23}, {22, 23}, {23, 230}, + {230, 22}, {26, 22}, {22, 231}, {231, 26}, {112, 26}, + {26, 232}, {232, 112}, {189, 190}, {190, 243}, {243, 189}, + {221, 56}, {56, 190}, {190, 221}, {28, 56}, {56, 221}, + {221, 28}, {27, 28}, {28, 222}, {222, 27}, {29, 27}, + {27, 223}, {223, 29}, {30, 29}, {29, 224}, {224, 30}, + {247, 30}, {30, 225}, {225, 247}, {238, 79}, {79, 20}, + {20, 238}, {166, 59}, {59, 75}, {75, 166}, {60, 75}, + {75, 240}, {240, 60}, {147, 177}, {177, 215}, {215, 147}, + {20, 79}, {79, 166}, {166, 20}, {187, 147}, {147, 213}, + {213, 187}, {112, 233}, {233, 244}, {244, 112}, {233, 128}, + {128, 245}, {245, 233}, {128, 114}, {114, 188}, {188, 128}, + {114, 217}, {217, 174}, {174, 114}, {131, 115}, {115, 220}, + {220, 131}, {217, 198}, {198, 236}, {236, 217}, {198, 131}, + {131, 134}, {134, 198}, {177, 132}, {132, 58}, {58, 177}, + {143, 35}, {35, 124}, {124, 143}, {110, 163}, {163, 7}, + {7, 110}, {228, 110}, {110, 25}, {25, 228}, {356, 389}, + {389, 368}, {368, 356}, {11, 302}, {302, 267}, {267, 11}, + {452, 350}, {350, 349}, {349, 452}, {302, 303}, {303, 269}, + {269, 302}, {357, 343}, {343, 277}, {277, 357}, {452, 453}, + {453, 357}, {357, 452}, {333, 332}, {332, 297}, {297, 333}, + {175, 152}, {152, 377}, {377, 175}, {347, 348}, {348, 330}, + {330, 347}, {303, 304}, {304, 270}, {270, 303}, {9, 336}, + {336, 337}, {337, 9}, {278, 279}, {279, 360}, {360, 278}, + {418, 262}, {262, 431}, {431, 418}, {304, 408}, {408, 409}, + {409, 304}, {310, 415}, {415, 407}, {407, 310}, {270, 409}, + {409, 410}, {410, 270}, {450, 348}, {348, 347}, {347, 450}, + {422, 430}, {430, 434}, {434, 422}, {313, 314}, {314, 17}, + {17, 313}, {306, 307}, {307, 375}, {375, 306}, {387, 388}, + {388, 260}, {260, 387}, {286, 414}, {414, 398}, {398, 286}, + {335, 406}, {406, 418}, {418, 335}, {364, 367}, {367, 416}, + {416, 364}, {423, 358}, {358, 327}, {327, 423}, {251, 284}, + {284, 298}, {298, 251}, {281, 5}, {5, 4}, {4, 281}, + {373, 374}, {374, 253}, {253, 373}, {307, 320}, {320, 321}, + {321, 307}, {425, 427}, {427, 411}, {411, 425}, {421, 313}, + {313, 18}, {18, 421}, {321, 405}, {405, 406}, {406, 321}, + {320, 404}, {404, 405}, {405, 320}, {315, 16}, {16, 17}, + {17, 315}, {426, 425}, {425, 266}, {266, 426}, {377, 400}, + {400, 369}, {369, 377}, {322, 391}, {391, 269}, {269, 322}, + {417, 465}, {465, 464}, {464, 417}, {386, 257}, {257, 258}, + {258, 386}, {466, 260}, {260, 388}, {388, 466}, {456, 399}, + {399, 419}, {419, 456}, {284, 332}, {332, 333}, {333, 284}, + {417, 285}, {285, 8}, {8, 417}, {346, 340}, {340, 261}, + {261, 346}, {413, 441}, {441, 285}, {285, 413}, {327, 460}, + {460, 328}, {328, 327}, {355, 371}, {371, 329}, {329, 355}, + {392, 439}, {439, 438}, {438, 392}, {382, 341}, {341, 256}, + {256, 382}, {429, 420}, {420, 360}, {360, 429}, {364, 394}, + {394, 379}, {379, 364}, {277, 343}, {343, 437}, {437, 277}, + {443, 444}, {444, 283}, {283, 443}, {275, 440}, {440, 363}, + {363, 275}, {431, 262}, {262, 369}, {369, 431}, {297, 338}, + {338, 337}, {337, 297}, {273, 375}, {375, 321}, {321, 273}, + {450, 451}, {451, 349}, {349, 450}, {446, 342}, {342, 467}, + {467, 446}, {293, 334}, {334, 282}, {282, 293}, {458, 461}, + {461, 462}, {462, 458}, {276, 353}, {353, 383}, {383, 276}, + {308, 324}, {324, 325}, {325, 308}, {276, 300}, {300, 293}, + {293, 276}, {372, 345}, {345, 447}, {447, 372}, {352, 345}, + {345, 340}, {340, 352}, {274, 1}, {1, 19}, {19, 274}, + {456, 248}, {248, 281}, {281, 456}, {436, 427}, {427, 425}, + {425, 436}, {381, 256}, {256, 252}, {252, 381}, {269, 391}, + {391, 393}, {393, 269}, {200, 199}, {199, 428}, {428, 200}, + {266, 330}, {330, 329}, {329, 266}, {287, 273}, {273, 422}, + {422, 287}, {250, 462}, {462, 328}, {328, 250}, {258, 286}, + {286, 384}, {384, 258}, {265, 353}, {353, 342}, {342, 265}, + {387, 259}, {259, 257}, {257, 387}, {424, 431}, {431, 430}, + {430, 424}, {342, 353}, {353, 276}, {276, 342}, {273, 335}, + {335, 424}, {424, 273}, {292, 325}, {325, 307}, {307, 292}, + {366, 447}, {447, 345}, {345, 366}, {271, 303}, {303, 302}, + {302, 271}, {423, 266}, {266, 371}, {371, 423}, {294, 455}, + {455, 460}, {460, 294}, {279, 278}, {278, 294}, {294, 279}, + {271, 272}, {272, 304}, {304, 271}, {432, 434}, {434, 427}, + {427, 432}, {272, 407}, {407, 408}, {408, 272}, {394, 430}, + {430, 431}, {431, 394}, {395, 369}, {369, 400}, {400, 395}, + {334, 333}, {333, 299}, {299, 334}, {351, 417}, {417, 168}, + {168, 351}, {352, 280}, {280, 411}, {411, 352}, {325, 319}, + {319, 320}, {320, 325}, {295, 296}, {296, 336}, {336, 295}, + {319, 403}, {403, 404}, {404, 319}, {330, 348}, {348, 349}, + {349, 330}, {293, 298}, {298, 333}, {333, 293}, {323, 454}, + {454, 447}, {447, 323}, {15, 16}, {16, 315}, {315, 15}, + {358, 429}, {429, 279}, {279, 358}, {14, 15}, {15, 316}, + {316, 14}, {285, 336}, {336, 9}, {9, 285}, {329, 349}, + {349, 350}, {350, 329}, {374, 380}, {380, 252}, {252, 374}, + {318, 402}, {402, 403}, {403, 318}, {6, 197}, {197, 419}, + {419, 6}, {318, 319}, {319, 325}, {325, 318}, {367, 364}, + {364, 365}, {365, 367}, {435, 367}, {367, 397}, {397, 435}, + {344, 438}, {438, 439}, {439, 344}, {272, 271}, {271, 311}, + {311, 272}, {195, 5}, {5, 281}, {281, 195}, {273, 287}, + {287, 291}, {291, 273}, {396, 428}, {428, 199}, {199, 396}, + {311, 271}, {271, 268}, {268, 311}, {283, 444}, {444, 445}, + {445, 283}, {373, 254}, {254, 339}, {339, 373}, {282, 334}, + {334, 296}, {296, 282}, {449, 347}, {347, 346}, {346, 449}, + {264, 447}, {447, 454}, {454, 264}, {336, 296}, {296, 299}, + {299, 336}, {338, 10}, {10, 151}, {151, 338}, {278, 439}, + {439, 455}, {455, 278}, {292, 407}, {407, 415}, {415, 292}, + {358, 371}, {371, 355}, {355, 358}, {340, 345}, {345, 372}, + {372, 340}, {346, 347}, {347, 280}, {280, 346}, {442, 443}, + {443, 282}, {282, 442}, {19, 94}, {94, 370}, {370, 19}, + {441, 442}, {442, 295}, {295, 441}, {248, 419}, {419, 197}, + {197, 248}, {263, 255}, {255, 359}, {359, 263}, {440, 275}, + {275, 274}, {274, 440}, {300, 383}, {383, 368}, {368, 300}, + {351, 412}, {412, 465}, {465, 351}, {263, 467}, {467, 466}, + {466, 263}, {301, 368}, {368, 389}, {389, 301}, {395, 378}, + {378, 379}, {379, 395}, {412, 351}, {351, 419}, {419, 412}, + {436, 426}, {426, 322}, {322, 436}, {2, 164}, {164, 393}, + {393, 2}, {370, 462}, {462, 461}, {461, 370}, {164, 0}, + {0, 267}, {267, 164}, {302, 11}, {11, 12}, {12, 302}, + {268, 12}, {12, 13}, {13, 268}, {293, 300}, {300, 301}, + {301, 293}, {446, 261}, {261, 340}, {340, 446}, {330, 266}, + {266, 425}, {425, 330}, {426, 423}, {423, 391}, {391, 426}, + {429, 355}, {355, 437}, {437, 429}, {391, 327}, {327, 326}, + {326, 391}, {440, 457}, {457, 438}, {438, 440}, {341, 382}, + {382, 362}, {362, 341}, {459, 457}, {457, 461}, {461, 459}, + {434, 430}, {430, 394}, {394, 434}, {414, 463}, {463, 362}, + {362, 414}, {396, 369}, {369, 262}, {262, 396}, {354, 461}, + {461, 457}, {457, 354}, {316, 403}, {403, 402}, {402, 316}, + {315, 404}, {404, 403}, {403, 315}, {314, 405}, {405, 404}, + {404, 314}, {313, 406}, {406, 405}, {405, 313}, {421, 418}, + {418, 406}, {406, 421}, {366, 401}, {401, 361}, {361, 366}, + {306, 408}, {408, 407}, {407, 306}, {291, 409}, {409, 408}, + {408, 291}, {287, 410}, {410, 409}, {409, 287}, {432, 436}, + {436, 410}, {410, 432}, {434, 416}, {416, 411}, {411, 434}, + {264, 368}, {368, 383}, {383, 264}, {309, 438}, {438, 457}, + {457, 309}, {352, 376}, {376, 401}, {401, 352}, {274, 275}, + {275, 4}, {4, 274}, {421, 428}, {428, 262}, {262, 421}, + {294, 327}, {327, 358}, {358, 294}, {433, 416}, {416, 367}, + {367, 433}, {289, 455}, {455, 439}, {439, 289}, {462, 370}, + {370, 326}, {326, 462}, {2, 326}, {326, 370}, {370, 2}, + {305, 460}, {460, 455}, {455, 305}, {254, 449}, {449, 448}, + {448, 254}, {255, 261}, {261, 446}, {446, 255}, {253, 450}, + {450, 449}, {449, 253}, {252, 451}, {451, 450}, {450, 252}, + {256, 452}, {452, 451}, {451, 256}, {341, 453}, {453, 452}, + {452, 341}, {413, 464}, {464, 463}, {463, 413}, {441, 413}, + {413, 414}, {414, 441}, {258, 442}, {442, 441}, {441, 258}, + {257, 443}, {443, 442}, {442, 257}, {259, 444}, {444, 443}, + {443, 259}, {260, 445}, {445, 444}, {444, 260}, {467, 342}, + {342, 445}, {445, 467}, {459, 458}, {458, 250}, {250, 459}, + {289, 392}, {392, 290}, {290, 289}, {290, 328}, {328, 460}, + {460, 290}, {376, 433}, {433, 435}, {435, 376}, {250, 290}, + {290, 392}, {392, 250}, {411, 416}, {416, 433}, {433, 411}, + {341, 463}, {463, 464}, {464, 341}, {453, 464}, {464, 465}, + {465, 453}, {357, 465}, {465, 412}, {412, 357}, {343, 412}, + {412, 399}, {399, 343}, {360, 363}, {363, 440}, {440, 360}, + {437, 399}, {399, 456}, {456, 437}, {420, 456}, {456, 363}, + {363, 420}, {401, 435}, {435, 288}, {288, 401}, {372, 383}, + {383, 353}, {353, 372}, {339, 255}, {255, 249}, {249, 339}, + {448, 261}, {261, 255}, {255, 448}, {133, 243}, {243, 190}, + {190, 133}, {133, 155}, {155, 112}, {112, 133}, {33, 246}, + {246, 247}, {247, 33}, {33, 130}, {130, 25}, {25, 33}, + {398, 384}, {384, 286}, {286, 398}, {362, 398}, {398, 414}, + {414, 362}, {362, 463}, {463, 341}, {341, 362}, {263, 359}, + {359, 467}, {467, 263}, {263, 249}, {249, 255}, {255, 263}, + {466, 467}, {467, 260}, {260, 466}, {75, 60}, {60, 166}, + {166, 75}, {238, 239}, {239, 79}, {79, 238}, {162, 127}, + {127, 139}, {139, 162}, {72, 11}, {11, 37}, {37, 72}, + {121, 232}, {232, 120}, {120, 121}, {73, 72}, {72, 39}, + {39, 73}, {114, 128}, {128, 47}, {47, 114}, {233, 232}, + {232, 128}, {128, 233}, {103, 104}, {104, 67}, {67, 103}, + {152, 175}, {175, 148}, {148, 152}, {119, 118}, {118, 101}, + {101, 119}, {74, 73}, {73, 40}, {40, 74}, {107, 9}, + {9, 108}, {108, 107}, {49, 48}, {48, 131}, {131, 49}, + {32, 194}, {194, 211}, {211, 32}, {184, 74}, {74, 185}, + {185, 184}, {191, 80}, {80, 183}, {183, 191}, {185, 40}, + {40, 186}, {186, 185}, {119, 230}, {230, 118}, {118, 119}, + {210, 202}, {202, 214}, {214, 210}, {84, 83}, {83, 17}, + {17, 84}, {77, 76}, {76, 146}, {146, 77}, {161, 160}, + {160, 30}, {30, 161}, {190, 56}, {56, 173}, {173, 190}, + {182, 106}, {106, 194}, {194, 182}, {138, 135}, {135, 192}, + {192, 138}, {129, 203}, {203, 98}, {98, 129}, {54, 21}, + {21, 68}, {68, 54}, {5, 51}, {51, 4}, {4, 5}, + {145, 144}, {144, 23}, {23, 145}, {90, 77}, {77, 91}, + {91, 90}, {207, 205}, {205, 187}, {187, 207}, {83, 201}, + {201, 18}, {18, 83}, {181, 91}, {91, 182}, {182, 181}, + {180, 90}, {90, 181}, {181, 180}, {16, 85}, {85, 17}, + {17, 16}, {205, 206}, {206, 36}, {36, 205}, {176, 148}, + {148, 140}, {140, 176}, {165, 92}, {92, 39}, {39, 165}, + {245, 193}, {193, 244}, {244, 245}, {27, 159}, {159, 28}, + {28, 27}, {30, 247}, {247, 161}, {161, 30}, {174, 236}, + {236, 196}, {196, 174}, {103, 54}, {54, 104}, {104, 103}, + {55, 193}, {193, 8}, {8, 55}, {111, 117}, {117, 31}, + {31, 111}, {221, 189}, {189, 55}, {55, 221}, {240, 98}, + {98, 99}, {99, 240}, {142, 126}, {126, 100}, {100, 142}, + {219, 166}, {166, 218}, {218, 219}, {112, 155}, {155, 26}, + {26, 112}, {198, 209}, {209, 131}, {131, 198}, {169, 135}, + {135, 150}, {150, 169}, {114, 47}, {47, 217}, {217, 114}, + {224, 223}, {223, 53}, {53, 224}, {220, 45}, {45, 134}, + {134, 220}, {32, 211}, {211, 140}, {140, 32}, {109, 67}, + {67, 108}, {108, 109}, {146, 43}, {43, 91}, {91, 146}, + {231, 230}, {230, 120}, {120, 231}, {113, 226}, {226, 247}, + {247, 113}, {105, 63}, {63, 52}, {52, 105}, {241, 238}, + {238, 242}, {242, 241}, {124, 46}, {46, 156}, {156, 124}, + {95, 78}, {78, 96}, {96, 95}, {70, 46}, {46, 63}, + {63, 70}, {116, 143}, {143, 227}, {227, 116}, {116, 123}, + {123, 111}, {111, 116}, {1, 44}, {44, 19}, {19, 1}, + {3, 236}, {236, 51}, {51, 3}, {207, 216}, {216, 205}, + {205, 207}, {26, 154}, {154, 22}, {22, 26}, {165, 39}, + {39, 167}, {167, 165}, {199, 200}, {200, 208}, {208, 199}, + {101, 36}, {36, 100}, {100, 101}, {43, 57}, {57, 202}, + {202, 43}, {242, 20}, {20, 99}, {99, 242}, {56, 28}, + {28, 157}, {157, 56}, {124, 35}, {35, 113}, {113, 124}, + {29, 160}, {160, 27}, {27, 29}, {211, 204}, {204, 210}, + {210, 211}, {124, 113}, {113, 46}, {46, 124}, {106, 43}, + {43, 204}, {204, 106}, {96, 62}, {62, 77}, {77, 96}, + {227, 137}, {137, 116}, {116, 227}, {73, 41}, {41, 72}, + {72, 73}, {36, 203}, {203, 142}, {142, 36}, {235, 64}, + {64, 240}, {240, 235}, {48, 49}, {49, 64}, {64, 48}, + {42, 41}, {41, 74}, {74, 42}, {214, 212}, {212, 207}, + {207, 214}, {183, 42}, {42, 184}, {184, 183}, {210, 169}, + {169, 211}, {211, 210}, {140, 170}, {170, 176}, {176, 140}, + {104, 105}, {105, 69}, {69, 104}, {193, 122}, {122, 168}, + {168, 193}, {50, 123}, {123, 187}, {187, 50}, {89, 96}, + {96, 90}, {90, 89}, {66, 65}, {65, 107}, {107, 66}, + {179, 89}, {89, 180}, {180, 179}, {119, 101}, {101, 120}, + {120, 119}, {68, 63}, {63, 104}, {104, 68}, {234, 93}, + {93, 227}, {227, 234}, {16, 15}, {15, 85}, {85, 16}, + {209, 129}, {129, 49}, {49, 209}, {15, 14}, {14, 86}, + {86, 15}, {107, 55}, {55, 9}, {9, 107}, {120, 100}, + {100, 121}, {121, 120}, {153, 145}, {145, 22}, {22, 153}, + {178, 88}, {88, 179}, {179, 178}, {197, 6}, {6, 196}, + {196, 197}, {89, 88}, {88, 96}, {96, 89}, {135, 138}, + {138, 136}, {136, 135}, {138, 215}, {215, 172}, {172, 138}, + {218, 115}, {115, 219}, {219, 218}, {41, 42}, {42, 81}, + {81, 41}, {5, 195}, {195, 51}, {51, 5}, {57, 43}, + {43, 61}, {61, 57}, {208, 171}, {171, 199}, {199, 208}, + {41, 81}, {81, 38}, {38, 41}, {224, 53}, {53, 225}, + {225, 224}, {24, 144}, {144, 110}, {110, 24}, {105, 52}, + {52, 66}, {66, 105}, {118, 229}, {229, 117}, {117, 118}, + {227, 34}, {34, 234}, {234, 227}, {66, 107}, {107, 69}, + {69, 66}, {10, 109}, {109, 151}, {151, 10}, {219, 48}, + {48, 235}, {235, 219}, {183, 62}, {62, 191}, {191, 183}, + {142, 129}, {129, 126}, {126, 142}, {116, 111}, {111, 143}, + {143, 116}, {118, 117}, {117, 50}, {50, 118}, {223, 222}, + {222, 52}, {52, 223}, {94, 19}, {19, 141}, {141, 94}, + {222, 221}, {221, 65}, {65, 222}, {196, 3}, {3, 197}, + {197, 196}, {45, 220}, {220, 44}, {44, 45}, {156, 70}, + {70, 139}, {139, 156}, {188, 122}, {122, 245}, {245, 188}, + {139, 71}, {71, 162}, {162, 139}, {149, 170}, {170, 150}, + {150, 149}, {122, 188}, {188, 196}, {196, 122}, {206, 216}, + {216, 92}, {92, 206}, {164, 2}, {2, 167}, {167, 164}, + {242, 141}, {141, 241}, {241, 242}, {0, 164}, {164, 37}, + {37, 0}, {11, 72}, {72, 12}, {12, 11}, {12, 38}, + {38, 13}, {13, 12}, {70, 63}, {63, 71}, {71, 70}, + {31, 226}, {226, 111}, {111, 31}, {36, 101}, {101, 205}, + {205, 36}, {203, 206}, {206, 165}, {165, 203}, {126, 209}, + {209, 217}, {217, 126}, {98, 165}, {165, 97}, {97, 98}, + {237, 220}, {220, 218}, {218, 237}, {237, 239}, {239, 241}, + {241, 237}, {210, 214}, {214, 169}, {169, 210}, {140, 171}, + {171, 32}, {32, 140}, {241, 125}, {125, 237}, {237, 241}, + {179, 86}, {86, 178}, {178, 179}, {180, 85}, {85, 179}, + {179, 180}, {181, 84}, {84, 180}, {180, 181}, {182, 83}, + {83, 181}, {181, 182}, {194, 201}, {201, 182}, {182, 194}, + {177, 137}, {137, 132}, {132, 177}, {184, 76}, {76, 183}, + {183, 184}, {185, 61}, {61, 184}, {184, 185}, {186, 57}, + {57, 185}, {185, 186}, {216, 212}, {212, 186}, {186, 216}, + {192, 214}, {214, 187}, {187, 192}, {139, 34}, {34, 156}, + {156, 139}, {218, 79}, {79, 237}, {237, 218}, {147, 123}, + {123, 177}, {177, 147}, {45, 44}, {44, 4}, {4, 45}, + {208, 201}, {201, 32}, {32, 208}, {98, 64}, {64, 129}, + {129, 98}, {192, 213}, {213, 138}, {138, 192}, {235, 59}, + {59, 219}, {219, 235}, {141, 242}, {242, 97}, {97, 141}, + {97, 2}, {2, 141}, {141, 97}, {240, 75}, {75, 235}, + {235, 240}, {229, 24}, {24, 228}, {228, 229}, {31, 25}, + {25, 226}, {226, 31}, {230, 23}, {23, 229}, {229, 230}, + {231, 22}, {22, 230}, {230, 231}, {232, 26}, {26, 231}, + {231, 232}, {233, 112}, {112, 232}, {232, 233}, {244, 189}, + {189, 243}, {243, 244}, {189, 221}, {221, 190}, {190, 189}, + {222, 28}, {28, 221}, {221, 222}, {223, 27}, {27, 222}, + {222, 223}, {224, 29}, {29, 223}, {223, 224}, {225, 30}, + {30, 224}, {224, 225}, {113, 247}, {247, 225}, {225, 113}, + {99, 60}, {60, 240}, {240, 99}, {213, 147}, {147, 215}, + {215, 213}, {60, 20}, {20, 166}, {166, 60}, {192, 187}, + {187, 213}, {213, 192}, {243, 112}, {112, 244}, {244, 243}, + {244, 233}, {233, 245}, {245, 244}, {245, 128}, {128, 188}, + {188, 245}, {188, 114}, {114, 174}, {174, 188}, {134, 131}, + {131, 220}, {220, 134}, {174, 217}, {217, 236}, {236, 174}, + {236, 198}, {198, 134}, {134, 236}, {215, 177}, {177, 58}, + {58, 215}, {156, 143}, {143, 124}, {124, 156}, {25, 110}, + {110, 7}, {7, 25}, {31, 228}, {228, 25}, {25, 31}, + {264, 356}, {356, 368}, {368, 264}, {0, 11}, {11, 267}, + {267, 0}, {451, 452}, {452, 349}, {349, 451}, {267, 302}, + {302, 269}, {269, 267}, {350, 357}, {357, 277}, {277, 350}, + {350, 452}, {452, 357}, {357, 350}, {299, 333}, {333, 297}, + {297, 299}, {396, 175}, {175, 377}, {377, 396}, {280, 347}, + {347, 330}, {330, 280}, {269, 303}, {303, 270}, {270, 269}, + {151, 9}, {9, 337}, {337, 151}, {344, 278}, {278, 360}, + {360, 344}, {424, 418}, {418, 431}, {431, 424}, {270, 304}, + {304, 409}, {409, 270}, {272, 310}, {310, 407}, {407, 272}, + {322, 270}, {270, 410}, {410, 322}, {449, 450}, {450, 347}, + {347, 449}, {432, 422}, {422, 434}, {434, 432}, {18, 313}, + {313, 17}, {17, 18}, {291, 306}, {306, 375}, {375, 291}, + {259, 387}, {387, 260}, {260, 259}, {424, 335}, {335, 418}, + {418, 424}, {434, 364}, {364, 416}, {416, 434}, {391, 423}, + {423, 327}, {327, 391}, {301, 251}, {251, 298}, {298, 301}, + {275, 281}, {281, 4}, {4, 275}, {254, 373}, {373, 253}, + {253, 254}, {375, 307}, {307, 321}, {321, 375}, {280, 425}, + {425, 411}, {411, 280}, {200, 421}, {421, 18}, {18, 200}, + {335, 321}, {321, 406}, {406, 335}, {321, 320}, {320, 405}, + {405, 321}, {314, 315}, {315, 17}, {17, 314}, {423, 426}, + {426, 266}, {266, 423}, {396, 377}, {377, 369}, {369, 396}, + {270, 322}, {322, 269}, {269, 270}, {413, 417}, {417, 464}, + {464, 413}, {385, 386}, {386, 258}, {258, 385}, {248, 456}, + {456, 419}, {419, 248}, {298, 284}, {284, 333}, {333, 298}, + {168, 417}, {417, 8}, {8, 168}, {448, 346}, {346, 261}, + {261, 448}, {417, 413}, {413, 285}, {285, 417}, {326, 327}, + {327, 328}, {328, 326}, {277, 355}, {355, 329}, {329, 277}, + {309, 392}, {392, 438}, {438, 309}, {381, 382}, {382, 256}, + {256, 381}, {279, 429}, {429, 360}, {360, 279}, {365, 364}, + {364, 379}, {379, 365}, {355, 277}, {277, 437}, {437, 355}, + {282, 443}, {443, 283}, {283, 282}, {281, 275}, {275, 363}, + {363, 281}, {395, 431}, {431, 369}, {369, 395}, {299, 297}, + {297, 337}, {337, 299}, {335, 273}, {273, 321}, {321, 335}, + {348, 450}, {450, 349}, {349, 348}, {359, 446}, {446, 467}, + {467, 359}, {283, 293}, {293, 282}, {282, 283}, {250, 458}, + {458, 462}, {462, 250}, {300, 276}, {276, 383}, {383, 300}, + {292, 308}, {308, 325}, {325, 292}, {283, 276}, {276, 293}, + {293, 283}, {264, 372}, {372, 447}, {447, 264}, {346, 352}, + {352, 340}, {340, 346}, {354, 274}, {274, 19}, {19, 354}, + {363, 456}, {456, 281}, {281, 363}, {426, 436}, {436, 425}, + {425, 426}, {380, 381}, {381, 252}, {252, 380}, {267, 269}, + {269, 393}, {393, 267}, {421, 200}, {200, 428}, {428, 421}, + {371, 266}, {266, 329}, {329, 371}, {432, 287}, {287, 422}, + {422, 432}, {290, 250}, {250, 328}, {328, 290}, {385, 258}, + {258, 384}, {384, 385}, {446, 265}, {265, 342}, {342, 446}, + {386, 387}, {387, 257}, {257, 386}, {422, 424}, {424, 430}, + {430, 422}, {445, 342}, {342, 276}, {276, 445}, {422, 273}, + {273, 424}, {424, 422}, {306, 292}, {292, 307}, {307, 306}, + {352, 366}, {366, 345}, {345, 352}, {268, 271}, {271, 302}, + {302, 268}, {358, 423}, {423, 371}, {371, 358}, {327, 294}, + {294, 460}, {460, 327}, {331, 279}, {279, 294}, {294, 331}, + {303, 271}, {271, 304}, {304, 303}, {436, 432}, {432, 427}, + {427, 436}, {304, 272}, {272, 408}, {408, 304}, {395, 394}, + {394, 431}, {431, 395}, {378, 395}, {395, 400}, {400, 378}, + {296, 334}, {334, 299}, {299, 296}, {6, 351}, {351, 168}, + {168, 6}, {376, 352}, {352, 411}, {411, 376}, {307, 325}, + {325, 320}, {320, 307}, {285, 295}, {295, 336}, {336, 285}, + {320, 319}, {319, 404}, {404, 320}, {329, 330}, {330, 349}, + {349, 329}, {334, 293}, {293, 333}, {333, 334}, {366, 323}, + {323, 447}, {447, 366}, {316, 15}, {15, 315}, {315, 316}, + {331, 358}, {358, 279}, {279, 331}, {317, 14}, {14, 316}, + {316, 317}, {8, 285}, {285, 9}, {9, 8}, {277, 329}, + {329, 350}, {350, 277}, {253, 374}, {374, 252}, {252, 253}, + {319, 318}, {318, 403}, {403, 319}, {351, 6}, {6, 419}, + {419, 351}, {324, 318}, {318, 325}, {325, 324}, {397, 367}, + {367, 365}, {365, 397}, {288, 435}, {435, 397}, {397, 288}, + {278, 344}, {344, 439}, {439, 278}, {310, 272}, {272, 311}, + {311, 310}, {248, 195}, {195, 281}, {281, 248}, {375, 273}, + {273, 291}, {291, 375}, {175, 396}, {396, 199}, {199, 175}, + {312, 311}, {311, 268}, {268, 312}, {276, 283}, {283, 445}, + {445, 276}, {390, 373}, {373, 339}, {339, 390}, {295, 282}, + {282, 296}, {296, 295}, {448, 449}, {449, 346}, {346, 448}, + {356, 264}, {264, 454}, {454, 356}, {337, 336}, {336, 299}, + {299, 337}, {337, 338}, {338, 151}, {151, 337}, {294, 278}, + {278, 455}, {455, 294}, {308, 292}, {292, 415}, {415, 308}, + {429, 358}, {358, 355}, {355, 429}, {265, 340}, {340, 372}, + {372, 265}, {352, 346}, {346, 280}, {280, 352}, {295, 442}, + {442, 282}, {282, 295}, {354, 19}, {19, 370}, {370, 354}, + {285, 441}, {441, 295}, {295, 285}, {195, 248}, {248, 197}, + {197, 195}, {457, 440}, {440, 274}, {274, 457}, {301, 300}, + {300, 368}, {368, 301}, {417, 351}, {351, 465}, {465, 417}, + {251, 301}, {301, 389}, {389, 251}, {394, 395}, {395, 379}, + {379, 394}, {399, 412}, {412, 419}, {419, 399}, {410, 436}, + {436, 322}, {322, 410}, {326, 2}, {2, 393}, {393, 326}, + {354, 370}, {370, 461}, {461, 354}, {393, 164}, {164, 267}, + {267, 393}, {268, 302}, {302, 12}, {12, 268}, {312, 268}, + {268, 13}, {13, 312}, {298, 293}, {293, 301}, {301, 298}, + {265, 446}, {446, 340}, {340, 265}, {280, 330}, {330, 425}, + {425, 280}, {322, 426}, {426, 391}, {391, 322}, {420, 429}, + {429, 437}, {437, 420}, {393, 391}, {391, 326}, {326, 393}, + {344, 440}, {440, 438}, {438, 344}, {458, 459}, {459, 461}, + {461, 458}, {364, 434}, {434, 394}, {394, 364}, {428, 396}, + {396, 262}, {262, 428}, {274, 354}, {354, 457}, {457, 274}, + {317, 316}, {316, 402}, {402, 317}, {316, 315}, {315, 403}, + {403, 316}, {315, 314}, {314, 404}, {404, 315}, {314, 313}, + {313, 405}, {405, 314}, {313, 421}, {421, 406}, {406, 313}, + {323, 366}, {366, 361}, {361, 323}, {292, 306}, {306, 407}, + {407, 292}, {306, 291}, {291, 408}, {408, 306}, {291, 287}, + {287, 409}, {409, 291}, {287, 432}, {432, 410}, {410, 287}, + {427, 434}, {434, 411}, {411, 427}, {372, 264}, {264, 383}, + {383, 372}, {459, 309}, {309, 457}, {457, 459}, {366, 352}, + {352, 401}, {401, 366}, {1, 274}, {274, 4}, {4, 1}, + {418, 421}, {421, 262}, {262, 418}, {331, 294}, {294, 358}, + {358, 331}, {435, 433}, {433, 367}, {367, 435}, {392, 289}, + {289, 439}, {439, 392}, {328, 462}, {462, 326}, {326, 328}, + {94, 2}, {2, 370}, {370, 94}, {289, 305}, {305, 455}, + {455, 289}, {339, 254}, {254, 448}, {448, 339}, {359, 255}, + {255, 446}, {446, 359}, {254, 253}, {253, 449}, {449, 254}, + {253, 252}, {252, 450}, {450, 253}, {252, 256}, {256, 451}, + {451, 252}, {256, 341}, {341, 452}, {452, 256}, {414, 413}, + {413, 463}, {463, 414}, {286, 441}, {441, 414}, {414, 286}, + {286, 258}, {258, 441}, {441, 286}, {258, 257}, {257, 442}, + {442, 258}, {257, 259}, {259, 443}, {443, 257}, {259, 260}, + {260, 444}, {444, 259}, {260, 467}, {467, 445}, {445, 260}, + {309, 459}, {459, 250}, {250, 309}, {305, 289}, {289, 290}, + {290, 305}, {305, 290}, {290, 460}, {460, 305}, {401, 376}, + {376, 435}, {435, 401}, {309, 250}, {250, 392}, {392, 309}, + {376, 411}, {411, 433}, {433, 376}, {453, 341}, {341, 464}, + {464, 453}, {357, 453}, {453, 465}, {465, 357}, {343, 357}, + {357, 412}, {412, 343}, {437, 343}, {343, 399}, {399, 437}, + {344, 360}, {360, 440}, {440, 344}, {420, 437}, {437, 456}, + {456, 420}, {360, 420}, {420, 363}, {363, 360}, {361, 401}, + {401, 288}, {288, 361}, {265, 372}, {372, 353}, {353, 265}, + {390, 339}, {339, 249}, {249, 390}, {339, 448}, {448, 255}, + {255, 339}}}; +}; + +} // namespace face_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKS_CONNECTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc index d9825b15f..9e3fdc0ca 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc @@ -111,6 +111,7 @@ class TensorsToImageCalculator : public Node { private: TensorsToImageCalculatorOptions options_; absl::Status CpuProcess(CalculatorContext* cc); + int tensor_position_; #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_METAL_ENABLED @@ -166,6 +167,7 @@ absl::Status TensorsToImageCalculator::Open(CalculatorContext* cc) { << "Must specify either `input_tensor_float_range` or " "`input_tensor_uint_range` in the calculator options"; } + tensor_position_ = options_.tensor_position(); return absl::OkStatus(); } @@ -202,17 +204,23 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); - const auto& input_tensor = input_tensors[0]; + const auto& input_tensor = input_tensors[tensor_position_]; const int tensor_in_height = input_tensor.shape().dims[1]; const int tensor_in_width = input_tensor.shape().dims[2]; const int tensor_in_channels = input_tensor.shape().dims[3]; - RET_CHECK_EQ(tensor_in_channels, 3); + RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1); - auto output_frame = std::make_shared( - mediapipe::ImageFormat::SRGB, tensor_in_width, tensor_in_height); + auto format = mediapipe::ImageFormat::SRGB; + if (tensor_in_channels == 1) { + format = mediapipe::ImageFormat::GRAY8; + } + + auto output_frame = + std::make_shared(format, tensor_in_width, tensor_in_height); cv::Mat output_matview = mediapipe::formats::MatView(output_frame.get()); constexpr float kOutputImageRangeMin = 0.0f; @@ -227,8 +235,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { GetValueRangeTransformation( input_range.min(), input_range.max(), kOutputImageRangeMin, kOutputImageRangeMax)); - tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, - transform.offset); + tensor_matview.convertTo(output_matview, + CV_MAKETYPE(CV_8U, tensor_in_channels), + transform.scale, transform.offset); } else if (input_tensor.element_type() == Tensor::ElementType::kUInt8) { cv::Mat tensor_matview( cv::Size(tensor_in_width, tensor_in_height), @@ -239,8 +248,9 @@ absl::Status TensorsToImageCalculator::CpuProcess(CalculatorContext* cc) { GetValueRangeTransformation( input_range.min(), input_range.max(), kOutputImageRangeMin, kOutputImageRangeMax)); - tensor_matview.convertTo(output_matview, CV_8UC3, transform.scale, - transform.offset); + tensor_matview.convertTo(output_matview, + CV_MAKETYPE(CV_8U, tensor_in_channels), + transform.scale, transform.offset); } else { return absl::InvalidArgumentError( absl::Substitute("Type of tensor must be kFloat32 or kUInt8, got: $0", @@ -264,10 +274,14 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); - const int tensor_width = input_tensors[0].shape().dims[2]; - const int tensor_height = input_tensors[0].shape().dims[1]; + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); + const int tensor_width = input_tensors[tensor_position_].shape().dims[2]; + const int tensor_height = input_tensors[tensor_position_].shape().dims[1]; + const int tensor_channels = input_tensors[tensor_position_].shape().dims[3]; + // TODO: Add 1 channel support. + RET_CHECK(tensor_channels == 3); // TODO: Fix unused variable [[maybe_unused]] id device = gpu_helper_.mtlDevice; @@ -277,8 +291,8 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { [command_buffer computeCommandEncoder]; [compute_encoder setComputePipelineState:to_buffer_program_]; - auto input_view = - mediapipe::MtlBufferView::GetReadView(input_tensors[0], command_buffer); + auto input_view = mediapipe::MtlBufferView::GetReadView( + input_tensors[tensor_position_], command_buffer); [compute_encoder setBuffer:input_view.buffer() offset:0 atIndex:0]; mediapipe::GpuBuffer output = @@ -355,7 +369,7 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { absl::StrCat(tflite::gpu::gl::GetShaderHeader(workgroup_size_), R"( precision highp float; layout(rgba8, binding = 0) writeonly uniform highp image2D output_texture; - uniform ivec2 out_size; + uniform ivec3 out_size; )"); const std::string shader_body = R"( @@ -366,10 +380,11 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { void main() { int out_width = out_size.x; int out_height = out_size.y; + int out_channels = out_size.z; ivec2 gid = ivec2(gl_GlobalInvocationID.xy); if (gid.x >= out_width || gid.y >= out_height) { return; } - int linear_index = 3 * (gid.y * out_width + gid.x); + int linear_index = out_channels * (gid.y * out_width + gid.x); #ifdef FLIP_Y_COORD int y_coord = out_height - gid.y - 1; @@ -377,8 +392,14 @@ absl::Status TensorsToImageCalculator::GlSetup(CalculatorContext* cc) { int y_coord = gid.y; #endif // defined(FLIP_Y_COORD) + vec4 out_value; ivec2 out_coordinate = ivec2(gid.x, y_coord); - vec4 out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); + if (out_channels == 3) { + out_value = vec4(input_data.elements[linear_index], input_data.elements[linear_index + 1], input_data.elements[linear_index + 2], 1.0); + } else { + float in_value = input_data.elements[linear_index]; + out_value = vec4(in_value, in_value, in_value, 1.0); + } imageStore(output_texture, out_coordinate, out_value); })"; @@ -438,10 +459,15 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { return absl::OkStatus(); } const auto& input_tensors = kInputTensors(cc).Get(); - RET_CHECK_EQ(input_tensors.size(), 1) - << "Expect 1 input tensor, but have " << input_tensors.size(); - const int tensor_width = input_tensors[0].shape().dims[2]; - const int tensor_height = input_tensors[0].shape().dims[1]; + RET_CHECK_GT(input_tensors.size(), tensor_position_) + << "Expect input tensor at position " << tensor_position_ + << ", but have tensors of size " << input_tensors.size(); + + const auto& input_tensor = input_tensors[tensor_position_]; + const int tensor_width = input_tensor.shape().dims[2]; + const int tensor_height = input_tensor.shape().dims[1]; + const int tensor_in_channels = input_tensor.shape().dims[3]; + RET_CHECK(tensor_in_channels == 3 || tensor_in_channels == 1); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 @@ -454,7 +480,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { glBindImageTexture(output_index, out_texture->id(), 0, GL_FALSE, 0, GL_WRITE_ONLY, GL_RGBA8); - auto read_view = input_tensors[0].GetOpenGlBufferReadView(); + auto read_view = input_tensor.GetOpenGlBufferReadView(); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); const tflite::gpu::uint3 workload = {tensor_width, tensor_height, 1}; @@ -462,8 +488,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { tflite::gpu::DivideRoundUp(workload, workgroup_size_); glUseProgram(gl_compute_program_->id()); - glUniform2i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), - tensor_width, tensor_height); + glUniform3i(glGetUniformLocation(gl_compute_program_->id(), "out_size"), + tensor_width, tensor_height, tensor_in_channels); MP_RETURN_IF_ERROR(gl_compute_program_->Dispatch(workgroups)); @@ -481,8 +507,8 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { #else - if (!input_tensors[0].ready_as_opengl_texture_2d()) { - (void)input_tensors[0].GetCpuReadView(); + if (!input_tensor.ready_as_opengl_texture_2d()) { + (void)input_tensor.GetCpuReadView(); } auto output_texture = @@ -490,7 +516,7 @@ absl::Status TensorsToImageCalculator::GlProcess(CalculatorContext* cc) { gl_helper_.BindFramebuffer(output_texture); // GL_TEXTURE0 glActiveTexture(GL_TEXTURE1); glBindTexture(GL_TEXTURE_2D, - input_tensors[0].GetOpenGlTexture2dReadView().name()); + input_tensor.GetOpenGlTexture2dReadView().name()); MP_RETURN_IF_ERROR(gl_renderer_->GlRender( tensor_width, tensor_height, output_texture.width(), diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto index 6bca86265..b0ecb8b5a 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.proto @@ -48,4 +48,8 @@ message TensorsToImageCalculatorOptions { FloatRange input_tensor_float_range = 2; UIntRange input_tensor_uint_range = 3; } + + // Determines which output tensor to slice when there are multiple output + // tensors available (e.g. network has multiple heads) + optional int32 tensor_position = 4 [default = 0]; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index f2afac494..1e24256d1 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -153,6 +153,11 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hand_landmarks_connections", + hdrs = ["hand_landmarks_connections.h"], +) + # TODO: open source hand joints graph cc_library( diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h new file mode 100644 index 000000000..510820294 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_connections.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKS_CONNECTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKS_CONNECTIONS_H_ + +#include + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +static constexpr std::array, 6> kHandPalmConnections{ + {{0, 1}, {0, 5}, {9, 13}, {13, 17}, {5, 9}, {0, 17}}}; + +static constexpr std::array, 3> kHandThumbConnections{ + {{1, 2}, {2, 3}, {3, 4}}}; + +static constexpr std::array, 3> kHandIndexFingerConnections{ + {{5, 6}, {6, 7}, {7, 8}}}; + +static constexpr std::array, 3> kHandMiddleFingerConnections{ + {{9, 10}, {10, 11}, {11, 12}}}; + +static constexpr std::array, 3> kHandRingFingerConnections{ + {{13, 14}, {14, 15}, {15, 16}}}; + +static constexpr std::array, 3> kHandPinkyFingerConnections{ + {{17, 18}, {18, 19}, {19, 20}}}; + +static constexpr std::array, 21> kHandConnections{ + {{0, 1}, {0, 5}, {9, 13}, {13, 17}, {5, 9}, {0, 17}, {1, 2}, + {2, 3}, {3, 4}, {5, 6}, {6, 7}, {7, 8}, {9, 10}, {10, 11}, + {11, 12}, {13, 14}, {14, 15}, {15, 16}, {17, 18}, {18, 19}, {19, 20}}}; + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKS_CONNECTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 99faa1064..a251a0ffc 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include +#include #include "absl/strings/str_format.h" #include "mediapipe/framework/api2/builder.h" @@ -41,6 +42,8 @@ constexpr char kConfidenceMasksTag[] = "CONFIDENCE_MASKS"; constexpr char kConfidenceMasksStreamName[] = "confidence_masks"; constexpr char kCategoryMaskTag[] = "CATEGORY_MASK"; constexpr char kCategoryMaskStreamName[] = "category_mask"; +constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kOutputSizeStreamName[] = "output_size"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -70,6 +73,7 @@ CalculatorGraphConfig CreateGraphConfig( options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); + graph.In(kOutputSizeTag).SetName(kOutputSizeStreamName); if (output_confidence_masks) { task_subgraph.Out(kConfidenceMasksTag) .SetName(kConfidenceMasksStreamName) >> @@ -85,10 +89,12 @@ CalculatorGraphConfig CreateGraphConfig( graph.Out(kImageTag); if (enable_flow_limiting) { return tasks::core::AddFlowLimiterCalculator( - graph, task_subgraph, {kImageTag, kNormRectTag}, kConfidenceMasksTag); + graph, task_subgraph, {kImageTag, kNormRectTag, kOutputSizeTag}, + kConfidenceMasksTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); + graph.In(kOutputSizeTag) >> task_subgraph.In(kOutputSizeTag); return graph.GetConfig(); } @@ -211,6 +217,13 @@ absl::StatusOr> ImageSegmenter::Create( absl::StatusOr ImageSegmenter::Segment( mediapipe::Image image, std::optional image_processing_options) { + return Segment(image, image.width(), image.height(), + std::move(image_processing_options)); +} + +absl::StatusOr ImageSegmenter::Segment( + mediapipe::Image image, int output_width, int output_height, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -225,7 +238,10 @@ absl::StatusOr ImageSegmenter::Segment( ProcessImageData( {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, {kNormRectStreamName, - MakePacket(std::move(norm_rect))}})); + MakePacket(std::move(norm_rect))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(output_width, output_height))}})); std::optional> confidence_masks; if (output_confidence_masks_) { confidence_masks = @@ -243,6 +259,14 @@ absl::StatusOr ImageSegmenter::Segment( absl::StatusOr ImageSegmenter::SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options) { + return SegmentForVideo(image, image.width(), image.height(), timestamp_ms, + image_processing_options); +} + +absl::StatusOr ImageSegmenter::SegmentForVideo( + mediapipe::Image image, int output_width, int output_height, + int64_t timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -260,6 +284,10 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kNormRectStreamName, MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(output_width, output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); std::optional> confidence_masks; if (output_confidence_masks_) { @@ -278,6 +306,13 @@ absl::StatusOr ImageSegmenter::SegmentForVideo( absl::Status ImageSegmenter::SegmentAsync( Image image, int64_t timestamp_ms, std::optional image_processing_options) { + return SegmentAsync(image, image.width(), image.height(), timestamp_ms, + image_processing_options); +} + +absl::Status ImageSegmenter::SegmentAsync( + Image image, int output_width, int output_height, int64_t timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -293,6 +328,10 @@ absl::Status ImageSegmenter::SegmentAsync( .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, {kNormRectStreamName, MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kOutputSizeStreamName, + MakePacket>( + std::make_pair(output_width, output_height)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 0546cef3a..237603497 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -102,17 +102,36 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // // The image can be of any size with format RGB or RGBA. // + // The output size is the same as the input image size. + // // The optional 'image_processing_options' parameter can be used to specify // the rotation to apply to the image before performing segmentation, by // setting its 'rotation_degrees' field. Note that specifying a // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. - absl::StatusOr Segment( mediapipe::Image image, std::optional image_processing_options = std::nullopt); + // Performs image segmentation on the provided single image. + // Only use this method when the ImageSegmenter is created with the image + // running mode. + // + // The image can be of any size with format RGB or RGBA. + // + // The output width and height specify the size of the resulted mask. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + absl::StatusOr Segment( + mediapipe::Image image, int output_width, int output_height, + std::optional image_processing_options = + std::nullopt); + // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video // running mode. @@ -121,16 +140,39 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // - // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing segmentation, by - // setting its 'rotation_degrees' field. Note that specifying a - // region-of-interest using the 'region_of_interest' field is NOT supported + // The output size is the same as the input image size. + // + // The optional 'image_processing_options' parameter can be used + // to specify the rotation to apply to the image before performing + // segmentation, by setting its 'rotation_degrees' field. Note that specifying + // a region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. absl::StatusOr SegmentForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = std::nullopt); + // Performs image segmentation on the provided video frame. + // Only use this method when the ImageSegmenter is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + // + // The output width and height specify the size of the resulted mask. + // + // The optional 'image_processing_options' parameter can be used + // to specify the rotation to apply to the image before performing + // segmentation, by setting its 'rotation_degrees' field. Note that specifying + // a region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + absl::StatusOr SegmentForVideo( + mediapipe::Image image, int output_width, int output_height, + int64_t timestamp_ms, + std::optional image_processing_options = + std::nullopt); + // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the // ImageSegmenterOptions. Only use this method when the ImageSegmenter is @@ -141,6 +183,8 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // + // The output size is the same as the input image size. + // // The optional 'image_processing_options' parameter can be used to specify // the rotation to apply to the image before performing segmentation, by // setting its 'rotation_degrees' field. Note that specifying a @@ -158,6 +202,36 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { std::optional image_processing_options = std::nullopt); + // Sends live image data to perform image segmentation, and the results will + // be available via the "result_callback" provided in the + // ImageSegmenterOptions. Only use this method when the ImageSegmenter is + // created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the image segmenter. The input timestamps must be monotonically + // increasing. + // + // The output width and height specify the size of the resulted mask. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // The "result_callback" prvoides + // - An ImageSegmenterResult. + // - The const reference to the corresponding input image that the image + // segmentation runs on. Note that the const reference to the image will + // no longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status SegmentAsync(mediapipe::Image image, int output_width, + int output_height, int64_t timestamp_ms, + std::optional + image_processing_options = std::nullopt); + // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 0ae47ffd1..e80da0123 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -82,6 +82,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kSizeTag[] = "SIZE"; constexpr char kQualityScoresTag[] = "QUALITY_SCORES"; constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; @@ -356,6 +357,9 @@ absl::StatusOr ConvertImageToTensors( // Describes image rotation and region of image to perform detection // on. // @Optional: rect covering the whole image is used if not specified. +// OUTPUT_SIZE - std::pair @Optional +// The output size of the mask, in width and height. If not specified, the +// output size of the input image is used. // // Outputs: // CONFIDENCE_MASK - mediapipe::Image @Multiple @@ -400,11 +404,16 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { if (!options.segmenter_options().has_output_type()) { MP_RETURN_IF_ERROR(SanityCheck(sc)); } + std::optional>> output_size; + if (HasInput(sc->OriginalNode(), kOutputSizeTag)) { + output_size = graph.In(kOutputSizeTag).Cast>(); + } ASSIGN_OR_RETURN( auto output_streams, BuildSegmentationTask( options, *model_resources, graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + graph[Input::Optional(kNormRectTag)], output_size, + graph)); // TODO: remove deprecated output type support. if (options.segmenter_options().has_output_type()) { @@ -469,7 +478,8 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, + std::optional>> output_size, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -514,10 +524,14 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { image_and_tensors.tensors >> inference.In(kTensorsTag); inference.Out(kTensorsTag) >> tensor_to_images.In(kTensorsTag); - // Adds image property calculator for output size. - auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); - image_in >> image_properties.In("IMAGE"); - image_properties.Out("SIZE") >> tensor_to_images.In(kOutputSizeTag); + if (output_size.has_value()) { + *output_size >> tensor_to_images.In(kOutputSizeTag); + } else { + // Adds image property calculator for output size. + auto& image_properties = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_properties.In(kImageTag); + image_properties.Out(kSizeTag) >> tensor_to_images.In(kOutputSizeTag); + } // Exports multiple segmented masks. // TODO: remove deprecated output type support. diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index f97857ddc..241c89588 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -155,3 +155,8 @@ cc_library( "//mediapipe/tasks/cc/components/containers:landmark", ], ) + +cc_library( + name = "pose_landmarks_connections", + hdrs = ["pose_landmarks_connections.h"], +) diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h new file mode 100644 index 000000000..4b79215a4 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_connections.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The MediaPipe Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKS_CONNECTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKS_CONNECTIONS_H_ + +#include + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_landmarker { + +static constexpr std::array, 34> kPoseLandmarksConnections{{ + {1, 2}, {0, 1}, {2, 3}, {3, 7}, {0, 4}, {4, 5}, {5, 6}, + {6, 8}, {9, 10}, {11, 12}, {11, 13}, {13, 15}, {15, 17}, {15, 19}, + {15, 21}, {17, 19}, {12, 14}, {14, 16}, {16, 18}, {16, 20}, {16, 22}, + {18, 20}, {11, 23}, {12, 24}, {23, 24}, {23, 25}, {24, 26}, {25, 27}, + {26, 28}, {27, 29}, {28, 30}, {29, 31}, {30, 32}, {27, 31}, +}}; + +} // namespace pose_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_POSE_LANDMARKER_POSE_LANDMARKS_CONNECTIONS_H_ diff --git a/mediapipe/tasks/ios/BUILD b/mediapipe/tasks/ios/BUILD index 29b0dd65f..14a409e72 100644 --- a/mediapipe/tasks/ios/BUILD +++ b/mediapipe/tasks/ios/BUILD @@ -66,7 +66,9 @@ strip_api_include_path_prefix( "//mediapipe/tasks/ios/components/containers:sources/MPPClassificationResult.h", "//mediapipe/tasks/ios/components/containers:sources/MPPEmbedding.h", "//mediapipe/tasks/ios/components/containers:sources/MPPEmbeddingResult.h", + "//mediapipe/tasks/ios/components/containers:sources/MPPConnection.h", "//mediapipe/tasks/ios/components/containers:sources/MPPDetection.h", + "//mediapipe/tasks/ios/components/containers:sources/MPPLandmark.h", "//mediapipe/tasks/ios/core:sources/MPPBaseOptions.h", "//mediapipe/tasks/ios/core:sources/MPPTaskOptions.h", "//mediapipe/tasks/ios/core:sources/MPPTaskResult.h", @@ -160,6 +162,8 @@ apple_static_xcframework( ":MPPCategory.h", ":MPPClassificationResult.h", ":MPPDetection.h", + ":MPPLandmark.h", + ":MPPConnection.h", ":MPPCommon.h", ":MPPTaskOptions.h", ":MPPTaskResult.h", diff --git a/mediapipe/tasks/ios/test/vision/face_detector/MPPFaceDetectorTests.mm b/mediapipe/tasks/ios/test/vision/face_detector/MPPFaceDetectorTests.mm index 1976bf603..548c4bdbf 100644 --- a/mediapipe/tasks/ios/test/vision/face_detector/MPPFaceDetectorTests.mm +++ b/mediapipe/tasks/ios/test/vision/face_detector/MPPFaceDetectorTests.mm @@ -25,7 +25,7 @@ static NSDictionary *const kPortraitImage = @{@"name" : @"portrait", @"type" : @"jpg", @"orientation" : @(UIImageOrientationUp)}; static NSDictionary *const kPortraitRotatedImage = - @{@"name" : @"portrait_rotated", @"type" : @"jpg", @"orientation" : @(UIImageOrientationRight)}; + @{@"name" : @"portrait_rotated", @"type" : @"jpg", @"orientation" : @(UIImageOrientationLeft)}; static NSDictionary *const kCatImage = @{@"name" : @"cat", @"type" : @"jpg"}; static NSString *const kShortRangeBlazeFaceModel = @"face_detection_short_range"; static NSArray *const kPortraitExpectedKeypoints = @[ diff --git a/mediapipe/tasks/ios/test/vision/gesture_recognizer/MPPGestureRecognizerTests.m b/mediapipe/tasks/ios/test/vision/gesture_recognizer/MPPGestureRecognizerTests.m index 8fbcb6b49..4b4eceed6 100644 --- a/mediapipe/tasks/ios/test/vision/gesture_recognizer/MPPGestureRecognizerTests.m +++ b/mediapipe/tasks/ios/test/vision/gesture_recognizer/MPPGestureRecognizerTests.m @@ -343,7 +343,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; MPPGestureRecognizer *gestureRecognizer = [self createGestureRecognizerWithOptionsSucceeds:gestureRecognizerOptions]; MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage - orientation:UIImageOrientationRight]; + orientation:UIImageOrientationLeft]; MPPGestureRecognizerResult *gestureRecognizerResult = [gestureRecognizer recognizeImage:mppImage error:nil]; diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m index c08976923..e1bd9f6c3 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m +++ b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m @@ -402,7 +402,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; ]; MPPImage *image = [self imageWithFileInfo:kBurgerRotatedImage - orientation:UIImageOrientationRight]; + orientation:UIImageOrientationLeft]; [self assertResultsOfClassifyImage:image usingImageClassifier:imageClassifier @@ -425,7 +425,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; displayName:nil] ]; MPPImage *image = [self imageWithFileInfo:kMultiObjectsRotatedImage - orientation:UIImageOrientationRight]; + orientation:UIImageOrientationLeft]; // roi around folding chair MPPImageClassifierResult *imageClassifierResult = diff --git a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m index 2ef5a0957..079682df1 100644 --- a/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m +++ b/mediapipe/tasks/ios/test/vision/object_detector/MPPObjectDetectorTests.m @@ -438,7 +438,7 @@ static NSString *const kLiveStreamTestsDictExpectationKey = @"expectation"; [[MPPObjectDetectorResult alloc] initWithDetections:detections timestampInMilliseconds:0]; MPPImage *image = [self imageWithFileInfo:kCatsAndDogsRotatedImage - orientation:UIImageOrientationRight]; + orientation:UIImageOrientationLeft]; [self assertResultsOfDetectInImage:image usingObjectDetector:objectDetector diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPImage.h b/mediapipe/tasks/ios/vision/core/sources/MPPImage.h index deffc97e2..847efc331 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPImage.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPImage.h @@ -62,10 +62,10 @@ NS_SWIFT_NAME(MPImage) /** * Initializes an `MPPImage` object with the given `UIImage`. - * The orientation of the newly created `MPPImage` will be `UIImageOrientationUp`. - * Hence, if this image is used as input for any MediaPipe vision tasks, inference will be - * performed on the it without any rotation. To create an `MPPImage` with a different orientation, - * please use `[MPPImage initWithImage:orientation:error:]`. + * The orientation of the newly created `MPPImage` will be equal to the `imageOrientation` of + * `UIImage` and when sent to the vision tasks for inference, rotation will be applied accordingly. + * To create an `MPPImage` with an orientation different from its `imageOrientation`, please use + * `[MPPImage initWithImage:orientation:error:]`. * * @param image The image to use as the source. Its `CGImage` property must not be `NULL`. * @param error An optional error parameter populated when there is an error in initializing the @@ -77,14 +77,19 @@ NS_SWIFT_NAME(MPImage) - (nullable instancetype)initWithUIImage:(UIImage *)image error:(NSError **)error; /** - * Initializes an `MPPImage` object with the given `UIImabe` and orientation. + * Initializes an `MPPImage` object with the given `UIImage` and orientation. The given orientation + * will be used to calculate the rotation to be applied to the `UIImage` before inference is + * performed on it by the vision tasks. The `imageOrientation` stored in the `UIImage` is ignored + * when `MPImage` objects created by this method are sent to the vision tasks for inference. Use + * `[MPPImage initWithImage:orientation:error:]` to initialize images with the `imageOrientation` of + * `UIImage`. * * If the newly created `MPPImage` is used as input for any MediaPipe vision tasks, inference * will be performed on a copy of the image rotated according to the orientation. * * @param image The image to use as the source. Its `CGImage` property must not be `NULL`. * @param orientation The display orientation of the image. This will be stored in the property - * `orientation`. `MPPImage`. + * `orientation` `MPPImage` and will override the `imageOrientation` of the passed in `UIImage`. * @param error An optional error parameter populated when there is an error in initializing the * `MPPImage`. * diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm index cba8a63ff..ae5e1d64c 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm @@ -30,13 +30,13 @@ using ::mediapipe::tasks::core::PacketsCallback; } // namespace /** Rotation degrees for a 90 degree rotation to the right. */ -static const NSInteger kMPPOrientationDegreesRight = -90; +static const NSInteger kMPPOrientationDegreesRight = -270; /** Rotation degrees for a 180 degree rotation. */ static const NSInteger kMPPOrientationDegreesDown = -180; /** Rotation degrees for a 90 degree rotation to the left. */ -static const NSInteger kMPPOrientationDegreesLeft = -270; +static const NSInteger kMPPOrientationDegreesLeft = -90; static NSString *const kTaskPrefix = @"com.mediapipe.tasks.vision"; diff --git a/mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h b/mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h index 23b423ad0..34284859f 100644 --- a/mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h +++ b/mediapipe/tasks/ios/vision/face_landmarker/sources/MPPFaceLandmarkerOptions.h @@ -30,7 +30,7 @@ NS_ASSUME_NONNULL_BEGIN * The delegate of `MPPFaceLandmarker` must adopt `MPPFaceLandmarkerLiveStreamDelegate` protocol. * The methods in this protocol are optional. */ -NS_SWIFT_NAME(FaceDetectorLiveStreamDelegate) +NS_SWIFT_NAME(FaceLandmarkerLiveStreamDelegate) @protocol MPPFaceLandmarkerLiveStreamDelegate /** diff --git a/mediapipe/tasks/ios/vision/image_segmenter/BUILD b/mediapipe/tasks/ios/vision/image_segmenter/BUILD index a0ebac2ae..54031f248 100644 --- a/mediapipe/tasks/ios/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/ios/vision/image_segmenter/BUILD @@ -35,3 +35,13 @@ objc_library( "//mediapipe/tasks/ios/vision/core:MPPRunningMode", ], ) + +objc_library( + name = "MPPImageSegmenter", + hdrs = ["sources/MPPImageSegmenterOptions.h"], + deps = [ + ":MPPImageSegmenterOptions", + ":MPPImageSegmenterResult", + "//mediapipe/tasks/ios/vision/core:MPPImage", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenter.h b/mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenter.h new file mode 100644 index 000000000..819b20129 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenter.h @@ -0,0 +1,217 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterOptions.h" +#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Class that performs segmentation on images. + * + * The API expects a TFLite model with mandatory TFLite Model Metadata. + */ +NS_SWIFT_NAME(ImageSegmenter) +@interface MPPImageSegmenter : NSObject + +/** + * Creates a new instance of `MPPImageSegmenter` from an absolute path to a TensorFlow Lite model + * file stored locally on the device and the default `MPPImageSegmenterOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * image segmenter. + * + * @return A new instance of `MPPImageSegmenter` with the given model path. `nil` if there is an + * error in initializing the image segmenter. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPImageSegmenter` from the given `MPPImageSegmenterOptions`. + * + * @param options The options of type `MPPImageSegmenterOptions` to use for configuring the + * `MPPImageSegmenter`. + * @param error An optional error parameter populated when there is an error in initializing the + * image segmenter. + * + * @return A new instance of `MPPImageSegmenter` with the given options. `nil` if there is an error + * in initializing the image segmenter. + */ +- (nullable instancetype)initWithOptions:(MPPImageSegmenterOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs segmentation on the provided MPPImage using the whole image as region of interest. + * Rotation will be applied according to the `orientation` property of the provided `MPPImage`. Only + * use this method when the `MPPImageSegmenter` is created with `MPPRunningModeImage`. + * + * This method supports RGBA images. If your `MPPImage` has a source type of + * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer + * must have one of the following pixel format types: + * 1. kCVPixelFormatType_32BGRA + * 2. kCVPixelFormatType_32RGBA + * + * If your `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color space is + * RGB with an Alpha channel. + * + * @param image The `MPPImage` on which segmentation is to be performed. + * @param error An optional error parameter populated when there is an error in performing + * segmentation on the input image. + * + * @return An `MPPImageSegmenterResult` that contains the segmented masks. + */ +- (nullable MPPImageSegmenterResult *)segmentImage:(MPPImage *)image + error:(NSError *)error NS_SWIFT_NAME(segment(image:)); + +/** + * Performs segmentation on the provided MPPImage using the whole image as region of interest and + * invokes the given completion handler block with the response. The method returns synchronously + * once the completion handler returns. + * + * Rotation will be applied according to the `orientation` property of the provided + * `MPPImage`. Only use this method when the `MPPImageSegmenter` is created with + * `MPPRunningModeImage`. + * + * This method supports RGBA images. If your `MPPImage` has a source type of + * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer + * must have one of the following pixel format types: + * 1. kCVPixelFormatType_32BGRA + * 2. kCVPixelFormatType_32RGBA + * + * If your `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color space is + * RGB with an Alpha channel. + * + * @param image The `MPPImage` on which segmentation is to be performed. + * @param completionHandler A block to be invoked with the results of performing segmentation on the + * image. The block takes two arguments, the optional `MPPImageSegmenterResult` that contains the + * segmented masks if the segmentation was successful and an optional error populated upon failure. + * The lifetime of the returned masks is only guaranteed for the duration of the block. + */ +- (void)segmentImage:(MPPImage *)image + withCompletionHandler:((void ^)(MPPImageSegmenterResult *_Nullable result, + NSError *_Nullable error))completionHandler + NS_SWIFT_NAME(segment(image:completion:)); + +/** + * Performs segmentation on the provided video frame of type `MPPImage` using the whole image as + * region of interest. + * + * Rotation will be applied according to the `orientation` property of the provided `MPPImage`. Only + * use this method when the `MPPImageSegmenter` is created with `MPPRunningModeVideo`. + * + * This method supports RGBA images. If your `MPPImage` has a source type of + * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer + * must have one of the following pixel format types: + * 1. kCVPixelFormatType_32BGRA + * 2. kCVPixelFormatType_32RGBA + * + * If your `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color space is + * RGB with an Alpha channel. + * + * @param image The `MPPImage` on which segmentation is to be performed. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing + * segmentation on the input image. + * + * @return An `MPPImageSegmenterResult` that contains a the segmented masks. + */ +- (nullable MPPImageSegmenterResult *)segmentVideoFrame:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(segment(videoFrame:timestampInMilliseconds:)); + +/** + * Performs segmentation on the provided video frame of type `MPPImage` using the whole image as + * region of interest invokes the given completion handler block with the response. The method + * returns synchronously once the completion handler returns. + * + * Rotation will be applied according to the `orientation` property of the provided `MPPImage`. Only + * use this method when the `MPPImageSegmenter` is created with `MPPRunningModeVideo`. + * + * This method supports RGBA images. If your `MPPImage` has a source type of + * `MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer + * must have one of the following pixel format types: + * 1. kCVPixelFormatType_32BGRA + * 2. kCVPixelFormatType_32RGBA + * + * If your `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color space is + * RGB with an Alpha channel. + * + * @param image The `MPPImage` on which segmentation is to be performed. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. + * @param completionHandler A block to be invoked with the results of performing segmentation on the + * image. The block takes two arguments, the optional `MPPImageSegmenterResult` that contains the + * segmented masks if the segmentation was successful and an optional error only populated upon + * failure. The lifetime of the returned masks is only guaranteed for the duration of the block. + */ +- (void)segmentVideoFrame:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + withCompletionHandler:((void ^)(MPPImageSegmenterResult *_Nullable result, + NSError *_Nullable error))completionHandler + NS_SWIFT_NAME(segment(videoFrame:timestampInMilliseconds:completion:)); + +/** + * Sends live stream image data of type `MPPImage` to perform segmentation using the whole image as + * region of interest. + * + * Rotation will be applied according to the `orientation` property of the provided `MPPImage`. Only + * use this method when the `MPPImageSegmenter` is created with`MPPRunningModeLiveStream`. + * + * The object which needs to be continuously notified of the available results of image segmentation + * must confirm to `MPPImageSegmenterLiveStreamDelegate` protocol and implement the + *`imageSegmenter:didFinishSegmentationWithResult:timestampInMilliseconds:error:` delegate method. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the segmenter. The input timestamps must be monotonically increasing. + * + * This method supports RGBA images. If your `MPPImage` has a source type of + *`MPPImageSourceTypePixelBuffer` or `MPPImageSourceTypeSampleBuffer`, the underlying pixel buffer + * must have one of the following pixel format types: + * 1. kCVPixelFormatType_32BGRA + * 2. kCVPixelFormatType_32RGBA + * + * If the input `MPPImage` has a source type of `MPPImageSourceTypeImage` ensure that the color + * space is RGB with an Alpha channel. + * + * If this method is used for classifying live camera frames using `AVFoundation`, ensure that you + * request `AVCaptureVideoDataOutput` to output frames in `kCMPixelFormat_32RGBA` using its + * `videoSettings` property. + * + * @param image A live stream image data of type `MPPImage` on which segmentation is to be + * performed. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the segmenter. The input timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error when sending the input + * image to the graph. + * + * @return `YES` if the image was sent to the task successfully, otherwise `NO`. + */ +- (BOOL)segmentAsyncInImage:(MPPImage *)image + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(segmentAsync(image:timestampInMilliseconds:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_segmenter/utils/BUILD b/mediapipe/tasks/ios/vision/image_segmenter/utils/BUILD new file mode 100644 index 000000000..7630dd7e6 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_segmenter/utils/BUILD @@ -0,0 +1,42 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPImageSegmenterOptionsHelpers", + srcs = ["sources/MPPImageSegmenterOptions+Helpers.mm"], + hdrs = ["sources/MPPImageSegmenterOptions+Helpers.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterOptions", + ], +) + +objc_library( + name = "MPPImageSegmenterResultHelpers", + srcs = ["sources/MPPImageSegmenterResult+Helpers.mm"], + hdrs = ["sources/MPPImageSegmenterResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:image", + "//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenterResult", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.h b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.h new file mode 100644 index 000000000..4d3b222f8 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.h @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPImageSegmenterOptions (Helpers) + +/** + * Populates the provided `CalculatorOptions` proto container with the current settings. + * + * @param optionsProto The `CalculatorOptions` proto object to copy the settings to. + */ +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.mm b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.mm new file mode 100644 index 000000000..d27bb91d7 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.mm @@ -0,0 +1,41 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using ImageSegmenterGraphOptionsProto = + ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterGraphOptions; +using SegmenterOptionsProto = ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +} // namespace + +@implementation MPPImageSegmenterOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + ImageSegmenterGraphOptionsProto *imageSegmenterGraphOptionsProto = + optionsProto->MutableExtension(ImageSegmenterGraphOptionsProto::ext); + imageSegmenterGraphOptionsProto->Clear(); + + [self.baseOptions copyToProto:imageSegmenterGraphOptionsProto->mutable_base_options() + withUseStreamMode:self.runningMode != MPPRunningModeImage]; + imageSegmenterGraphOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); +} + +@end diff --git a/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h new file mode 100644 index 000000000..503fcd1d7 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h @@ -0,0 +1,48 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_segmenter/sources/MPPImageSegmenterResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPImageSegmenterResult (Helpers) + +/** + * Creates an `MPPImageSegmenterResult` from confidence masks, category mask and quality scores + * packets. + * + * If `shouldCopyMaskPacketData` is set to `YES`, the confidence and catergory masks of the newly + * created `MPPImageSegmenterResult` holds references to deep copied pixel data of the output + * respective masks. + * + * @param confidenceMasksPacket A MediaPipe packet wrapping a `std::vector`. + * @param categoryMaskPacket A MediaPipe packet wrapping a ``. + * @param qualityScoresPacket A MediaPipe packet wrapping a `std::vector`. + * @param shouldCopyMaskPacketData A `BOOL` which indicates if the pixel data of the output masks + * must be deep copied to the newly created `MPPImageSegmenterResult`. + * + * @return An `MPPImageSegmenterResult` object that contains the image segmentation results. + */ ++ (MPPImageSegmenterResult *) + imageSegmenterResultWithConfidenceMasksPacket:(const mediapipe::Packet &)confidenceMasksPacket + categoryMaskPacket:(const mediapipe::Packet &)categoryMaskPacket + qualityScoresPacket:(const mediapipe::Packet &)qualityScoresPacket + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + shouldCopyMaskPacketData:(BOOL)shouldCopyMaskPacketData; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm new file mode 100644 index 000000000..d6e3b1be8 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.mm @@ -0,0 +1,78 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_segmenter/utils/sources/MPPImageSegmenterResult+Helpers.h" + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/packet.h" + +namespace { +using ::mediapipe::Image; +using ::mediapipe::ImageFrameSharedPtr; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPImageSegmenterResult (Helpers) + ++ (MPPImageSegmenterResult *) + imageSegmenterResultWithConfidenceMasksPacket:(const Packet &)confidenceMasksPacket + categoryMaskPacket:(const Packet &)categoryMaskPacket + qualityScoresPacket:(const Packet &)qualityScoresPacket + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + shouldCopyMaskPacketData:(BOOL)shouldCopyMaskPacketData { + NSMutableArray *confidenceMasks; + MPPMask *categoryMask; + NSMutableArray *qualityScores; + + if (confidenceMasksPacket.ValidateAsType>().ok()) { + std::vector cppConfidenceMasks = confidenceMasksPacket.Get>(); + confidenceMasks = [NSMutableArray arrayWithCapacity:(NSUInteger)cppConfidenceMasks.size()]; + + for (const auto &confidenceMask : cppConfidenceMasks) { + [confidenceMasks + addObject:[[MPPMask alloc] + initWithFloat32Data:(float *)confidenceMask.GetImageFrameSharedPtr() + .get() + ->PixelData() + width:confidenceMask.width() + height:confidenceMask.height() + shouldCopy:shouldCopyMaskPacketData ? YES : NO]]; + } + } + + if (categoryMaskPacket.ValidateAsType().ok()) { + const Image &cppCategoryMask = confidenceMasksPacket.Get(); + categoryMask = [[MPPMask alloc] + initWithUInt8Data:(UInt8 *)cppCategoryMask.GetImageFrameSharedPtr().get()->PixelData() + width:cppCategoryMask.width() + height:cppCategoryMask.height() + shouldCopy:shouldCopyMaskPacketData ? YES : NO]; + } + + if (qualityScoresPacket.ValidateAsType>().ok()) { + std::vector cppQualityScores = qualityScoresPacket.Get>(); + qualityScores = [NSMutableArray arrayWithCapacity:(NSUInteger)cppQualityScores.size()]; + + for (const auto &qualityScore : cppQualityScores) { + [qualityScores addObject:[NSNumber numberWithFloat:qualityScore]]; + } + } + + return [[MPPImageSegmenterResult alloc] initWithConfidenceMasks:confidenceMasks + categoryMask:categoryMask + qualityScores:qualityScores + timestampInMilliseconds:timestampInMilliseconds]; +} + +@end diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 07106985d..bcdc0e5e5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -92,6 +92,9 @@ android_library( android_library( name = "landmark", srcs = ["Landmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ "//third_party:autovalue", "@maven//:com_google_guava_guava", @@ -101,6 +104,9 @@ android_library( android_library( name = "normalized_landmark", srcs = ["NormalizedLandmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index c3e9f2715..e23d9115d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; import java.util.Objects; +import java.util.Optional; /** * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in @@ -27,7 +28,12 @@ public abstract class Landmark { private static final float TOLERANCE = 1e-6f; public static Landmark create(float x, float y, float z) { - return new AutoValue_Landmark(x, y, z); + return new AutoValue_Landmark(x, y, z, Optional.empty(), Optional.empty()); + } + + public static Landmark create( + float x, float y, float z, Optional visibility, Optional presence) { + return new AutoValue_Landmark(x, y, z, visibility, presence); } // The x coordinates of the landmark. @@ -39,6 +45,12 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); + // Visibility of the normalized landmark. + public abstract Optional visibility(); + + // Presence of the normalized landmark. + public abstract Optional presence(); + @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { @@ -57,6 +69,16 @@ public abstract class Landmark { @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java index f96e434ca..50a95d565 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; import java.util.Objects; +import java.util.Optional; /** * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are @@ -28,7 +29,12 @@ public abstract class NormalizedLandmark { private static final float TOLERANCE = 1e-6f; public static NormalizedLandmark create(float x, float y, float z) { - return new AutoValue_NormalizedLandmark(x, y, z); + return new AutoValue_NormalizedLandmark(x, y, z, Optional.empty(), Optional.empty()); + } + + public static NormalizedLandmark create( + float x, float y, float z, Optional visibility, Optional presence) { + return new AutoValue_NormalizedLandmark(x, y, z, visibility, presence); } // The x coordinates of the normalized landmark. @@ -40,6 +46,12 @@ public abstract class NormalizedLandmark { // The z coordinates of the normalized landmark. public abstract float z(); + // Visibility of the normalized landmark. + public abstract Optional visibility(); + + // Presence of the normalized landmark. + public abstract Optional presence(); + @Override public final boolean equals(Object o) { if (!(o instanceof NormalizedLandmark)) { @@ -58,6 +70,16 @@ public abstract class NormalizedLandmark { @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index d04fc4258..eb658c0e2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -32,6 +32,7 @@ android_library( "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:any_java_proto", "//third_party:autovalue", "@com_google_protobuf//:protobuf_javalite", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java index 8eec72ef9..dc2c001ba 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java @@ -54,6 +54,9 @@ public abstract class BaseOptions { */ public abstract Builder setDelegate(Delegate delegate); + /** Options for the chosen delegate. If not set, the default delegate options is used. */ + public abstract Builder setDelegateOptions(DelegateOptions delegateOptions); + abstract BaseOptions autoBuild(); /** @@ -79,6 +82,23 @@ public abstract class BaseOptions { throw new IllegalArgumentException( "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); } + boolean delegateMatchesDelegateOptions = true; + if (options.delegateOptions().isPresent()) { + switch (options.delegate()) { + case CPU: + delegateMatchesDelegateOptions = + options.delegateOptions().get() instanceof DelegateOptions.CpuOptions; + break; + case GPU: + delegateMatchesDelegateOptions = + options.delegateOptions().get() instanceof DelegateOptions.GpuOptions; + break; + } + if (!delegateMatchesDelegateOptions) { + throw new IllegalArgumentException( + "Specified Delegate type does not match the provided delegate options."); + } + } return options; } } @@ -91,6 +111,67 @@ public abstract class BaseOptions { abstract Delegate delegate(); + abstract Optional delegateOptions(); + + /** Advanced config options for the used delegate. */ + public abstract static class DelegateOptions { + + /** Options for CPU. */ + @AutoValue + public abstract static class CpuOptions extends DelegateOptions { + + public static Builder builder() { + Builder builder = new AutoValue_BaseOptions_DelegateOptions_CpuOptions.Builder(); + return builder; + } + + /** Builder for {@link CpuOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + public abstract CpuOptions build(); + } + } + + /** Options for GPU. */ + @AutoValue + public abstract static class GpuOptions extends DelegateOptions { + // Load pre-compiled serialized binary cache to accelerate init process. + // Only available on Android. Kernel caching will only be enabled if this + // path is set. NOTE: binary cache usage may be skipped if valid serialized + // model, specified by "serialized_model_dir", exists. + abstract Optional cachedKernelPath(); + + // A dir to load from and save to a pre-compiled serialized model used to + // accelerate init process. + // NOTE: serialized model takes precedence over binary cache + // specified by "cached_kernel_path", which still can be used if + // serialized model is invalid or missing. + abstract Optional serializedModelDir(); + + // Unique token identifying the model. Used in conjunction with + // "serialized_model_dir". It is the caller's responsibility to ensure + // there is no clash of the tokens. + abstract Optional modelToken(); + + public static Builder builder() { + return new AutoValue_BaseOptions_DelegateOptions_GpuOptions.Builder(); + } + + /** Builder for {@link GpuOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setCachedKernelPath(String cachedKernelPath); + + public abstract Builder setSerializedModelDir(String serializedModelDir); + + public abstract Builder setModelToken(String modelToken); + + public abstract GpuOptions build(); + } + } + } + public static Builder builder() { return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java index 3c422a8b2..ad3d01119 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -20,6 +20,8 @@ import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node; import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo; import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.protobuf.Any; import java.util.ArrayList; import java.util.List; @@ -110,10 +112,21 @@ public abstract class TaskInfo { */ CalculatorGraphConfig generateGraphConfig() { CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder(); - Node.Builder taskSubgraphBuilder = - Node.newBuilder() - .setCalculator(taskGraphName()) - .setOptions(taskOptions().convertToCalculatorOptionsProto()); + CalculatorOptions options = taskOptions().convertToCalculatorOptionsProto(); + Any anyOptions = taskOptions().convertToAnyProto(); + if (!(options == null ^ anyOptions == null)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Only one of convertTo*Proto() method should be implemented for " + + taskOptions().getClass()); + } + Node.Builder taskSubgraphBuilder = Node.newBuilder().setCalculator(taskGraphName()); + if (options != null) { + taskSubgraphBuilder.setOptions(options); + } + if (anyOptions != null) { + taskSubgraphBuilder.addNodeOptions(anyOptions); + } for (String outputStream : outputStreams()) { taskSubgraphBuilder.addOutputStream(outputStream); graphBuilder.addOutputStream(outputStream); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java index 11330ac0f..4ca258429 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java @@ -20,18 +20,26 @@ import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.tasks.core.proto.AccelerationProto; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.core.proto.ExternalFileProto; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; /** * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend - * {@link TaskOptions}. + * {@link TaskOptions} and implement exactly one of converTo*Proto() methods. */ public abstract class TaskOptions { /** * Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf * message. */ - public abstract CalculatorOptions convertToCalculatorOptionsProto(); + public CalculatorOptions convertToCalculatorOptionsProto() { + return null; + } + + /** Converts a MediaPipe Tasks task-specific options to an proto3 {@link Any} message. */ + public Any convertToAnyProto() { + return null; + } /** * Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf @@ -61,17 +69,51 @@ public abstract class TaskOptions { accelerationBuilder.setTflite( InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite .getDefaultInstance()); + options + .delegateOptions() + .ifPresent( + delegateOptions -> + setDelegateOptions( + accelerationBuilder, + (BaseOptions.DelegateOptions.CpuOptions) delegateOptions)); break; case GPU: accelerationBuilder.setGpu( InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder() .setUseAdvancedGpuApi(true) .build()); + options + .delegateOptions() + .ifPresent( + delegateOptions -> + setDelegateOptions( + accelerationBuilder, + (BaseOptions.DelegateOptions.GpuOptions) delegateOptions)); break; } + return BaseOptionsProto.BaseOptions.newBuilder() .setModelAsset(externalFileBuilder.build()) .setAcceleration(accelerationBuilder.build()) .build(); } + + private void setDelegateOptions( + AccelerationProto.Acceleration.Builder accelerationBuilder, + BaseOptions.DelegateOptions.CpuOptions options) { + accelerationBuilder.setTflite( + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.TfLite.getDefaultInstance()); + } + + private void setDelegateOptions( + AccelerationProto.Acceleration.Builder accelerationBuilder, + BaseOptions.DelegateOptions.GpuOptions options) { + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.Builder gpuBuilder = + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.newBuilder() + .setUseAdvancedGpuApi(true); + options.cachedKernelPath().ifPresent(gpuBuilder::setCachedKernelPath); + options.serializedModelDir().ifPresent(gpuBuilder::setSerializedModelDir); + options.modelToken().ifPresent(gpuBuilder::setModelToken); + accelerationBuilder.setGpu(gpuBuilder.build()); + } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java index c91477e10..0429ecacb 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarkerResult.java @@ -53,7 +53,15 @@ public abstract class FaceLandmarkerResult implements TaskResult { faceLandmarksProto.getLandmarkList()) { faceLandmarks.add( NormalizedLandmark.create( - faceLandmarkProto.getX(), faceLandmarkProto.getY(), faceLandmarkProto.getZ())); + faceLandmarkProto.getX(), + faceLandmarkProto.getY(), + faceLandmarkProto.getZ(), + faceLandmarkProto.hasVisibility() + ? Optional.of(faceLandmarkProto.getVisibility()) + : Optional.empty(), + faceLandmarkProto.hasPresence() + ? Optional.of(faceLandmarkProto.getPresence()) + : Optional.empty())); } } Optional>> multiFaceBlendshapes = Optional.empty(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 467e871b2..b8b236d42 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -25,6 +25,7 @@ import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; /** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */ @AutoValue @@ -53,7 +54,15 @@ public abstract class HandLandmarkerResult implements TaskResult { handLandmarksProto.getLandmarkList()) { handLandmarks.add( NormalizedLandmark.create( - handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); + handLandmarkProto.getX(), + handLandmarkProto.getY(), + handLandmarkProto.getZ(), + handLandmarkProto.hasVisibility() + ? Optional.of(handLandmarkProto.getVisibility()) + : Optional.empty(), + handLandmarkProto.hasPresence() + ? Optional.of(handLandmarkProto.getPresence()) + : Optional.empty())); } } for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { @@ -65,7 +74,13 @@ public abstract class HandLandmarkerResult implements TaskResult { com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ())); + handWorldLandmarkProto.getZ(), + handWorldLandmarkProto.hasVisibility() + ? Optional.of(handWorldLandmarkProto.getVisibility()) + : Optional.empty(), + handWorldLandmarkProto.hasPresence() + ? Optional.of(handWorldLandmarkProto.getPresence()) + : Optional.empty())); } } for (ClassificationList handednessProto : handednessesProto) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java index 389e78266..0dde56700 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java @@ -58,7 +58,15 @@ public abstract class PoseLandmarkerResult implements TaskResult { poseLandmarksProto.getLandmarkList()) { poseLandmarks.add( NormalizedLandmark.create( - poseLandmarkProto.getX(), poseLandmarkProto.getY(), poseLandmarkProto.getZ())); + poseLandmarkProto.getX(), + poseLandmarkProto.getY(), + poseLandmarkProto.getZ(), + poseLandmarkProto.hasVisibility() + ? Optional.of(poseLandmarkProto.getVisibility()) + : Optional.empty(), + poseLandmarkProto.hasPresence() + ? Optional.of(poseLandmarkProto.getPresence()) + : Optional.empty())); } } for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { @@ -70,7 +78,13 @@ public abstract class PoseLandmarkerResult implements TaskResult { Landmark.create( poseWorldLandmarkProto.getX(), poseWorldLandmarkProto.getY(), - poseWorldLandmarkProto.getZ())); + poseWorldLandmarkProto.getZ(), + poseWorldLandmarkProto.hasVisibility() + ? Optional.of(poseWorldLandmarkProto.getVisibility()) + : Optional.empty(), + poseWorldLandmarkProto.hasPresence() + ? Optional.of(poseWorldLandmarkProto.getPresence()) + : Optional.empty())); } } return new AutoValue_PoseLandmarkerResult( diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/AndroidManifest.xml new file mode 100644 index 000000000..26310fc18 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD index 01e7ad0fa..ce7435d69 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD @@ -23,3 +23,5 @@ android_library( "//third_party/java/android_libs/guava_jdk5:io", ], ) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BaseOptionsTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BaseOptionsTest.java new file mode 100644 index 000000000..939ecb407 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BaseOptionsTest.java @@ -0,0 +1,159 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.tasks.core.proto.AccelerationProto; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link BaseOptions} */ +@RunWith(Suite.class) +@SuiteClasses({BaseOptionsTest.General.class, BaseOptionsTest.ConvertProtoTest.class}) +public class BaseOptionsTest { + + static final String MODEL_ASSET_PATH = "dummy_model.tflite"; + static final String SERIALIZED_MODEL_DIR = "dummy_serialized_model_dir"; + static final String MODEL_TOKEN = "dummy_model_token"; + static final String CACHED_KERNEL_PATH = "dummy_cached_kernel_path"; + + @RunWith(AndroidJUnit4.class) + public static final class General extends BaseOptionsTest { + @Test + public void succeedsWithDefaultOptions() throws Exception { + BaseOptions options = BaseOptions.builder().setModelAssetPath(MODEL_ASSET_PATH).build(); + assertThat(options.modelAssetPath().isPresent()).isTrue(); + assertThat(options.modelAssetPath().get()).isEqualTo(MODEL_ASSET_PATH); + assertThat(options.delegate()).isEqualTo(Delegate.CPU); + } + + @Test + public void succeedsWithGpuOptions() throws Exception { + BaseOptions options = + BaseOptions.builder() + .setModelAssetPath(MODEL_ASSET_PATH) + .setDelegate(Delegate.GPU) + .setDelegateOptions( + BaseOptions.DelegateOptions.GpuOptions.builder() + .setSerializedModelDir(SERIALIZED_MODEL_DIR) + .setModelToken(MODEL_TOKEN) + .setCachedKernelPath(CACHED_KERNEL_PATH) + .build()) + .build(); + assertThat( + ((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get()) + .serializedModelDir() + .get()) + .isEqualTo(SERIALIZED_MODEL_DIR); + assertThat( + ((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get()) + .modelToken() + .get()) + .isEqualTo(MODEL_TOKEN); + assertThat( + ((BaseOptions.DelegateOptions.GpuOptions) options.delegateOptions().get()) + .cachedKernelPath() + .get()) + .isEqualTo(CACHED_KERNEL_PATH); + } + + @Test + public void failsWithInvalidDelegateOptions() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + BaseOptions.builder() + .setModelAssetPath(MODEL_ASSET_PATH) + .setDelegate(Delegate.CPU) + .setDelegateOptions( + BaseOptions.DelegateOptions.GpuOptions.builder() + .setSerializedModelDir(SERIALIZED_MODEL_DIR) + .setModelToken(MODEL_TOKEN) + .build()) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Specified Delegate type does not match the provided delegate options."); + } + } + + /** A mock TaskOptions class providing access to convertBaseOptionsToProto. */ + public static class MockTaskOptions extends TaskOptions { + + public MockTaskOptions(BaseOptions baseOptions) { + baseOptionsProto = convertBaseOptionsToProto(baseOptions); + } + + public BaseOptionsProto.BaseOptions getBaseOptionsProto() { + return baseOptionsProto; + } + + private BaseOptionsProto.BaseOptions baseOptionsProto; + + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + return CalculatorOptions.newBuilder().build(); + } + } + + /** Test for converting {@link BaseOptions} to {@link BaseOptionsProto} */ + @RunWith(AndroidJUnit4.class) + public static final class ConvertProtoTest extends BaseOptionsTest { + @Test + public void succeedsWithDefaultOptions() throws Exception { + BaseOptions options = + BaseOptions.builder() + .setModelAssetPath(MODEL_ASSET_PATH) + .setDelegate(Delegate.CPU) + .setDelegateOptions(BaseOptions.DelegateOptions.CpuOptions.builder().build()) + .build(); + MockTaskOptions taskOptions = new MockTaskOptions(options); + AccelerationProto.Acceleration acceleration = + taskOptions.getBaseOptionsProto().getAcceleration(); + assertThat(acceleration.hasTflite()).isTrue(); + } + + @Test + public void succeedsWithGpuOptions() throws Exception { + BaseOptions options = + BaseOptions.builder() + .setModelAssetPath(MODEL_ASSET_PATH) + .setDelegate(Delegate.GPU) + .setDelegateOptions( + BaseOptions.DelegateOptions.GpuOptions.builder() + .setModelToken(MODEL_TOKEN) + .setSerializedModelDir(SERIALIZED_MODEL_DIR) + .build()) + .build(); + MockTaskOptions taskOptions = new MockTaskOptions(options); + AccelerationProto.Acceleration acceleration = + taskOptions.getBaseOptionsProto().getAcceleration(); + assertThat(acceleration.hasTflite()).isFalse(); + assertThat(acceleration.hasGpu()).isTrue(); + assertThat(acceleration.getGpu().getUseAdvancedGpuApi()).isTrue(); + assertThat(acceleration.getGpu().hasCachedKernelPath()).isFalse(); + assertThat(acceleration.getGpu().getModelToken()).isEqualTo(MODEL_TOKEN); + assertThat(acceleration.getGpu().getSerializedModelDir()).isEqualTo(SERIALIZED_MODEL_DIR); + } + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java index ed7573b2a..20084ee7c 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textembedder/TextEmbedderTest.java @@ -140,7 +140,7 @@ public class TextEmbedderTest { TextEmbedder.cosineSimilarity( result0.embeddingResult().embeddings().get(0), result1.embeddingResult().embeddings().get(0)); - assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3565317439544432); } @Test diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java index 7adef9e27..508709ab0 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java @@ -15,6 +15,7 @@ package com.google.mediapipe.tasks.vision.poselandmarker; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertThrows; import android.content.res.AssetManager; @@ -26,6 +27,7 @@ import com.google.common.truth.Correspondence; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; @@ -34,6 +36,7 @@ import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions; import java.io.InputStream; import java.util.Arrays; +import java.util.List; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -50,6 +53,8 @@ public class PoseLandmarkerTest { private static final String NO_POSES_IMAGE = "burger.jpg"; private static final String TAG = "Pose Landmarker Test"; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final float VISIBILITY_TOLERANCE = 0.9f; + private static final float PRESENCE_TOLERANCE = 0.9f; private static final int IMAGE_WIDTH = 1000; private static final int IMAGE_HEIGHT = 667; @@ -70,6 +75,8 @@ public class PoseLandmarkerTest { PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + assertAllLandmarksAreVisibleAndPresent( + actualResult, VISIBILITY_TOLERANCE, PRESENCE_TOLERANCE); } @Test @@ -361,4 +368,40 @@ public class PoseLandmarkerTest { assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); } + + private static void assertAllLandmarksAreVisibleAndPresent( + PoseLandmarkerResult result, float visbilityThreshold, float presenceThreshold) { + for (int i = 0; i < result.landmarks().size(); i++) { + List landmarks = result.landmarks().get(i); + for (int j = 0; j < landmarks.size(); j++) { + NormalizedLandmark landmark = landmarks.get(j); + String landmarkMessage = "Landmark List " + i + " landmark " + j + ": " + landmark; + landmark + .visibility() + .ifPresent( + val -> + assertWithMessage(landmarkMessage).that(val).isAtLeast((visbilityThreshold))); + landmark + .presence() + .ifPresent( + val -> assertWithMessage(landmarkMessage).that(val).isAtLeast((presenceThreshold))); + } + } + for (int i = 0; i < result.worldLandmarks().size(); i++) { + List landmarks = result.worldLandmarks().get(i); + for (int j = 0; j < landmarks.size(); j++) { + Landmark landmark = landmarks.get(j); + String landmarkMessage = "World Landmark List " + i + " landmark " + j + ": " + landmark; + landmark + .visibility() + .ifPresent( + val -> + assertWithMessage(landmarkMessage).that(val).isAtLeast((visbilityThreshold))); + landmark + .presence() + .ifPresent( + val -> assertWithMessage(landmarkMessage).that(val).isAtLeast((presenceThreshold))); + } + } + } } diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 27726b707..9688ee919 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -37,7 +37,7 @@ _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' # Tolerance for embedding vector coordinate values. _EPSILON = 1e-4 # Tolerance for cosine similarity evaluation. -_SIMILARITY_TOLERANCE = 1e-6 +_SIMILARITY_TOLERANCE = 1e-3 class ModelFileType(enum.Enum): @@ -287,7 +287,7 @@ class TextEmbedderTest(parameterized.TestCase): @parameterized.parameters( # TODO: The similarity should likely be lower - (_BERT_MODEL_FILE, 0.980880), + (_BERT_MODEL_FILE, 0.98077), (_USE_MODEL_FILE, 0.780334), ) def test_embed_with_different_themes(self, model_file, expected_similarity): diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 4fde58e02..6ea207d67 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -57,6 +57,7 @@ mediapipe_files(srcs = [ "hand_landmarker.task", "left_hands.jpg", "left_hands_rotated.jpg", + "leopard_bg_removal_result_512x512.png", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", @@ -65,6 +66,7 @@ mediapipe_files(srcs = [ "mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite", "mobilenet_v2_1.0_224.tflite", "mobilenet_v3_small_100_224_embedder.tflite", + "mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt.tflite", "mozart_square.jpg", "multi_objects.jpg", "multi_objects_rotated.jpg", @@ -136,6 +138,7 @@ filegroup( "hand_landmark_lite.tflite", "left_hands.jpg", "left_hands_rotated.jpg", + "leopard_bg_removal_result_512x512.png", "mozart_square.jpg", "multi_objects.jpg", "multi_objects_rotated.jpg", diff --git a/mediapipe/tasks/web/components/containers/bounding_box.d.ts b/mediapipe/tasks/web/components/containers/bounding_box.d.ts index 77f2837d1..85811f443 100644 --- a/mediapipe/tasks/web/components/containers/bounding_box.d.ts +++ b/mediapipe/tasks/web/components/containers/bounding_box.d.ts @@ -24,4 +24,10 @@ export declare interface BoundingBox { width: number; /** The height of the bounding box, in pixels. */ height: number; + /** + * Angle of rotation of the original non-rotated box around the top left + * corner of the original non-rotated box, in clockwise degrees from the + * horizontal. + */ + angle: number; } diff --git a/mediapipe/tasks/web/components/processors/detection_result.test.ts b/mediapipe/tasks/web/components/processors/detection_result.test.ts index 0fa8156ba..8e3e413e1 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.test.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.test.ts @@ -58,7 +58,7 @@ describe('convertFromDetectionProto()', () => { categoryName: 'foo', displayName: 'bar', }], - boundingBox: {originX: 1, originY: 2, width: 3, height: 4}, + boundingBox: {originX: 1, originY: 2, width: 3, height: 4, angle: 0}, keypoints: [{ x: 5, y: 6, @@ -85,7 +85,7 @@ describe('convertFromDetectionProto()', () => { categoryName: '', displayName: '', }], - boundingBox: {originX: 0, originY: 0, width: 0, height: 0}, + boundingBox: {originX: 0, originY: 0, width: 0, height: 0, angle: 0}, keypoints: [] }); }); diff --git a/mediapipe/tasks/web/components/processors/detection_result.ts b/mediapipe/tasks/web/components/processors/detection_result.ts index 4999ed31b..6cb5e6230 100644 --- a/mediapipe/tasks/web/components/processors/detection_result.ts +++ b/mediapipe/tasks/web/components/processors/detection_result.ts @@ -42,7 +42,8 @@ export function convertFromDetectionProto(source: DetectionProto): Detection { originX: boundingBox.getXmin() ?? 0, originY: boundingBox.getYmin() ?? 0, width: boundingBox.getWidth() ?? 0, - height: boundingBox.getHeight() ?? 0 + height: boundingBox.getHeight() ?? 0, + angle: 0.0, }; } diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 8c6aae6cf..dde98192d 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -25,9 +25,6 @@ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/regis import {WasmFileset} from './wasm_fileset'; -// None of the MP Tasks ship bundle assets. -const NO_ASSETS = undefined; - // Internal stream names for temporarily keeping memory alive, then freeing it. const FREE_MEMORY_STREAM = 'free_memory'; const UNUSED_STREAM_SUFFIX = '_unused_out'; @@ -61,7 +58,8 @@ export async function createTaskRunner( }; const instance = await createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + type, fileset.wasmLoaderPath, fileset.assetLoaderPath, canvas, + fileLocator); await instance.setOptions(options); return instance; } @@ -96,65 +94,73 @@ export abstract class TaskRunner { abstract setOptions(options: TaskRunnerOptions): Promise; /** - * Applies the current set of options, including any base options that have - * not been processed by the task implementation. The options are applied - * synchronously unless a `modelAssetPath` is provided. This ensures that - * for most use cases options are applied directly and immediately affect + * Applies the current set of options, including optionally any base options + * that have not been processed by the task implementation. The options are + * applied synchronously unless a `modelAssetPath` is provided. This ensures + * that for most use cases options are applied directly and immediately affect * the next inference. + * + * @param options The options for the task. + * @param loadTfliteModel Whether to load the model specified in + * `options.baseOptions`. */ - protected applyOptions(options: TaskRunnerOptions): Promise { - const baseOptions: BaseOptions = options.baseOptions || {}; + protected applyOptions(options: TaskRunnerOptions, loadTfliteModel = true): + Promise { + if (loadTfliteModel) { + const baseOptions: BaseOptions = options.baseOptions || {}; - // Validate that exactly one model is configured - if (options.baseOptions?.modelAssetBuffer && - options.baseOptions?.modelAssetPath) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || - this.baseOptions.getModelAsset()?.hasFileName() || - options.baseOptions?.modelAssetBuffer || - options.baseOptions?.modelAssetPath)) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + this.baseOptions.getModelAsset()?.hasFileName() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => { + if (!response.ok) { + throw new Error(`Failed to fetch model: ${ + baseOptions.modelAssetPath} (${response.status})`); + } else { + return response.arrayBuffer(); + } + }) + .then(buffer => { + try { + // Try to delete file as we cannot overwite an existing file + // using our current API. + this.graphRunner.wasmModule.FS_unlink('/model.dat'); + } catch { + } + // TODO: Consider passing the model to the graph as an + // input side packet as this might reduce copies. + this.graphRunner.wasmModule.FS_createDataFile( + '/', 'model.dat', new Uint8Array(buffer), + /* canRead= */ true, /* canWrite= */ false, + /* canOwn= */ false); + this.setExternalFile('/model.dat'); + this.refreshGraph(); + this.onGraphRefreshed(); + }); + } else { + this.setExternalFile(baseOptions.modelAssetBuffer); + } } - this.setAcceleration(baseOptions); - if (baseOptions.modelAssetPath) { - // We don't use `await` here since we want to apply most settings - // synchronously. - return fetch(baseOptions.modelAssetPath.toString()) - .then(response => { - if (!response.ok) { - throw new Error(`Failed to fetch model: ${ - baseOptions.modelAssetPath} (${response.status})`); - } else { - return response.arrayBuffer(); - } - }) - .then(buffer => { - try { - // Try to delete file as we cannot overwite an existing file using - // our current API. - this.graphRunner.wasmModule.FS_unlink('/model.dat'); - } catch { - } - // TODO: Consider passing the model to the graph as an - // input side packet as this might reduce copies. - this.graphRunner.wasmModule.FS_createDataFile( - '/', 'model.dat', new Uint8Array(buffer), - /* canRead= */ true, /* canWrite= */ false, - /* canOwn= */ false); - this.setExternalFile('/model.dat'); - this.refreshGraph(); - this.onGraphRefreshed(); - }); - } else { - // Apply the setting synchronously. - this.setExternalFile(baseOptions.modelAssetBuffer); - this.refreshGraph(); - this.onGraphRefreshed(); - return Promise.resolve(); - } + // If there is no model to download, we can apply the setting synchronously. + this.refreshGraph(); + this.onGraphRefreshed(); + return Promise.resolve(); } /** Appliest the current options to the MediaPipe graph. */ diff --git a/mediapipe/tasks/web/core/wasm_fileset.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts index 558aa3faf..dda466ad9 100644 --- a/mediapipe/tasks/web/core/wasm_fileset.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -22,4 +22,6 @@ export declare interface WasmFileset { wasmLoaderPath: string; /** The path to the Wasm binary. */ wasmBinaryPath: string; + /** The optional path to the asset loader script. */ + assetLoaderPath?: string; } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index f8f7826d0..3ed15b97d 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -70,7 +70,8 @@ export abstract class VisionTaskRunner extends TaskRunner { * @param imageStreamName the name of the input image stream. * @param normRectStreamName the name of the input normalized rect image * stream used to provide (mandatory) rotation and (optional) - * region-of-interest. + * region-of-interest. `null` if the graph does not support normalized + * rects. * @param roiAllowed Whether this task supports Region-Of-Interest * pre-processing * @@ -79,13 +80,20 @@ export abstract class VisionTaskRunner extends TaskRunner { constructor( protected override readonly graphRunner: VisionGraphRunner, private readonly imageStreamName: string, - private readonly normRectStreamName: string, + private readonly normRectStreamName: string|null, private readonly roiAllowed: boolean) { super(graphRunner); } - /** Configures the shared options of a vision task. */ - override applyOptions(options: VisionTaskOptions): Promise { + /** + * Configures the shared options of a vision task. + * + * @param options The options for the task. + * @param loadTfliteModel Whether to load the model specified in + * `options.baseOptions`. + */ + override applyOptions(options: VisionTaskOptions, loadTfliteModel = true): + Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'IMAGE'; @@ -98,7 +106,7 @@ export abstract class VisionTaskRunner extends TaskRunner { } } - return super.applyOptions(options); + return super.applyOptions(options, loadTfliteModel); } /** Sends a single image to the graph and awaits results. */ @@ -209,11 +217,13 @@ export abstract class VisionTaskRunner extends TaskRunner { imageSource: ImageSource, imageProcessingOptions: ImageProcessingOptions|undefined, timestamp: number): void { - const normalizedRect = - this.convertToNormalizedRect(imageSource, imageProcessingOptions); - this.graphRunner.addProtoToStream( - normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', - this.normRectStreamName, timestamp); + if (this.normRectStreamName) { + const normalizedRect = + this.convertToNormalizedRect(imageSource, imageProcessingOptions); + this.graphRunner.addProtoToStream( + normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', + this.normRectStreamName, timestamp); + } this.graphRunner.addGpuBufferAsImageToStream( imageSource, this.imageStreamName, timestamp ?? performance.now()); this.finishProcessing(); diff --git a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts index dfe84bb17..049edefd6 100644 --- a/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts +++ b/mediapipe/tasks/web/vision/face_detector/face_detector_test.ts @@ -191,7 +191,7 @@ describe('FaceDetector', () => { categoryName: '', displayName: '', }], - boundingBox: {originX: 0, originY: 0, width: 0, height: 0}, + boundingBox: {originX: 0, originY: 0, width: 0, height: 0, angle: 0}, keypoints: [] }); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index 9c63eaba1..6437216b1 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -210,7 +210,7 @@ describe('ObjectDetector', () => { categoryName: '', displayName: '', }], - boundingBox: {originX: 0, originY: 0, width: 0, height: 0}, + boundingBox: {originX: 0, originY: 0, width: 0, height: 0, angle: 0}, keypoints: [] }); }); diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index b9fe8b0c9..ecedeedb2 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -152,6 +152,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", ], diff --git a/mediapipe/util/pose_util.cc b/mediapipe/util/pose_util.cc index 61663ba55..4a6bb6cdb 100644 --- a/mediapipe/util/pose_util.cc +++ b/mediapipe/util/pose_util.cc @@ -1,5 +1,6 @@ #include "mediapipe/util/pose_util.h" +#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" namespace { @@ -192,7 +193,7 @@ void DrawPose(const mediapipe::NormalizedLandmarkList& pose, bool flip_y, } void DrawFace(const mediapipe::NormalizedLandmarkList& face, bool flip_y, - bool draw_nose, bool color_style, bool reverse_color, + bool draw_nose, int color_style, bool reverse_color, int draw_line_width, cv::Mat* image) { const int target_width = image->cols; const int target_height = image->rows; @@ -202,16 +203,26 @@ void DrawFace(const mediapipe::NormalizedLandmarkList& face, bool flip_y, (flip_y ? 1.0f - lm.y() : lm.y()) * target_height); } - cv::Scalar kFaceOvalColor = kWhiteColor; - cv::Scalar kLipsColor = kWhiteColor; - cv::Scalar kLeftEyeColor = kGreenColor; - cv::Scalar kLeftEyebrowColor = kGreenColor; - cv::Scalar kLeftEyeIrisColor = kGreenColor; - cv::Scalar kRightEyeColor = kRedColor; - cv::Scalar kRightEyebrowColor = kRedColor; - cv::Scalar kRightEyeIrisColor = kRedColor; - cv::Scalar kNoseColor = kWhiteColor; - if (color_style) { + cv::Scalar kFaceOvalColor; + cv::Scalar kLipsColor; + cv::Scalar kLeftEyeColor; + cv::Scalar kLeftEyebrowColor; + cv::Scalar kLeftEyeIrisColor; + cv::Scalar kRightEyeColor; + cv::Scalar kRightEyebrowColor; + cv::Scalar kRightEyeIrisColor; + cv::Scalar kNoseColor; + if (color_style == 0) { + kFaceOvalColor = kWhiteColor; + kLipsColor = kWhiteColor; + kLeftEyeColor = kGreenColor; + kLeftEyebrowColor = kGreenColor; + kLeftEyeIrisColor = kGreenColor; + kRightEyeColor = kRedColor; + kRightEyebrowColor = kRedColor; + kRightEyeIrisColor = kRedColor; + kNoseColor = kWhiteColor; + } else if (color_style == 1) { kFaceOvalColor = kWhiteColor; kLipsColor = kBlueColor; kLeftEyeColor = kCyanColor; @@ -221,6 +232,18 @@ void DrawFace(const mediapipe::NormalizedLandmarkList& face, bool flip_y, kRightEyebrowColor = kRedColor; kRightEyeIrisColor = kRedColor; kNoseColor = kYellowColor; + } else if (color_style == 2) { + kFaceOvalColor = kWhiteColor; + kLipsColor = kBlueColor; + kLeftEyeColor = kCyanColor; + kLeftEyebrowColor = kGreenColor; + kLeftEyeIrisColor = kRedColor; + kRightEyeColor = kCyanColor; + kRightEyebrowColor = kGreenColor; + kRightEyeIrisColor = kRedColor; + kNoseColor = kYellowColor; + } else { + LOG(ERROR) << "color_style not supported."; } if (reverse_color) { diff --git a/mediapipe/util/pose_util.h b/mediapipe/util/pose_util.h index d94e22cbe..da952422f 100644 --- a/mediapipe/util/pose_util.h +++ b/mediapipe/util/pose_util.h @@ -24,7 +24,7 @@ void DrawPose(const mediapipe::NormalizedLandmarkList& pose, bool flip_y, cv::Mat* image); void DrawFace(const mediapipe::NormalizedLandmarkList& face, bool flip_y, - bool draw_nose, bool color_style, bool reverse_color, + bool draw_nose, int color_style, bool reverse_color, int draw_line_width, cv::Mat* image); } // namespace mediapipe diff --git a/mediapipe/util/sequence/BUILD b/mediapipe/util/sequence/BUILD index ac7c2ba51..41611d27c 100644 --- a/mediapipe/util/sequence/BUILD +++ b/mediapipe/util/sequence/BUILD @@ -72,7 +72,6 @@ cc_test( "//mediapipe/framework/formats:location", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", - "//mediapipe/framework/port:status", "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) diff --git a/mediapipe/util/sequence/README.md b/mediapipe/util/sequence/README.md index e5b5ed919..960a0d9b5 100644 --- a/mediapipe/util/sequence/README.md +++ b/mediapipe/util/sequence/README.md @@ -555,9 +555,9 @@ without timestamps, use the `context`. |`PREFIX/feature/dimensions`|context int list|`set_feature_dimensions` / `SetFeatureDimensions`|A list of integer dimensions for each feature.| |`PREFIX/feature/rate`|context float|`set_feature_rate` / `SetFeatureRate`|The rate that features are calculated as features per second.| |`PREFIX/feature/bytes/format`|context bytes|`set_feature_bytes_format` / `SetFeatureBytesFormat`|The encoding format if any for features stored as bytes.| -|`PREFIX/context_feature/floats`|context float list|`add_context_feature_floats` / `AddContextFeatureFloats`|A list of floats for the entire example.| -|`PREFIX/context_feature/bytes`|context bytes list|`add_context_feature_bytes` / `AddContextFeatureBytes`|A list of bytes for the entire example. Maybe be encoded.| -|`PREFIX/context_feature/ints`|context int list|`add_context_feature_ints` / `AddContextFeatureInts`|A list of ints for the entire example.| +|`PREFIX/context_feature/floats`|context float list|`set_context_feature_floats` / `AddContextFeatureFloats`|A list of floats for the entire example.| +|`PREFIX/context_feature/bytes`|context bytes list|`set_context_feature_bytes` / `AddContextFeatureBytes`|A list of bytes for the entire example. Maybe be encoded.| +|`PREFIX/context_feature/ints`|context int list|`set_context_feature_ints` / `AddContextFeatureInts`|A list of ints for the entire example.| ### Keys related to audio Audio is a special subtype of generic features with additional data about the @@ -593,6 +593,8 @@ ground truth transcripts. |-----|------|------------------------|-------------| |`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.| |`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.| +|`text/context/token_id`|context int list|`set_text_context_token_id` / `SetTextContextTokenId`|Storage for large blocks of text in the context as token ids.| +|`text/context/embedding`|context float list|`set_text_context_embedding` / `SetTextContextEmbedding`|Storage for large blocks of text in the context as embeddings.| |`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.| |`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.| |`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.| diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index 620d6d483..e4bfcf5a2 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -634,6 +634,10 @@ PREFIXED_IMAGE(InstanceSegmentation, kInstanceSegmentationPrefix); const char kTextLanguageKey[] = "text/language"; // A large block of text that applies to the media. const char kTextContextContentKey[] = "text/context/content"; +// A large block of text that applies to the media as token ids. +const char kTextContextTokenIdKey[] = "text/context/token_id"; +// A large block of text that applies to the media as embeddings. +const char kTextContextEmbeddingKey[] = "text/context/embedding"; // Feature list keys: // The text contents for a given time. @@ -651,6 +655,8 @@ const char kTextTokenIdKey[] = "text/token/id"; BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey); BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey); +VECTOR_INT64_CONTEXT_FEATURE(TextContextTokenId, kTextContextTokenIdKey); +VECTOR_FLOAT_CONTEXT_FEATURE(TextContextEmbedding, kTextContextEmbeddingKey); BYTES_FEATURE_LIST(TextContent, kTextContentKey); INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey); INT64_FEATURE_LIST(TextDuration, kTextDurationKey); diff --git a/mediapipe/util/sequence/media_sequence.py b/mediapipe/util/sequence/media_sequence.py index 1b96383d6..e87d8c21d 100644 --- a/mediapipe/util/sequence/media_sequence.py +++ b/mediapipe/util/sequence/media_sequence.py @@ -601,6 +601,10 @@ _create_image_with_prefix("instance_segmentation", INSTANCE_SEGMENTATION_PREFIX) TEXT_LANGUAGE_KEY = "text/language" # A large block of text that applies to the media. TEXT_CONTEXT_CONTENT_KEY = "text/context/content" +# A large block of text that applies to the media as token ids. +TEXT_CONTEXT_TOKEN_ID_KEY = "text/context/token_id" +# A large block of text that applies to the media as embeddings. +TEXT_CONTEXT_EMBEDDING_KEY = "text/context/embedding" # The text contents for a given time. TEXT_CONTENT_KEY = "text/content" @@ -619,6 +623,10 @@ msu.create_bytes_context_feature( "text_language", TEXT_LANGUAGE_KEY, module_dict=globals()) msu.create_bytes_context_feature( "text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals()) +msu.create_int_list_context_feature( + "text_context_token_id", TEXT_CONTEXT_TOKEN_ID_KEY, module_dict=globals()) +msu.create_float_list_context_feature( + "text_context_embedding", TEXT_CONTEXT_EMBEDDING_KEY, module_dict=globals()) msu.create_bytes_feature_list( "text_content", TEXT_CONTENT_KEY, module_dict=globals()) msu.create_int_feature_list( diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index e220eace0..17365faec 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/port/gmock.h" @@ -711,6 +712,30 @@ TEST(MediaSequenceTest, RoundTripTextContextContent) { ASSERT_FALSE(HasTextContextContent(sequence)); } +TEST(MediaSequenceTest, RoundTripTextContextTokenId) { + tensorflow::SequenceExample sequence; + ASSERT_FALSE(HasTextContextTokenId(sequence)); + std::vector vi = {47, 35}; + SetTextContextTokenId(vi, &sequence); + ASSERT_TRUE(HasTextContextTokenId(sequence)); + ASSERT_EQ(GetTextContextTokenId(sequence).size(), vi.size()); + ASSERT_EQ(GetTextContextTokenId(sequence)[1], vi[1]); + ClearTextContextTokenId(&sequence); + ASSERT_FALSE(HasTextContextTokenId(sequence)); +} + +TEST(MediaSequenceTest, RoundTripTextContextEmbedding) { + tensorflow::SequenceExample sequence; + ASSERT_FALSE(HasTextContextEmbedding(sequence)); + std::vector vi = {47., 35.}; + SetTextContextEmbedding(vi, &sequence); + ASSERT_TRUE(HasTextContextEmbedding(sequence)); + ASSERT_EQ(GetTextContextEmbedding(sequence).size(), vi.size()); + ASSERT_EQ(GetTextContextEmbedding(sequence)[1], vi[1]); + ClearTextContextEmbedding(&sequence); + ASSERT_FALSE(HasTextContextEmbedding(sequence)); +} + TEST(MediaSequenceTest, RoundTripTextContent) { tensorflow::SequenceExample sequence; std::vector text = {"test", "again"}; diff --git a/mediapipe/util/sequence/media_sequence_test.py b/mediapipe/util/sequence/media_sequence_test.py index 5a5c61c7f..5c4ff3827 100644 --- a/mediapipe/util/sequence/media_sequence_test.py +++ b/mediapipe/util/sequence/media_sequence_test.py @@ -129,6 +129,8 @@ class MediaSequenceTest(tf.test.TestCase): ms.add_bbox_embedding_confidence((0.47, 0.49), example) ms.set_text_language(b"test", example) ms.set_text_context_content(b"text", example) + ms.set_text_context_token_id([47, 49], example) + ms.set_text_context_embedding([0.47, 0.49], example) ms.add_text_content(b"one", example) ms.add_text_timestamp(47, example) ms.add_text_confidence(0.47, example) @@ -260,6 +262,29 @@ class MediaSequenceTest(tf.test.TestCase): self.assertFalse(ms.has_feature_dimensions(example, "1")) self.assertFalse(ms.has_feature_dimensions(example, "2")) + def test_text_context_round_trip(self): + example = tf.train.SequenceExample() + text_content = b"text content" + text_token_ids = np.array([1, 2, 3, 4]) + text_embeddings = np.array([0.1, 0.2, 0.3, 0.4]) + self.assertFalse(ms.has_text_context_embedding(example)) + self.assertFalse(ms.has_text_context_token_id(example)) + self.assertFalse(ms.has_text_context_content(example)) + ms.set_text_context_content(text_content, example) + ms.set_text_context_token_id(text_token_ids, example) + ms.set_text_context_embedding(text_embeddings, example) + self.assertEqual(text_content, ms.get_text_context_content(example)) + self.assertAllClose(text_token_ids, ms.get_text_context_token_id(example)) + self.assertAllClose(text_embeddings, ms.get_text_context_embedding(example)) + self.assertTrue(ms.has_text_context_embedding(example)) + self.assertTrue(ms.has_text_context_token_id(example)) + self.assertTrue(ms.has_text_context_content(example)) + ms.clear_text_context_content(example) + ms.clear_text_context_token_id(example) + ms.clear_text_context_embedding(example) + self.assertFalse(ms.has_text_context_embedding(example)) + self.assertFalse(ms.has_text_context_token_id(example)) + self.assertFalse(ms.has_text_context_content(example)) if __name__ == "__main__": tf.test.main() diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 4e40975cb..c1b272b67 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -234,6 +234,11 @@ absl::Status TFLiteGPURunner::InitializeOpenCL( MP_RETURN_IF_ERROR( cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties)); + if (serialized_model_.empty() && + opencl_init_from_serialized_model_is_forced_) { + ASSIGN_OR_RETURN(serialized_model_, GetSerializedModel()); + } + // Try to initialize from serialized model first. if (!serialized_model_.empty()) { absl::Status init_status = InitializeOpenCLFromSerializedModel(builder); @@ -270,7 +275,6 @@ absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel( } absl::StatusOr> TFLiteGPURunner::GetSerializedModel() { - RET_CHECK(runner_) << "Runner is in invalid state."; if (serialized_model_used_) { return serialized_model_; } diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index 5eeaa230f..c64981ef8 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -62,6 +62,9 @@ class TFLiteGPURunner { void ForceOpenGL() { opengl_is_forced_ = true; } void ForceOpenCL() { opencl_is_forced_ = true; } + void ForceOpenCLInitFromSerializedModel() { + opencl_init_from_serialized_model_is_forced_ = true; + } absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id); absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id); @@ -141,6 +144,7 @@ class TFLiteGPURunner { bool opencl_is_forced_ = false; bool opengl_is_forced_ = false; + bool opencl_init_from_serialized_model_is_forced_ = false; }; } // namespace gpu diff --git a/platform_mappings b/platform_mappings new file mode 100644 index 000000000..cfe26f37b --- /dev/null +++ b/platform_mappings @@ -0,0 +1,64 @@ +# This file allows automatically mapping flags such as '--cpu' to the more +# modern Bazel platforms (https://bazel.build/concepts/platforms). + +# In particular, Bazel platforms lack support for Apple for now if no such +# mapping is put into place. It's inspired from: +# https://github.com/bazelbuild/rules_apple/issues/1764 + +platforms: + @build_bazel_apple_support//platforms:macos_x86_64 + --cpu=darwin_x86_64 + + @build_bazel_apple_support//platforms:macos_arm64 + --cpu=darwin_arm64 + + @build_bazel_apple_support//platforms:ios_i386 + --cpu=ios_i386 + + @build_bazel_apple_support//platforms:ios_x86_64 + --cpu=ios_x86_64 + + @build_bazel_apple_support//platforms:ios_sim_arm64 + --cpu=ios_sim_arm64 + + @build_bazel_apple_support//platforms:ios_armv7 + --cpu=ios_armv7 + + @build_bazel_apple_support//platforms:ios_arm64 + --cpu=ios_arm64 + + @build_bazel_apple_support//platforms:ios_arm64e + --cpu=ios_arm64e + +flags: + --cpu=darwin_x86_64 + --apple_platform_type=macos + @build_bazel_apple_support//platforms:macos_x86_64 + + --cpu=darwin_arm64 + --apple_platform_type=macos + @build_bazel_apple_support//platforms:macos_arm64 + + --cpu=ios_i386 + --apple_platform_type=ios + @build_bazel_apple_support//platforms:ios_i386 + + --cpu=ios_x86_64 + --apple_platform_type=ios + @build_bazel_apple_support//platforms:ios_x86_64 + + --cpu=ios_sim_arm64 + --apple_platform_type=ios + @build_bazel_apple_support//platforms:ios_sim_arm64 + + --cpu=ios_armv7 + --apple_platform_type=ios + @build_bazel_apple_support//platforms:ios_armv7 + + --cpu=ios_arm64 + --apple_platform_type=ios + @build_bazel_apple_support//platforms:ios_arm64 + + --cpu=ios_arm64e + --apple_platform_type=ios + @build_bazel_apple_support//platforms:ios_arm64e diff --git a/third_party/BUILD b/third_party/BUILD index 470b7ff99..971e51338 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -378,3 +378,10 @@ java_library( "@maven//:com_google_auto_value_auto_value_annotations", ], ) + +java_import( + name = "any_java_proto", + jars = [ + "@com_google_protobuf//java/core:libcore.jar", + ], +) diff --git a/third_party/com_github_glog_glog.diff b/third_party/com_github_glog_glog.diff new file mode 100644 index 000000000..15447d791 --- /dev/null +++ b/third_party/com_github_glog_glog.diff @@ -0,0 +1,68 @@ +diff --git a/src/logging.cc b/src/logging.cc +index 4028ccc..483e639 100644 +--- a/src/logging.cc ++++ b/src/logging.cc +@@ -1743,6 +1743,23 @@ ostream& LogMessage::stream() { + return data_->stream_; + } + ++namespace { ++#if defined(__ANDROID__) ++int AndroidLogLevel(const int severity) { ++ switch (severity) { ++ case 3: ++ return ANDROID_LOG_FATAL; ++ case 2: ++ return ANDROID_LOG_ERROR; ++ case 1: ++ return ANDROID_LOG_WARN; ++ default: ++ return ANDROID_LOG_INFO; ++ } ++} ++#endif // defined(__ANDROID__) ++} // namespace ++ + // Flush buffered message, called by the destructor, or any other function + // that needs to synchronize the log. + void LogMessage::Flush() { +@@ -1779,6 +1796,12 @@ void LogMessage::Flush() { + } + LogDestination::WaitForSinks(data_); + ++#if defined(__ANDROID__) ++ const int level = AndroidLogLevel((int)data_->severity_); ++ const std::string text = std::string(data_->message_text_); ++ __android_log_write(level, "native", text.substr(0,data_->num_chars_to_log_).c_str()); ++#endif // defined(__ANDROID__) ++ + if (append_newline) { + // Fix the ostrstream back how it was before we screwed with it. + // It's 99.44% certain that we don't need to worry about doing this. + +diff --git a/bazel/glog.bzl b/bazel/glog.bzl +index dacd934..d7b3d78 100644 +--- a/bazel/glog.bzl ++++ b/bazel/glog.bzl +@@ -53,7 +53,6 @@ def glog_library(namespace = "google", with_gflags = 1, **kwargs): + ) + + common_copts = [ +- "-std=c++14", + "-DGLOG_BAZEL_BUILD", + # Inject a C++ namespace. + "-DGOOGLE_NAMESPACE='%s'" % namespace, +@@ -145,7 +144,13 @@ def glog_library(namespace = "google", with_gflags = 1, **kwargs): + ], + }) + ++ c14_opts = ["-std=c++14"] ++ c17_opts = ["-std=c++17"] ++ + final_lib_copts = select({ ++ "@bazel_tools//src/conditions:windows": c17_opts, ++ "//conditions:default": c14_opts, ++ }) + select({ + "@bazel_tools//src/conditions:windows": common_copts + windows_only_copts, + "@bazel_tools//src/conditions:darwin": common_copts + linux_or_darwin_copts + darwin_only_copts, + "@bazel_tools//src/conditions:freebsd": common_copts + linux_or_darwin_copts + freebsd_only_copts, diff --git a/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff b/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff deleted file mode 100644 index 471cf2aa6..000000000 --- a/third_party/com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff +++ /dev/null @@ -1,52 +0,0 @@ -diff --git a/src/logging.cc b/src/logging.cc -index 0b5e6ee..be5a506 100644 ---- a/src/logging.cc -+++ b/src/logging.cc -@@ -67,6 +67,10 @@ - # include "stacktrace.h" - #endif - -+#ifdef __ANDROID__ -+#include -+#endif -+ - using std::string; - using std::vector; - using std::setw; -@@ -1279,6 +1283,23 @@ ostream& LogMessage::stream() { - return data_->stream_; - } - -+namespace { -+#if defined(__ANDROID__) -+int AndroidLogLevel(const int severity) { -+ switch (severity) { -+ case 3: -+ return ANDROID_LOG_FATAL; -+ case 2: -+ return ANDROID_LOG_ERROR; -+ case 1: -+ return ANDROID_LOG_WARN; -+ default: -+ return ANDROID_LOG_INFO; -+ } -+} -+#endif // defined(__ANDROID__) -+} // namespace -+ - // Flush buffered message, called by the destructor, or any other function - // that needs to synchronize the log. - void LogMessage::Flush() { -@@ -1313,6 +1334,12 @@ void LogMessage::Flush() { - } - LogDestination::WaitForSinks(data_); - -+#if defined(__ANDROID__) -+ const int level = AndroidLogLevel((int)data_->severity_); -+ const std::string text = std::string(data_->message_text_); -+ __android_log_write(level, "native", text.substr(0,data_->num_chars_to_log_).c_str()); -+#endif // defined(__ANDROID__) -+ - if (append_newline) { - // Fix the ostrstream back how it was before we screwed with it. - // It's 99.44% certain that we don't need to worry about doing this. diff --git a/third_party/com_github_glog_glog_f2cf2e1bd040fd15016af53598db0cb9b16a6655.diff b/third_party/com_github_glog_glog_f2cf2e1bd040fd15016af53598db0cb9b16a6655.diff deleted file mode 100644 index 560e83ecc..000000000 --- a/third_party/com_github_glog_glog_f2cf2e1bd040fd15016af53598db0cb9b16a6655.diff +++ /dev/null @@ -1,45 +0,0 @@ -https://github.com/google/glog/pull/342 - -diff --git a/CONTRIBUTORS b/CONTRIBUTORS -index d63f62d1..aa0dd4a8 100644 ---- a/CONTRIBUTORS -+++ b/CONTRIBUTORS -@@ -26,6 +26,7 @@ Abhishek Dasgupta - Abhishek Parmar - Andrew Schwartzmeyer - Andy Ying -+Bret McKee - Brian Silverman - Fumitoshi Ukai - Guillaume Dumont -diff --git a/src/glog/logging.h.in b/src/glog/logging.h.in -index 9968b96d..f6dccb29 100644 ---- a/src/glog/logging.h.in -+++ b/src/glog/logging.h.in -@@ -649,6 +649,10 @@ void MakeCheckOpValueString(std::ostream* os, const signed char& v); - template <> GOOGLE_GLOG_DLL_DECL - void MakeCheckOpValueString(std::ostream* os, const unsigned char& v); - -+// Provide printable value for nullptr_t -+template <> GOOGLE_GLOG_DLL_DECL -+void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& v); -+ - // Build the error message string. Specify no inlining for code size. - template - std::string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) -diff --git a/src/logging.cc b/src/logging.cc -index 0c86cf62..256655e5 100644 ---- a/src/logging.cc -+++ b/src/logging.cc -@@ -2163,6 +2163,11 @@ void MakeCheckOpValueString(std::ostream* os, const unsigned char& v) { - } - } - -+template <> -+void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& v) { -+ (*os) << "nullptr"; -+} -+ - void InitGoogleLogging(const char* argv0) { - glog_internal_namespace_::InitGoogleLoggingUtilities(argv0); - } diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 4b51d9de0..9f827c542 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -646,6 +646,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"], ) + http_file( + name = "com_google_mediapipe_leopard_bg_removal_result_512x512_png", + sha256 = "30be22e89fdd1d7b985294498ec67509b0caa1ca941fe291fa25f43a3873e4dd", + urls = ["https://storage.googleapis.com/mediapipe-assets/leopard_bg_removal_result_512x512.png?generation=1690239134617707"], + ) + http_file( name = "com_google_mediapipe_leopard_bg_removal_result_png", sha256 = "afd33f2058fd58d189cda86ec931647741a6139970c9bcbc637cdd151ec657c5", @@ -712,6 +718,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_ica_8bit-with-unsupported-metadata-version.tflite?generation=1661875819091013"], ) + http_file( + name = "com_google_mediapipe_mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt_tflite", + sha256 = "3c4c7e36b35fc903ecfb51b351b4849b23c57cc18d1416cf6cabaa1522d84760", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenetsweep_dptrigmqn384_unit_384_384_fp16quant_fp32input_opt.tflite?generation=1690302146106240"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v1_0_25_192_quantized_1_default_1_tflite", sha256 = "f80999b6324c6f101300c3ee38fbe7e11e74a743b5e0be7350602087fe7430a3", diff --git a/third_party/halide.BUILD b/third_party/halide.BUILD index 677fa9f38..5521f6bb9 100644 --- a/third_party/halide.BUILD +++ b/third_party/halide.BUILD @@ -42,7 +42,7 @@ cc_library( cc_library( name = "lib_halide_static", srcs = select({ - "@halide//:halide_config_windows_x86_64": [ + "@mediapipe//mediapipe:windows": [ "bin/Release/Halide.dll", "lib/Release/Halide.lib", ], diff --git a/third_party/halide/BUILD.bazel b/third_party/halide/BUILD.bazel index 8b69a2503..52fbf0a10 100644 --- a/third_party/halide/BUILD.bazel +++ b/third_party/halide/BUILD.bazel @@ -28,13 +28,13 @@ halide_library_runtimes() name = target_name, actual = select( { - ":halide_config_linux_x86_64": "@linux_halide//:%s" % target_name, - ":halide_config_macos_x86_64": "@macos_x86_64_halide//:%s" % target_name, - ":halide_config_macos_arm64": "@macos_arm_64_halide//:%s" % target_name, - ":halide_config_windows_x86_64": "@windows_halide//:%s" % target_name, - # deliberately no //condition:default clause here + "@mediapipe//mediapipe:macos_x86_64": "@macos_x86_64_halide//:%s" % target_name, + "@mediapipe//mediapipe:macos_arm64": "@macos_arm_64_halide//:%s" % target_name, + "@mediapipe//mediapipe:windows": "@windows_halide//:%s" % target_name, + # Assume Linux x86_64 by default. + # TODO: add mediapipe configs for linux to avoid assuming it's the default. + "//conditions:default": "@linux_halide//:%s" % target_name, }, - no_match_error = "Compiling Halide code requires that the build host is one of Linux x86-64, Windows x86-64, macOS x86-64, or macOS arm64.", ), ) for target_name in [ diff --git a/third_party/halide/halide.bzl b/third_party/halide/halide.bzl index bbb0a1f97..147986255 100644 --- a/third_party/halide/halide.bzl +++ b/third_party/halide/halide.bzl @@ -82,22 +82,22 @@ def halide_runtime_linkopts(): # Map of halide-target-base -> config_settings _HALIDE_TARGET_CONFIG_SETTINGS_MAP = { # Android - "arm-32-android": ["@halide//:halide_config_android_arm"], - "arm-64-android": ["@halide//:halide_config_android_arm64"], - "x86-32-android": ["@halide//:halide_config_android_x86_32"], - "x86-64-android": ["@halide//:halide_config_android_x86_64"], + "arm-32-android": ["@mediapipe//mediapipe:android_arm"], + "arm-64-android": ["@mediapipe//mediapipe:android_arm64"], + "x86-32-android": ["@mediapipe//mediapipe:android_x86"], + "x86-64-android": ["@mediapipe//mediapipe:android_x86_64"], # iOS - "arm-32-ios": ["@halide//:halide_config_ios_arm"], - "arm-64-ios": ["@halide//:halide_config_ios_arm64"], + "arm-32-ios": ["@mediapipe//mediapipe:ios_armv7"], + "arm-64-ios": ["@mediapipe//mediapipe:ios_arm64", "@mediapipe//mediapipe:ios_arm64e"], # OSX (or iOS simulator) - "x86-32-osx": ["@halide//:halide_config_macos_x86_32", "@halide//:halide_config_ios_x86_32"], - "x86-64-osx": ["@halide//:halide_config_macos_x86_64", "@halide//:halide_config_ios_x86_64"], - "arm-64-osx": ["@halide//:halide_config_macos_arm64"], + "x86-32-osx": ["@mediapipe//mediapipe:ios_i386"], + "x86-64-osx": ["@mediapipe//mediapipe:macos_x86_64", "@mediapipe//mediapipe:ios_x86_64"], + "arm-64-osx": ["@mediapipe//mediapipe:macos_arm64"], # Windows - "x86-64-windows": ["@halide//:halide_config_windows_x86_64"], + "x86-64-windows": ["@mediapipe//mediapipe:windows"], # Linux - "x86-64-linux": ["@halide//:halide_config_linux_x86_64"], - # deliberately nothing here using //conditions:default + # TODO: add mediapipe configs for linux to avoid assuming it's the default. + "x86-64-linux": ["//conditions:default"], } _HALIDE_TARGET_MAP_DEFAULT = { @@ -618,19 +618,6 @@ def _standard_library_runtime_names(): return collections.uniq([_halide_library_runtime_target_name(f) for f in _standard_library_runtime_features()]) def halide_library_runtimes(compatible_with = []): - # Note that we don't use all of these combinations - # (and some are invalid), but that's ok. - for cpu in ["arm", "arm64", "x86_32", "x86_64"]: - for os in ["android", "linux", "windows", "ios", "macos"]: - native.config_setting( - name = "halide_config_%s_%s" % (os, cpu), - constraint_values = [ - "@platforms//os:%s" % os, - "@platforms//cpu:%s" % cpu, - ], - visibility = ["//visibility:public"], - ) - unused = [ _define_halide_library_runtime(f, compatible_with = compatible_with) for f in _standard_library_runtime_features() diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 8ef0a71a2..1aae204d7 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,72 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "0d66a26fa5ca638c54ec3e5bffb50aec74ee0880b108d4b5f7d316e9ae36cc9a", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1685638894464709"], + sha256 = "9e5f88363212ac1ad505a0b9e59e3dd34413064f3b70219ff8b0216d6a53128f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1690577772170421"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "014963d19ef6b1f25720379c3df07a6e08b24894ada4938d45b1256e97739318", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1685638897160853"], + sha256 = "8e4c7e9efcfe0d1107b40626f14070f17a817d2b830205ae642ea645fa882d28", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1690577774642876"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", - sha256 = "f03d4826c251783bfc1fb8b82b2d08c00b2e3cb2efcc606305eb210f09fc686b", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1685638899477366"], + sha256 = "9b9d1fbbead06a26461bb664189d46f0c327a1077e67f0aeeb0628d04de13a81", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1690577777075565"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", - sha256 = "36972cf62138bcb5fde37a1fecce334a86b0261eefc1f1daa17b4b8acdc784b4", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1685638901926088"], + sha256 = "44734a8fdb979eb9359de0c0282565d74cdced5d3a6687be849875e0eb11503c", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1690577779811164"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "5745360da942f3bcb585547e8720cb11f19793e68851b119b8f9ea22b120fd06", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1685638904214551"], + sha256 = "93275ebbae8dd2e9be0394391b722a0de5ac9ed51066093b1ac6ec24bebf5813", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1690577782193422"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "b6d8b03fa7fc3e969febfcb63e3db2de900f1f54b82bf2205f02d865fc4790b2", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1685638906864568"], + sha256 = "35e734890cae0c51c1ad91e3589d5777b013bcbac64a5bcbb3a67ce4a5815dd6", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1690577784996034"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", - sha256 = "837ca361044441e6202858b4a9d94b3296c8440099b40e6dafb1efcce76a8f63", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1685638909139832"], + sha256 = "4e6cea3ae95ffac595bfc08f0dab4ff452c91434eb71f92c0dd34250a46825a1", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1690577787398460"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", - sha256 = "507f4089f4a2cf8fe7fb61f48e180f3f86d5e8057fc60ef24c77aae724eb66ba", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1685638911843312"], + sha256 = "43cfab25c1d47822015e434d726a80d84e0bfdb5e685a511ab45d8b5cbe944d3", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1690577790301890"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "82de7a40fdb14833b5ceaeb1ebf219421dbb06ba5e525204737dec196161420d", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1685638914190745"], + sha256 = "6a73602a14484297690e69d716e683341b62a5fde8f5debde78de2651cb69bbe", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1690577792657082"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "d06ac49f4c156cf0c24ef62387b13e48b67476e7f04a423889c59ee835c460f2", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1685638917012370"], + sha256 = "3431f70071f3980bf13e638551e9bb333335223e35542ee768db06501f7a26f2", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1690577795814175"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", - sha256 = "fff428ef91d8cc936f9c3ec81750f5e7ee3c20bc0c76677eb5d8d4d010d2fac0", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1685638919406810"], + sha256 = "ece9ac1f41b93340b08682514ca291431ff7084c858caf6455e65b0c6c3eb717", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1690577798226032"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", - sha256 = "f87c51b8744b0ba564ce725fc3659dba5ef90b4615ac34135ca91c6508434fe9", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1685638922016130"], + sha256 = "4d54739714db6b3d0fbdd0608c2824c4ccceaaf279aa4ba160f2eab2663b30f2", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1690577801077668"], )