Add CombinedPredictionCalculator.
PiperOrigin-RevId: 484301880
This commit is contained in:
parent
ee84e447b2
commit
fc1d75cc99
|
@ -93,3 +93,46 @@ cc_test(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "combined_prediction_calculator_proto",
|
||||||
|
srcs = ["combined_prediction_calculator.proto"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "combined_prediction_calculator",
|
||||||
|
srcs = ["combined_prediction_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
":combined_prediction_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:collection",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"@com_google_absl//absl/container:btree",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "combined_prediction_calculator_test",
|
||||||
|
srcs = ["combined_prediction_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":combined_prediction_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,187 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/btree_map.h"
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/collection.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kPredictionTag[] = "PREDICTION";
|
||||||
|
|
||||||
|
Classification GetMaxScoringClassification(
|
||||||
|
const ClassificationList& classifications) {
|
||||||
|
Classification max_classification;
|
||||||
|
max_classification.set_score(0);
|
||||||
|
for (const auto& input : classifications.classification()) {
|
||||||
|
if (max_classification.score() < input.score()) {
|
||||||
|
max_classification = input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_classification;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GetScoreThreshold(
|
||||||
|
const std::string& input_label,
|
||||||
|
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||||
|
const std::string& background_label, const float default_threshold) {
|
||||||
|
float threshold = default_threshold;
|
||||||
|
auto it = classwise_thresholds.find(input_label);
|
||||||
|
if (it != classwise_thresholds.end()) {
|
||||||
|
threshold = it->second;
|
||||||
|
}
|
||||||
|
return threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ClassificationList> GetWinningPrediction(
|
||||||
|
const ClassificationList& classification_list,
|
||||||
|
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||||
|
const std::string& background_label, const float default_threshold) {
|
||||||
|
auto prediction_list = std::make_unique<ClassificationList>();
|
||||||
|
if (classification_list.classification().empty()) {
|
||||||
|
return prediction_list;
|
||||||
|
}
|
||||||
|
Classification& prediction = *prediction_list->add_classification();
|
||||||
|
auto argmax_prediction = GetMaxScoringClassification(classification_list);
|
||||||
|
float argmax_prediction_thresh =
|
||||||
|
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
|
||||||
|
background_label, default_threshold);
|
||||||
|
if (argmax_prediction.score() >= argmax_prediction_thresh) {
|
||||||
|
prediction.set_label(argmax_prediction.label());
|
||||||
|
prediction.set_score(argmax_prediction.score());
|
||||||
|
} else {
|
||||||
|
for (const auto& input : classification_list.classification()) {
|
||||||
|
if (input.label() == background_label) {
|
||||||
|
prediction.set_label(input.label());
|
||||||
|
prediction.set_score(input.score());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prediction_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// This calculator accepts multiple ClassificationList input streams. Each
|
||||||
|
// ClassificationList should contain classifications with labels and
|
||||||
|
// corresponding softmax scores. The calculator computes the best prediction for
|
||||||
|
// each ClassificationList input stream via argmax and thresholding. Thresholds
|
||||||
|
// for all classes can be specified in the
|
||||||
|
// `CombinedPredictionCalculatorOptions`, along with a default global
|
||||||
|
// threshold.
|
||||||
|
// Please note that for this calculator to work as designed, the class names
|
||||||
|
// other than the background class in the ClassificationList objects must be
|
||||||
|
// different, but the background class name has to be the same. This background
|
||||||
|
// label name can be set via `background_label` in
|
||||||
|
// `CombinedPredictionCalculatorOptions`.
|
||||||
|
// The ClassificationList in the PREDICTION output stream contains the label of
|
||||||
|
// the winning class and corresponding softmax score. If none of the
|
||||||
|
// ClassificationList objects has a non-background winning class, the output
|
||||||
|
// contains the background class and score of the background class in the first
|
||||||
|
// ClassificationList. If multiple ClassificationList objects have a
|
||||||
|
// non-background winning class, the output contains the winning prediction from
|
||||||
|
// the ClassificationList with the highest priority. Priority is in decreasing
|
||||||
|
// order of input streams to the graph node using this calculator.
|
||||||
|
// Input:
|
||||||
|
// At least one stream with ClassificationList.
|
||||||
|
// Output:
|
||||||
|
// PREDICTION - A ClassificationList with the winning label as the only item.
|
||||||
|
//
|
||||||
|
// Usage example:
|
||||||
|
// node {
|
||||||
|
// calculator: "CombinedPredictionCalculator"
|
||||||
|
// input_stream: "classification_list_0"
|
||||||
|
// input_stream: "classification_list_1"
|
||||||
|
// output_stream: "PREDICTION:prediction"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||||
|
// class {
|
||||||
|
// label: "A"
|
||||||
|
// score_threshold: 0.7
|
||||||
|
// }
|
||||||
|
// default_global_threshold: 0.1
|
||||||
|
// background_label: "B"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
class CombinedPredictionCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
|
||||||
|
""};
|
||||||
|
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
|
||||||
|
for (const auto& input : options_.class_()) {
|
||||||
|
classwise_thresholds_[input.label()] = input.score_threshold();
|
||||||
|
}
|
||||||
|
classwise_thresholds_[options_.background_label()] = 0;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
|
// After loop, if have winning prediction return. Otherwise empty packet.
|
||||||
|
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
|
||||||
|
auto collection = kClassificationListIn(cc);
|
||||||
|
for (int idx = 0; idx < collection.Count(); ++idx) {
|
||||||
|
const auto& packet = collection[idx];
|
||||||
|
if (packet.IsEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto prediction = GetWinningPrediction(
|
||||||
|
packet.Get(), classwise_thresholds_, options_.background_label(),
|
||||||
|
options_.default_global_threshold());
|
||||||
|
if (prediction->classification(0).label() !=
|
||||||
|
options_.background_label()) {
|
||||||
|
kPredictionOut(cc).Send(std::move(prediction));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
if (first_winning_prediction == nullptr) {
|
||||||
|
first_winning_prediction = std::move(prediction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (first_winning_prediction != nullptr) {
|
||||||
|
kPredictionOut(cc).Send(std::move(first_winning_prediction));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CombinedPredictionCalculatorOptions options_;
|
||||||
|
absl::btree_map<std::string, float> classwise_thresholds_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,41 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message CombinedPredictionCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional CombinedPredictionCalculatorOptions ext = 483738635;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Class {
|
||||||
|
optional string label = 1;
|
||||||
|
optional float score_threshold = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// List of classes with score thresholds.
|
||||||
|
repeated Class class = 1;
|
||||||
|
|
||||||
|
// Default score threshold applied to a label.
|
||||||
|
optional float default_global_threshold = 2 [default = 0];
|
||||||
|
|
||||||
|
// Name of the background class whose input scores will be ignored while
|
||||||
|
// thresholding.
|
||||||
|
optional string background_label = 3;
|
||||||
|
}
|
|
@ -0,0 +1,315 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kPredictionTag[] = "PREDICTION";
|
||||||
|
|
||||||
|
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
|
||||||
|
float drama_thresh, float llama_thresh, float bazinga_thresh,
|
||||||
|
float joy_thresh, float peace_thresh) {
|
||||||
|
constexpr absl::string_view kCalculatorProto = R"pb(
|
||||||
|
calculator: "CombinedPredictionCalculator"
|
||||||
|
input_stream: "custom_softmax_scores"
|
||||||
|
input_stream: "canned_softmax_scores"
|
||||||
|
output_stream: "PREDICTION:prediction"
|
||||||
|
options {
|
||||||
|
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||||
|
class { label: "CustomDrama" score_threshold: $0 }
|
||||||
|
class { label: "CustomLlama" score_threshold: $1 }
|
||||||
|
class { label: "CannedBazinga" score_threshold: $2 }
|
||||||
|
class { label: "CannedJoy" score_threshold: $3 }
|
||||||
|
class { label: "CannedPeace" score_threshold: $4 }
|
||||||
|
background_label: "Negative"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb";
|
||||||
|
auto runner = std::make_unique<CalculatorRunner>(
|
||||||
|
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
|
||||||
|
bazinga_thresh, joy_thresh, peace_thresh));
|
||||||
|
return runner;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
|
||||||
|
const float negative_score, const float drama_score,
|
||||||
|
const float llama_score) {
|
||||||
|
auto custom_scores = std::make_unique<ClassificationList>();
|
||||||
|
auto custom_negative = custom_scores->add_classification();
|
||||||
|
custom_negative->set_label("Negative");
|
||||||
|
custom_negative->set_score(negative_score);
|
||||||
|
auto drama = custom_scores->add_classification();
|
||||||
|
drama->set_label("CustomDrama");
|
||||||
|
drama->set_score(drama_score);
|
||||||
|
auto llama = custom_scores->add_classification();
|
||||||
|
llama->set_label("CustomLlama");
|
||||||
|
llama->set_score(llama_score);
|
||||||
|
return custom_scores;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
|
||||||
|
const float negative_score, const float bazinga_score,
|
||||||
|
const float joy_score, const float peace_score) {
|
||||||
|
auto canned_scores = std::make_unique<ClassificationList>();
|
||||||
|
auto canned_negative = canned_scores->add_classification();
|
||||||
|
canned_negative->set_label("Negative");
|
||||||
|
canned_negative->set_score(negative_score);
|
||||||
|
auto bazinga = canned_scores->add_classification();
|
||||||
|
bazinga->set_label("CannedBazinga");
|
||||||
|
bazinga->set_score(bazinga_score);
|
||||||
|
auto joy = canned_scores->add_classification();
|
||||||
|
joy->set_label("CannedJoy");
|
||||||
|
joy->set_score(joy_score);
|
||||||
|
auto peace = canned_scores->add_classification();
|
||||||
|
peace->set_label("CannedPeace");
|
||||||
|
peace->set_score(peace_score);
|
||||||
|
return canned_scores;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CombinedPredictionCalculatorPacketTest,
|
||||||
|
CustomEmpty_CannedEmpty_ResultIsEmpty) {
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
|
||||||
|
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CombinedPredictionCalculatorPacketTest,
|
||||||
|
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
|
||||||
|
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
|
||||||
|
auto canned_scores = BuildCannedScoreInput(
|
||||||
|
/*negative_score=*/0.1,
|
||||||
|
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
|
||||||
|
runner->MutableInputs()->Index(1).packets.push_back(
|
||||||
|
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
|
||||||
|
auto output_prediction_packets =
|
||||||
|
runner->Outputs().Tag(kPredictionTag).packets;
|
||||||
|
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||||
|
Classification output_prediction =
|
||||||
|
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_prediction.label(), "CannedJoy");
|
||||||
|
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CombinedPredictionCalculatorPacketTest,
|
||||||
|
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
|
||||||
|
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||||
|
auto custom_scores =
|
||||||
|
BuildCustomScoreInput(/*negative_score=*/0.1,
|
||||||
|
/*drama_score=*/0.2, /*llama_score=*/0.7);
|
||||||
|
runner->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
|
||||||
|
auto output_prediction_packets =
|
||||||
|
runner->Outputs().Tag(kPredictionTag).packets;
|
||||||
|
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||||
|
Classification output_prediction =
|
||||||
|
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_prediction.label(), "CustomLlama");
|
||||||
|
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CombinedPredictionCalculatorTestCase {
|
||||||
|
std::string test_name;
|
||||||
|
float custom_negative_score;
|
||||||
|
float drama_score;
|
||||||
|
float llama_score;
|
||||||
|
float drama_thresh;
|
||||||
|
float llama_thresh;
|
||||||
|
float canned_negative_score;
|
||||||
|
float bazinga_score;
|
||||||
|
float joy_score;
|
||||||
|
float peace_score;
|
||||||
|
float bazinga_thresh;
|
||||||
|
float joy_thresh;
|
||||||
|
float peace_thresh;
|
||||||
|
std::string max_scoring_label;
|
||||||
|
float max_score;
|
||||||
|
};
|
||||||
|
|
||||||
|
using CombinedPredictionCalculatorTest =
|
||||||
|
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
|
||||||
|
|
||||||
|
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
|
||||||
|
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
|
||||||
|
|
||||||
|
auto runner = BuildNodeRunnerWithOptions(
|
||||||
|
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
|
||||||
|
test_case.joy_thresh, test_case.peace_thresh);
|
||||||
|
|
||||||
|
auto custom_scores =
|
||||||
|
BuildCustomScoreInput(test_case.custom_negative_score,
|
||||||
|
test_case.drama_score, test_case.llama_score);
|
||||||
|
|
||||||
|
runner->MutableInputs()->Index(0).packets.push_back(
|
||||||
|
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
auto canned_scores = BuildCannedScoreInput(
|
||||||
|
test_case.canned_negative_score, test_case.bazinga_score,
|
||||||
|
test_case.joy_score, test_case.peace_score);
|
||||||
|
runner->MutableInputs()->Index(1).packets.push_back(
|
||||||
|
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||||
|
|
||||||
|
auto output_prediction_packets =
|
||||||
|
runner->Outputs().Tag(kPredictionTag).packets;
|
||||||
|
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||||
|
Classification output_prediction =
|
||||||
|
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
|
||||||
|
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
|
||||||
|
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
|
||||||
|
{
|
||||||
|
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.5,
|
||||||
|
.llama_score = 0.3,
|
||||||
|
.drama_thresh = 0.25,
|
||||||
|
.llama_thresh = 0.7,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "CustomDrama",
|
||||||
|
.max_score = 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.3,
|
||||||
|
.llama_score = 0.6,
|
||||||
|
.drama_thresh = 0.4,
|
||||||
|
.llama_thresh = 0.8,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.4,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.2,
|
||||||
|
.bazinga_thresh = 0.0,
|
||||||
|
.joy_thresh = 0.0,
|
||||||
|
.peace_thresh = 0.0,
|
||||||
|
.max_scoring_label = "CannedBazinga",
|
||||||
|
.max_score = 0.4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.5,
|
||||||
|
.drama_score = 0.1,
|
||||||
|
.llama_score = 0.4,
|
||||||
|
.drama_thresh = 0.1,
|
||||||
|
.llama_thresh = 0.05,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
|
||||||
|
.custom_negative_score = 0.8,
|
||||||
|
.drama_score = 0.1,
|
||||||
|
.llama_score = 0.1,
|
||||||
|
.drama_thresh = 0.25,
|
||||||
|
.llama_thresh = 0.7,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.2,
|
||||||
|
.llama_score = 0.7,
|
||||||
|
.drama_thresh = 1.1,
|
||||||
|
.llama_thresh = 1.1,
|
||||||
|
.canned_negative_score = 0.1,
|
||||||
|
.bazinga_score = 0.3,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.3,
|
||||||
|
.bazinga_thresh = 0.7,
|
||||||
|
.joy_thresh = 0.7,
|
||||||
|
.peace_thresh = 0.7,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
|
||||||
|
.custom_negative_score = 0.1,
|
||||||
|
.drama_score = 0.3,
|
||||||
|
.llama_score = 0.6,
|
||||||
|
.drama_thresh = 0.4,
|
||||||
|
.llama_thresh = 0.8,
|
||||||
|
.canned_negative_score = 0.3,
|
||||||
|
.bazinga_score = 0.2,
|
||||||
|
.joy_score = 0.3,
|
||||||
|
.peace_score = 0.2,
|
||||||
|
.bazinga_thresh = 0.5,
|
||||||
|
.joy_thresh = 0.5,
|
||||||
|
.peace_thresh = 0.5,
|
||||||
|
.max_scoring_label = "Negative",
|
||||||
|
.max_score = 0.1,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
[](const testing::TestParamInfo<
|
||||||
|
CombinedPredictionCalculatorTest::ParamType>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user