Add CombinedPredictionCalculator.
PiperOrigin-RevId: 484301880
This commit is contained in:
		
							parent
							
								
									ee84e447b2
								
							
						
					
					
						commit
						fc1d75cc99
					
				| 
						 | 
					@ -93,3 +93,46 @@ cc_test(
 | 
				
			||||||
        "@com_google_absl//absl/strings",
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_proto_library(
 | 
				
			||||||
 | 
					    name = "combined_prediction_calculator_proto",
 | 
				
			||||||
 | 
					    srcs = ["combined_prediction_calculator.proto"],
 | 
				
			||||||
 | 
					    visibility = ["//visibility:public"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_options_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_proto",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "combined_prediction_calculator",
 | 
				
			||||||
 | 
					    srcs = ["combined_prediction_calculator.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":combined_prediction_calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:collection",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:node",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:packet",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/api2:port",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_cc_proto",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/container:btree",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/memory",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/status",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/strings:str_format",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					    alwayslink = 1,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_test(
 | 
				
			||||||
 | 
					    name = "combined_prediction_calculator_test",
 | 
				
			||||||
 | 
					    srcs = ["combined_prediction_calculator_test.cc"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":combined_prediction_calculator",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_framework",
 | 
				
			||||||
 | 
					        "//mediapipe/framework:calculator_runner",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/port:gtest_main",
 | 
				
			||||||
 | 
					        "@com_google_absl//absl/strings",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,187 @@
 | 
				
			||||||
 | 
					/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/container/btree_map.h"
 | 
				
			||||||
 | 
					#include "absl/memory/memory.h"
 | 
				
			||||||
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "absl/strings/str_format.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/node.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/packet.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/port.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/collection.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					namespace api2 {
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constexpr char kPredictionTag[] = "PREDICTION";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Classification GetMaxScoringClassification(
 | 
				
			||||||
 | 
					    const ClassificationList& classifications) {
 | 
				
			||||||
 | 
					  Classification max_classification;
 | 
				
			||||||
 | 
					  max_classification.set_score(0);
 | 
				
			||||||
 | 
					  for (const auto& input : classifications.classification()) {
 | 
				
			||||||
 | 
					    if (max_classification.score() < input.score()) {
 | 
				
			||||||
 | 
					      max_classification = input;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return max_classification;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					float GetScoreThreshold(
 | 
				
			||||||
 | 
					    const std::string& input_label,
 | 
				
			||||||
 | 
					    const absl::btree_map<std::string, float>& classwise_thresholds,
 | 
				
			||||||
 | 
					    const std::string& background_label, const float default_threshold) {
 | 
				
			||||||
 | 
					  float threshold = default_threshold;
 | 
				
			||||||
 | 
					  auto it = classwise_thresholds.find(input_label);
 | 
				
			||||||
 | 
					  if (it != classwise_thresholds.end()) {
 | 
				
			||||||
 | 
					    threshold = it->second;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return threshold;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unique_ptr<ClassificationList> GetWinningPrediction(
 | 
				
			||||||
 | 
					    const ClassificationList& classification_list,
 | 
				
			||||||
 | 
					    const absl::btree_map<std::string, float>& classwise_thresholds,
 | 
				
			||||||
 | 
					    const std::string& background_label, const float default_threshold) {
 | 
				
			||||||
 | 
					  auto prediction_list = std::make_unique<ClassificationList>();
 | 
				
			||||||
 | 
					  if (classification_list.classification().empty()) {
 | 
				
			||||||
 | 
					    return prediction_list;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  Classification& prediction = *prediction_list->add_classification();
 | 
				
			||||||
 | 
					  auto argmax_prediction = GetMaxScoringClassification(classification_list);
 | 
				
			||||||
 | 
					  float argmax_prediction_thresh =
 | 
				
			||||||
 | 
					      GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
 | 
				
			||||||
 | 
					                        background_label, default_threshold);
 | 
				
			||||||
 | 
					  if (argmax_prediction.score() >= argmax_prediction_thresh) {
 | 
				
			||||||
 | 
					    prediction.set_label(argmax_prediction.label());
 | 
				
			||||||
 | 
					    prediction.set_score(argmax_prediction.score());
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    for (const auto& input : classification_list.classification()) {
 | 
				
			||||||
 | 
					      if (input.label() == background_label) {
 | 
				
			||||||
 | 
					        prediction.set_label(input.label());
 | 
				
			||||||
 | 
					        prediction.set_score(input.score());
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return prediction_list;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// This calculator accepts multiple ClassificationList input streams. Each
 | 
				
			||||||
 | 
					// ClassificationList should contain classifications with labels and
 | 
				
			||||||
 | 
					// corresponding softmax scores. The calculator computes the best prediction for
 | 
				
			||||||
 | 
					// each ClassificationList input stream via argmax and thresholding. Thresholds
 | 
				
			||||||
 | 
					// for all classes can be specified in the
 | 
				
			||||||
 | 
					// `CombinedPredictionCalculatorOptions`, along with a default global
 | 
				
			||||||
 | 
					// threshold.
 | 
				
			||||||
 | 
					// Please note that for this calculator to work as designed, the class names
 | 
				
			||||||
 | 
					// other than the background class in the ClassificationList objects must be
 | 
				
			||||||
 | 
					// different, but the background class name has to be the same. This background
 | 
				
			||||||
 | 
					// label name can be set via `background_label` in
 | 
				
			||||||
 | 
					// `CombinedPredictionCalculatorOptions`.
 | 
				
			||||||
 | 
					// The ClassificationList in the PREDICTION output stream contains the label of
 | 
				
			||||||
 | 
					// the winning class and corresponding softmax score. If none of the
 | 
				
			||||||
 | 
					// ClassificationList objects has a non-background winning class, the output
 | 
				
			||||||
 | 
					// contains the background class and score of the background class in the first
 | 
				
			||||||
 | 
					// ClassificationList. If multiple ClassificationList objects have a
 | 
				
			||||||
 | 
					// non-background winning class, the output contains the winning prediction from
 | 
				
			||||||
 | 
					// the ClassificationList with the highest priority. Priority is in decreasing
 | 
				
			||||||
 | 
					// order of input streams to the graph node using this calculator.
 | 
				
			||||||
 | 
					// Input:
 | 
				
			||||||
 | 
					//   At least one stream with ClassificationList.
 | 
				
			||||||
 | 
					// Output:
 | 
				
			||||||
 | 
					//  PREDICTION - A ClassificationList with the winning label as the only item.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Usage example:
 | 
				
			||||||
 | 
					// node {
 | 
				
			||||||
 | 
					//   calculator: "CombinedPredictionCalculator"
 | 
				
			||||||
 | 
					//   input_stream: "classification_list_0"
 | 
				
			||||||
 | 
					//   input_stream: "classification_list_1"
 | 
				
			||||||
 | 
					//   output_stream: "PREDICTION:prediction"
 | 
				
			||||||
 | 
					//   options {
 | 
				
			||||||
 | 
					//     [mediapipe.CombinedPredictionCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					//       class {
 | 
				
			||||||
 | 
					//         label: "A"
 | 
				
			||||||
 | 
					//         score_threshold: 0.7
 | 
				
			||||||
 | 
					//       }
 | 
				
			||||||
 | 
					//       default_global_threshold: 0.1
 | 
				
			||||||
 | 
					//       background_label: "B"
 | 
				
			||||||
 | 
					//     }
 | 
				
			||||||
 | 
					//   }
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CombinedPredictionCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
 | 
				
			||||||
 | 
					      ""};
 | 
				
			||||||
 | 
					  static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Open(CalculatorContext* cc) override {
 | 
				
			||||||
 | 
					    options_ = cc->Options<CombinedPredictionCalculatorOptions>();
 | 
				
			||||||
 | 
					    for (const auto& input : options_.class_()) {
 | 
				
			||||||
 | 
					      classwise_thresholds_[input.label()] = input.score_threshold();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    classwise_thresholds_[options_.background_label()] = 0;
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) override {
 | 
				
			||||||
 | 
					    // After loop, if have winning prediction return. Otherwise empty packet.
 | 
				
			||||||
 | 
					    std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
 | 
				
			||||||
 | 
					    auto collection = kClassificationListIn(cc);
 | 
				
			||||||
 | 
					    for (int idx = 0; idx < collection.Count(); ++idx) {
 | 
				
			||||||
 | 
					      const auto& packet = collection[idx];
 | 
				
			||||||
 | 
					      if (packet.IsEmpty()) {
 | 
				
			||||||
 | 
					        continue;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      auto prediction = GetWinningPrediction(
 | 
				
			||||||
 | 
					          packet.Get(), classwise_thresholds_, options_.background_label(),
 | 
				
			||||||
 | 
					          options_.default_global_threshold());
 | 
				
			||||||
 | 
					      if (prediction->classification(0).label() !=
 | 
				
			||||||
 | 
					          options_.background_label()) {
 | 
				
			||||||
 | 
					        kPredictionOut(cc).Send(std::move(prediction));
 | 
				
			||||||
 | 
					        return absl::OkStatus();
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if (first_winning_prediction == nullptr) {
 | 
				
			||||||
 | 
					        first_winning_prediction = std::move(prediction);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (first_winning_prediction != nullptr) {
 | 
				
			||||||
 | 
					      kPredictionOut(cc).Send(std::move(first_winning_prediction));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    return absl::OkStatus();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  CombinedPredictionCalculatorOptions options_;
 | 
				
			||||||
 | 
					  absl::btree_map<std::string, float> classwise_thresholds_;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace api2
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,41 @@
 | 
				
			||||||
 | 
					/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					syntax = "proto2";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package mediapipe;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "mediapipe/framework/calculator.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message CombinedPredictionCalculatorOptions {
 | 
				
			||||||
 | 
					  extend mediapipe.CalculatorOptions {
 | 
				
			||||||
 | 
					    optional CombinedPredictionCalculatorOptions ext = 483738635;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  message Class {
 | 
				
			||||||
 | 
					    optional string label = 1;
 | 
				
			||||||
 | 
					    optional float score_threshold = 2;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // List of classes with score thresholds.
 | 
				
			||||||
 | 
					  repeated Class class = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Default score threshold applied to a label.
 | 
				
			||||||
 | 
					  optional float default_global_threshold = 2 [default = 0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Name of the background class whose input scores will be ignored while
 | 
				
			||||||
 | 
					  // thresholding.
 | 
				
			||||||
 | 
					  optional string background_label = 3;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,315 @@
 | 
				
			||||||
 | 
					/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cmath>
 | 
				
			||||||
 | 
					#include <cstdint>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
 | 
					#include "absl/strings/substitute.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_runner.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/classification.pb.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_matchers.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					constexpr char kPredictionTag[] = "PREDICTION";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
 | 
				
			||||||
 | 
					    float drama_thresh, float llama_thresh, float bazinga_thresh,
 | 
				
			||||||
 | 
					    float joy_thresh, float peace_thresh) {
 | 
				
			||||||
 | 
					  constexpr absl::string_view kCalculatorProto = R"pb(
 | 
				
			||||||
 | 
					    calculator: "CombinedPredictionCalculator"
 | 
				
			||||||
 | 
					    input_stream: "custom_softmax_scores"
 | 
				
			||||||
 | 
					    input_stream: "canned_softmax_scores"
 | 
				
			||||||
 | 
					    output_stream: "PREDICTION:prediction"
 | 
				
			||||||
 | 
					    options {
 | 
				
			||||||
 | 
					      [mediapipe.CombinedPredictionCalculatorOptions.ext] {
 | 
				
			||||||
 | 
					        class { label: "CustomDrama" score_threshold: $0 }
 | 
				
			||||||
 | 
					        class { label: "CustomLlama" score_threshold: $1 }
 | 
				
			||||||
 | 
					        class { label: "CannedBazinga" score_threshold: $2 }
 | 
				
			||||||
 | 
					        class { label: "CannedJoy" score_threshold: $3 }
 | 
				
			||||||
 | 
					        class { label: "CannedPeace" score_threshold: $4 }
 | 
				
			||||||
 | 
					        background_label: "Negative"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  )pb";
 | 
				
			||||||
 | 
					  auto runner = std::make_unique<CalculatorRunner>(
 | 
				
			||||||
 | 
					      absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
 | 
				
			||||||
 | 
					                       bazinga_thresh, joy_thresh, peace_thresh));
 | 
				
			||||||
 | 
					  return runner;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unique_ptr<ClassificationList> BuildCustomScoreInput(
 | 
				
			||||||
 | 
					    const float negative_score, const float drama_score,
 | 
				
			||||||
 | 
					    const float llama_score) {
 | 
				
			||||||
 | 
					  auto custom_scores = std::make_unique<ClassificationList>();
 | 
				
			||||||
 | 
					  auto custom_negative = custom_scores->add_classification();
 | 
				
			||||||
 | 
					  custom_negative->set_label("Negative");
 | 
				
			||||||
 | 
					  custom_negative->set_score(negative_score);
 | 
				
			||||||
 | 
					  auto drama = custom_scores->add_classification();
 | 
				
			||||||
 | 
					  drama->set_label("CustomDrama");
 | 
				
			||||||
 | 
					  drama->set_score(drama_score);
 | 
				
			||||||
 | 
					  auto llama = custom_scores->add_classification();
 | 
				
			||||||
 | 
					  llama->set_label("CustomLlama");
 | 
				
			||||||
 | 
					  llama->set_score(llama_score);
 | 
				
			||||||
 | 
					  return custom_scores;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unique_ptr<ClassificationList> BuildCannedScoreInput(
 | 
				
			||||||
 | 
					    const float negative_score, const float bazinga_score,
 | 
				
			||||||
 | 
					    const float joy_score, const float peace_score) {
 | 
				
			||||||
 | 
					  auto canned_scores = std::make_unique<ClassificationList>();
 | 
				
			||||||
 | 
					  auto canned_negative = canned_scores->add_classification();
 | 
				
			||||||
 | 
					  canned_negative->set_label("Negative");
 | 
				
			||||||
 | 
					  canned_negative->set_score(negative_score);
 | 
				
			||||||
 | 
					  auto bazinga = canned_scores->add_classification();
 | 
				
			||||||
 | 
					  bazinga->set_label("CannedBazinga");
 | 
				
			||||||
 | 
					  bazinga->set_score(bazinga_score);
 | 
				
			||||||
 | 
					  auto joy = canned_scores->add_classification();
 | 
				
			||||||
 | 
					  joy->set_label("CannedJoy");
 | 
				
			||||||
 | 
					  joy->set_score(joy_score);
 | 
				
			||||||
 | 
					  auto peace = canned_scores->add_classification();
 | 
				
			||||||
 | 
					  peace->set_label("CannedPeace");
 | 
				
			||||||
 | 
					  peace->set_score(peace_score);
 | 
				
			||||||
 | 
					  return canned_scores;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(CombinedPredictionCalculatorPacketTest,
 | 
				
			||||||
 | 
					     CustomEmpty_CannedEmpty_ResultIsEmpty) {
 | 
				
			||||||
 | 
					  auto runner = BuildNodeRunnerWithOptions(
 | 
				
			||||||
 | 
					      /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
 | 
				
			||||||
 | 
					      /*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
 | 
				
			||||||
 | 
					  EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(CombinedPredictionCalculatorPacketTest,
 | 
				
			||||||
 | 
					     CustomEmpty_CannedNotEmpty_ResultIsCanned) {
 | 
				
			||||||
 | 
					  auto runner = BuildNodeRunnerWithOptions(
 | 
				
			||||||
 | 
					      /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
 | 
				
			||||||
 | 
					      /*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
 | 
				
			||||||
 | 
					  auto canned_scores = BuildCannedScoreInput(
 | 
				
			||||||
 | 
					      /*negative_score=*/0.1,
 | 
				
			||||||
 | 
					      /*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
 | 
				
			||||||
 | 
					  runner->MutableInputs()->Index(1).packets.push_back(
 | 
				
			||||||
 | 
					      Adopt(canned_scores.release()).At(Timestamp(1)));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto output_prediction_packets =
 | 
				
			||||||
 | 
					      runner->Outputs().Tag(kPredictionTag).packets;
 | 
				
			||||||
 | 
					  ASSERT_EQ(output_prediction_packets.size(), 1);
 | 
				
			||||||
 | 
					  Classification output_prediction =
 | 
				
			||||||
 | 
					      output_prediction_packets[0].Get<ClassificationList>().classification(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(output_prediction.label(), "CannedJoy");
 | 
				
			||||||
 | 
					  EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(CombinedPredictionCalculatorPacketTest,
 | 
				
			||||||
 | 
					     CustomNotEmpty_CannedEmpty_ResultIsCustom) {
 | 
				
			||||||
 | 
					  auto runner = BuildNodeRunnerWithOptions(
 | 
				
			||||||
 | 
					      /*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
 | 
				
			||||||
 | 
					      /*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
 | 
				
			||||||
 | 
					  auto custom_scores =
 | 
				
			||||||
 | 
					      BuildCustomScoreInput(/*negative_score=*/0.1,
 | 
				
			||||||
 | 
					                            /*drama_score=*/0.2, /*llama_score=*/0.7);
 | 
				
			||||||
 | 
					  runner->MutableInputs()->Index(0).packets.push_back(
 | 
				
			||||||
 | 
					      Adopt(custom_scores.release()).At(Timestamp(1)));
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto output_prediction_packets =
 | 
				
			||||||
 | 
					      runner->Outputs().Tag(kPredictionTag).packets;
 | 
				
			||||||
 | 
					  ASSERT_EQ(output_prediction_packets.size(), 1);
 | 
				
			||||||
 | 
					  Classification output_prediction =
 | 
				
			||||||
 | 
					      output_prediction_packets[0].Get<ClassificationList>().classification(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(output_prediction.label(), "CustomLlama");
 | 
				
			||||||
 | 
					  EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct CombinedPredictionCalculatorTestCase {
 | 
				
			||||||
 | 
					  std::string test_name;
 | 
				
			||||||
 | 
					  float custom_negative_score;
 | 
				
			||||||
 | 
					  float drama_score;
 | 
				
			||||||
 | 
					  float llama_score;
 | 
				
			||||||
 | 
					  float drama_thresh;
 | 
				
			||||||
 | 
					  float llama_thresh;
 | 
				
			||||||
 | 
					  float canned_negative_score;
 | 
				
			||||||
 | 
					  float bazinga_score;
 | 
				
			||||||
 | 
					  float joy_score;
 | 
				
			||||||
 | 
					  float peace_score;
 | 
				
			||||||
 | 
					  float bazinga_thresh;
 | 
				
			||||||
 | 
					  float joy_thresh;
 | 
				
			||||||
 | 
					  float peace_thresh;
 | 
				
			||||||
 | 
					  std::string max_scoring_label;
 | 
				
			||||||
 | 
					  float max_score;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using CombinedPredictionCalculatorTest =
 | 
				
			||||||
 | 
					    testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
 | 
				
			||||||
 | 
					  const CombinedPredictionCalculatorTestCase& test_case = GetParam();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto runner = BuildNodeRunnerWithOptions(
 | 
				
			||||||
 | 
					      test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
 | 
				
			||||||
 | 
					      test_case.joy_thresh, test_case.peace_thresh);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto custom_scores =
 | 
				
			||||||
 | 
					      BuildCustomScoreInput(test_case.custom_negative_score,
 | 
				
			||||||
 | 
					                            test_case.drama_score, test_case.llama_score);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  runner->MutableInputs()->Index(0).packets.push_back(
 | 
				
			||||||
 | 
					      Adopt(custom_scores.release()).At(Timestamp(1)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto canned_scores = BuildCannedScoreInput(
 | 
				
			||||||
 | 
					      test_case.canned_negative_score, test_case.bazinga_score,
 | 
				
			||||||
 | 
					      test_case.joy_score, test_case.peace_score);
 | 
				
			||||||
 | 
					  runner->MutableInputs()->Index(1).packets.push_back(
 | 
				
			||||||
 | 
					      Adopt(canned_scores.release()).At(Timestamp(1)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto output_prediction_packets =
 | 
				
			||||||
 | 
					      runner->Outputs().Tag(kPredictionTag).packets;
 | 
				
			||||||
 | 
					  ASSERT_EQ(output_prediction_packets.size(), 1);
 | 
				
			||||||
 | 
					  Classification output_prediction =
 | 
				
			||||||
 | 
					      output_prediction_packets[0].Get<ClassificationList>().classification(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
 | 
				
			||||||
 | 
					  EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					INSTANTIATE_TEST_CASE_P(
 | 
				
			||||||
 | 
					    CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
 | 
				
			||||||
 | 
					    testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            .test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
 | 
				
			||||||
 | 
					            .custom_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .drama_score = 0.5,
 | 
				
			||||||
 | 
					            .llama_score = 0.3,
 | 
				
			||||||
 | 
					            .drama_thresh = 0.25,
 | 
				
			||||||
 | 
					            .llama_thresh = 0.7,
 | 
				
			||||||
 | 
					            .canned_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .bazinga_score = 0.3,
 | 
				
			||||||
 | 
					            .joy_score = 0.3,
 | 
				
			||||||
 | 
					            .peace_score = 0.3,
 | 
				
			||||||
 | 
					            .bazinga_thresh = 0.7,
 | 
				
			||||||
 | 
					            .joy_thresh = 0.7,
 | 
				
			||||||
 | 
					            .peace_thresh = 0.7,
 | 
				
			||||||
 | 
					            .max_scoring_label = "CustomDrama",
 | 
				
			||||||
 | 
					            .max_score = 0.5,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            .test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
 | 
				
			||||||
 | 
					            .custom_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .drama_score = 0.3,
 | 
				
			||||||
 | 
					            .llama_score = 0.6,
 | 
				
			||||||
 | 
					            .drama_thresh = 0.4,
 | 
				
			||||||
 | 
					            .llama_thresh = 0.8,
 | 
				
			||||||
 | 
					            .canned_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .bazinga_score = 0.4,
 | 
				
			||||||
 | 
					            .joy_score = 0.3,
 | 
				
			||||||
 | 
					            .peace_score = 0.2,
 | 
				
			||||||
 | 
					            .bazinga_thresh = 0.0,
 | 
				
			||||||
 | 
					            .joy_thresh = 0.0,
 | 
				
			||||||
 | 
					            .peace_thresh = 0.0,
 | 
				
			||||||
 | 
					            .max_scoring_label = "CannedBazinga",
 | 
				
			||||||
 | 
					            .max_score = 0.4,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            .test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
 | 
				
			||||||
 | 
					            .custom_negative_score = 0.5,
 | 
				
			||||||
 | 
					            .drama_score = 0.1,
 | 
				
			||||||
 | 
					            .llama_score = 0.4,
 | 
				
			||||||
 | 
					            .drama_thresh = 0.1,
 | 
				
			||||||
 | 
					            .llama_thresh = 0.05,
 | 
				
			||||||
 | 
					            .canned_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .bazinga_score = 0.3,
 | 
				
			||||||
 | 
					            .joy_score = 0.3,
 | 
				
			||||||
 | 
					            .peace_score = 0.3,
 | 
				
			||||||
 | 
					            .bazinga_thresh = 0.7,
 | 
				
			||||||
 | 
					            .joy_thresh = 0.7,
 | 
				
			||||||
 | 
					            .peace_thresh = 0.7,
 | 
				
			||||||
 | 
					            .max_scoring_label = "Negative",
 | 
				
			||||||
 | 
					            .max_score = 0.5,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
 | 
				
			||||||
 | 
					            .custom_negative_score = 0.8,
 | 
				
			||||||
 | 
					            .drama_score = 0.1,
 | 
				
			||||||
 | 
					            .llama_score = 0.1,
 | 
				
			||||||
 | 
					            .drama_thresh = 0.25,
 | 
				
			||||||
 | 
					            .llama_thresh = 0.7,
 | 
				
			||||||
 | 
					            .canned_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .bazinga_score = 0.3,
 | 
				
			||||||
 | 
					            .joy_score = 0.3,
 | 
				
			||||||
 | 
					            .peace_score = 0.3,
 | 
				
			||||||
 | 
					            .bazinga_thresh = 0.7,
 | 
				
			||||||
 | 
					            .joy_thresh = 0.7,
 | 
				
			||||||
 | 
					            .peace_thresh = 0.7,
 | 
				
			||||||
 | 
					            .max_scoring_label = "Negative",
 | 
				
			||||||
 | 
					            .max_score = 0.8,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            .test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
 | 
				
			||||||
 | 
					            .custom_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .drama_score = 0.2,
 | 
				
			||||||
 | 
					            .llama_score = 0.7,
 | 
				
			||||||
 | 
					            .drama_thresh = 1.1,
 | 
				
			||||||
 | 
					            .llama_thresh = 1.1,
 | 
				
			||||||
 | 
					            .canned_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .bazinga_score = 0.3,
 | 
				
			||||||
 | 
					            .joy_score = 0.3,
 | 
				
			||||||
 | 
					            .peace_score = 0.3,
 | 
				
			||||||
 | 
					            .bazinga_thresh = 0.7,
 | 
				
			||||||
 | 
					            .joy_thresh = 0.7,
 | 
				
			||||||
 | 
					            .peace_thresh = 0.7,
 | 
				
			||||||
 | 
					            .max_scoring_label = "Negative",
 | 
				
			||||||
 | 
					            .max_score = 0.1,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
 | 
				
			||||||
 | 
					            .custom_negative_score = 0.1,
 | 
				
			||||||
 | 
					            .drama_score = 0.3,
 | 
				
			||||||
 | 
					            .llama_score = 0.6,
 | 
				
			||||||
 | 
					            .drama_thresh = 0.4,
 | 
				
			||||||
 | 
					            .llama_thresh = 0.8,
 | 
				
			||||||
 | 
					            .canned_negative_score = 0.3,
 | 
				
			||||||
 | 
					            .bazinga_score = 0.2,
 | 
				
			||||||
 | 
					            .joy_score = 0.3,
 | 
				
			||||||
 | 
					            .peace_score = 0.2,
 | 
				
			||||||
 | 
					            .bazinga_thresh = 0.5,
 | 
				
			||||||
 | 
					            .joy_thresh = 0.5,
 | 
				
			||||||
 | 
					            .peace_thresh = 0.5,
 | 
				
			||||||
 | 
					            .max_scoring_label = "Negative",
 | 
				
			||||||
 | 
					            .max_score = 0.1,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					    }),
 | 
				
			||||||
 | 
					    [](const testing::TestParamInfo<
 | 
				
			||||||
 | 
					        CombinedPredictionCalculatorTest::ParamType>& info) {
 | 
				
			||||||
 | 
					      return info.param.test_name;
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user