Add CombinedPredictionCalculator.
PiperOrigin-RevId: 484301880
This commit is contained in:
parent
ee84e447b2
commit
fc1d75cc99
|
@ -93,3 +93,46 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "combined_prediction_calculator_proto",
|
||||
srcs = ["combined_prediction_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "combined_prediction_calculator",
|
||||
srcs = ["combined_prediction_calculator.cc"],
|
||||
deps = [
|
||||
":combined_prediction_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "combined_prediction_calculator_test",
|
||||
srcs = ["combined_prediction_calculator_test.cc"],
|
||||
deps = [
|
||||
":combined_prediction_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
constexpr char kPredictionTag[] = "PREDICTION";
|
||||
|
||||
Classification GetMaxScoringClassification(
|
||||
const ClassificationList& classifications) {
|
||||
Classification max_classification;
|
||||
max_classification.set_score(0);
|
||||
for (const auto& input : classifications.classification()) {
|
||||
if (max_classification.score() < input.score()) {
|
||||
max_classification = input;
|
||||
}
|
||||
}
|
||||
return max_classification;
|
||||
}
|
||||
|
||||
float GetScoreThreshold(
|
||||
const std::string& input_label,
|
||||
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||
const std::string& background_label, const float default_threshold) {
|
||||
float threshold = default_threshold;
|
||||
auto it = classwise_thresholds.find(input_label);
|
||||
if (it != classwise_thresholds.end()) {
|
||||
threshold = it->second;
|
||||
}
|
||||
return threshold;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> GetWinningPrediction(
|
||||
const ClassificationList& classification_list,
|
||||
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||
const std::string& background_label, const float default_threshold) {
|
||||
auto prediction_list = std::make_unique<ClassificationList>();
|
||||
if (classification_list.classification().empty()) {
|
||||
return prediction_list;
|
||||
}
|
||||
Classification& prediction = *prediction_list->add_classification();
|
||||
auto argmax_prediction = GetMaxScoringClassification(classification_list);
|
||||
float argmax_prediction_thresh =
|
||||
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
|
||||
background_label, default_threshold);
|
||||
if (argmax_prediction.score() >= argmax_prediction_thresh) {
|
||||
prediction.set_label(argmax_prediction.label());
|
||||
prediction.set_score(argmax_prediction.score());
|
||||
} else {
|
||||
for (const auto& input : classification_list.classification()) {
|
||||
if (input.label() == background_label) {
|
||||
prediction.set_label(input.label());
|
||||
prediction.set_score(input.score());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return prediction_list;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// This calculator accepts multiple ClassificationList input streams. Each
|
||||
// ClassificationList should contain classifications with labels and
|
||||
// corresponding softmax scores. The calculator computes the best prediction for
|
||||
// each ClassificationList input stream via argmax and thresholding. Thresholds
|
||||
// for all classes can be specified in the
|
||||
// `CombinedPredictionCalculatorOptions`, along with a default global
|
||||
// threshold.
|
||||
// Please note that for this calculator to work as designed, the class names
|
||||
// other than the background class in the ClassificationList objects must be
|
||||
// different, but the background class name has to be the same. This background
|
||||
// label name can be set via `background_label` in
|
||||
// `CombinedPredictionCalculatorOptions`.
|
||||
// The ClassificationList in the PREDICTION output stream contains the label of
|
||||
// the winning class and corresponding softmax score. If none of the
|
||||
// ClassificationList objects has a non-background winning class, the output
|
||||
// contains the background class and score of the background class in the first
|
||||
// ClassificationList. If multiple ClassificationList objects have a
|
||||
// non-background winning class, the output contains the winning prediction from
|
||||
// the ClassificationList with the highest priority. Priority is in decreasing
|
||||
// order of input streams to the graph node using this calculator.
|
||||
// Input:
|
||||
// At least one stream with ClassificationList.
|
||||
// Output:
|
||||
// PREDICTION - A ClassificationList with the winning label as the only item.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "CombinedPredictionCalculator"
|
||||
// input_stream: "classification_list_0"
|
||||
// input_stream: "classification_list_1"
|
||||
// output_stream: "PREDICTION:prediction"
|
||||
// options {
|
||||
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||
// class {
|
||||
// label: "A"
|
||||
// score_threshold: 0.7
|
||||
// }
|
||||
// default_global_threshold: 0.1
|
||||
// background_label: "B"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
class CombinedPredictionCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
|
||||
""};
|
||||
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
|
||||
for (const auto& input : options_.class_()) {
|
||||
classwise_thresholds_[input.label()] = input.score_threshold();
|
||||
}
|
||||
classwise_thresholds_[options_.background_label()] = 0;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
// After loop, if have winning prediction return. Otherwise empty packet.
|
||||
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
|
||||
auto collection = kClassificationListIn(cc);
|
||||
for (int idx = 0; idx < collection.Count(); ++idx) {
|
||||
const auto& packet = collection[idx];
|
||||
if (packet.IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
auto prediction = GetWinningPrediction(
|
||||
packet.Get(), classwise_thresholds_, options_.background_label(),
|
||||
options_.default_global_threshold());
|
||||
if (prediction->classification(0).label() !=
|
||||
options_.background_label()) {
|
||||
kPredictionOut(cc).Send(std::move(prediction));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (first_winning_prediction == nullptr) {
|
||||
first_winning_prediction = std::move(prediction);
|
||||
}
|
||||
}
|
||||
if (first_winning_prediction != nullptr) {
|
||||
kPredictionOut(cc).Send(std::move(first_winning_prediction));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
CombinedPredictionCalculatorOptions options_;
|
||||
absl::btree_map<std::string, float> classwise_thresholds_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message CombinedPredictionCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional CombinedPredictionCalculatorOptions ext = 483738635;
|
||||
}
|
||||
|
||||
message Class {
|
||||
optional string label = 1;
|
||||
optional float score_threshold = 2;
|
||||
}
|
||||
|
||||
// List of classes with score thresholds.
|
||||
repeated Class class = 1;
|
||||
|
||||
// Default score threshold applied to a label.
|
||||
optional float default_global_threshold = 2 [default = 0];
|
||||
|
||||
// Name of the background class whose input scores will be ignored while
|
||||
// thresholding.
|
||||
optional string background_label = 3;
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kPredictionTag[] = "PREDICTION";
|
||||
|
||||
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
|
||||
float drama_thresh, float llama_thresh, float bazinga_thresh,
|
||||
float joy_thresh, float peace_thresh) {
|
||||
constexpr absl::string_view kCalculatorProto = R"pb(
|
||||
calculator: "CombinedPredictionCalculator"
|
||||
input_stream: "custom_softmax_scores"
|
||||
input_stream: "canned_softmax_scores"
|
||||
output_stream: "PREDICTION:prediction"
|
||||
options {
|
||||
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||
class { label: "CustomDrama" score_threshold: $0 }
|
||||
class { label: "CustomLlama" score_threshold: $1 }
|
||||
class { label: "CannedBazinga" score_threshold: $2 }
|
||||
class { label: "CannedJoy" score_threshold: $3 }
|
||||
class { label: "CannedPeace" score_threshold: $4 }
|
||||
background_label: "Negative"
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
auto runner = std::make_unique<CalculatorRunner>(
|
||||
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
|
||||
bazinga_thresh, joy_thresh, peace_thresh));
|
||||
return runner;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
|
||||
const float negative_score, const float drama_score,
|
||||
const float llama_score) {
|
||||
auto custom_scores = std::make_unique<ClassificationList>();
|
||||
auto custom_negative = custom_scores->add_classification();
|
||||
custom_negative->set_label("Negative");
|
||||
custom_negative->set_score(negative_score);
|
||||
auto drama = custom_scores->add_classification();
|
||||
drama->set_label("CustomDrama");
|
||||
drama->set_score(drama_score);
|
||||
auto llama = custom_scores->add_classification();
|
||||
llama->set_label("CustomLlama");
|
||||
llama->set_score(llama_score);
|
||||
return custom_scores;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
|
||||
const float negative_score, const float bazinga_score,
|
||||
const float joy_score, const float peace_score) {
|
||||
auto canned_scores = std::make_unique<ClassificationList>();
|
||||
auto canned_negative = canned_scores->add_classification();
|
||||
canned_negative->set_label("Negative");
|
||||
canned_negative->set_score(negative_score);
|
||||
auto bazinga = canned_scores->add_classification();
|
||||
bazinga->set_label("CannedBazinga");
|
||||
bazinga->set_score(bazinga_score);
|
||||
auto joy = canned_scores->add_classification();
|
||||
joy->set_label("CannedJoy");
|
||||
joy->set_score(joy_score);
|
||||
auto peace = canned_scores->add_classification();
|
||||
peace->set_label("CannedPeace");
|
||||
peace->set_score(peace_score);
|
||||
return canned_scores;
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomEmpty_CannedEmpty_ResultIsEmpty) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
|
||||
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
|
||||
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
|
||||
auto canned_scores = BuildCannedScoreInput(
|
||||
/*negative_score=*/0.1,
|
||||
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
|
||||
runner->MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), "CannedJoy");
|
||||
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
|
||||
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||
auto custom_scores =
|
||||
BuildCustomScoreInput(/*negative_score=*/0.1,
|
||||
/*drama_score=*/0.2, /*llama_score=*/0.7);
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), "CustomLlama");
|
||||
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
|
||||
}
|
||||
|
||||
struct CombinedPredictionCalculatorTestCase {
|
||||
std::string test_name;
|
||||
float custom_negative_score;
|
||||
float drama_score;
|
||||
float llama_score;
|
||||
float drama_thresh;
|
||||
float llama_thresh;
|
||||
float canned_negative_score;
|
||||
float bazinga_score;
|
||||
float joy_score;
|
||||
float peace_score;
|
||||
float bazinga_thresh;
|
||||
float joy_thresh;
|
||||
float peace_thresh;
|
||||
std::string max_scoring_label;
|
||||
float max_score;
|
||||
};
|
||||
|
||||
using CombinedPredictionCalculatorTest =
|
||||
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
|
||||
|
||||
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
|
||||
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
|
||||
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
|
||||
test_case.joy_thresh, test_case.peace_thresh);
|
||||
|
||||
auto custom_scores =
|
||||
BuildCustomScoreInput(test_case.custom_negative_score,
|
||||
test_case.drama_score, test_case.llama_score);
|
||||
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||
|
||||
auto canned_scores = BuildCannedScoreInput(
|
||||
test_case.canned_negative_score, test_case.bazinga_score,
|
||||
test_case.joy_score, test_case.peace_score);
|
||||
runner->MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
|
||||
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
|
||||
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
|
||||
{
|
||||
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.5,
|
||||
.llama_score = 0.3,
|
||||
.drama_thresh = 0.25,
|
||||
.llama_thresh = 0.7,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "CustomDrama",
|
||||
.max_score = 0.5,
|
||||
},
|
||||
{
|
||||
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.3,
|
||||
.llama_score = 0.6,
|
||||
.drama_thresh = 0.4,
|
||||
.llama_thresh = 0.8,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.4,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.2,
|
||||
.bazinga_thresh = 0.0,
|
||||
.joy_thresh = 0.0,
|
||||
.peace_thresh = 0.0,
|
||||
.max_scoring_label = "CannedBazinga",
|
||||
.max_score = 0.4,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.5,
|
||||
.drama_score = 0.1,
|
||||
.llama_score = 0.4,
|
||||
.drama_thresh = 0.1,
|
||||
.llama_thresh = 0.05,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.5,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.8,
|
||||
.drama_score = 0.1,
|
||||
.llama_score = 0.1,
|
||||
.drama_thresh = 0.25,
|
||||
.llama_thresh = 0.7,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.8,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.2,
|
||||
.llama_score = 0.7,
|
||||
.drama_thresh = 1.1,
|
||||
.llama_thresh = 1.1,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.1,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.3,
|
||||
.llama_score = 0.6,
|
||||
.drama_thresh = 0.4,
|
||||
.llama_thresh = 0.8,
|
||||
.canned_negative_score = 0.3,
|
||||
.bazinga_score = 0.2,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.2,
|
||||
.bazinga_thresh = 0.5,
|
||||
.joy_thresh = 0.5,
|
||||
.peace_thresh = 0.5,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.1,
|
||||
},
|
||||
}),
|
||||
[](const testing::TestParamInfo<
|
||||
CombinedPredictionCalculatorTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user