Add CombinedPredictionCalculator.

PiperOrigin-RevId: 484301880
This commit is contained in:
MediaPipe Team 2022-10-27 11:14:42 -07:00 committed by Copybara-Service
parent ee84e447b2
commit fc1d75cc99
4 changed files with 586 additions and 0 deletions

View File

@ -93,3 +93,46 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
mediapipe_proto_library(
name = "combined_prediction_calculator_proto",
srcs = ["combined_prediction_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "combined_prediction_calculator",
srcs = ["combined_prediction_calculator.cc"],
deps = [
":combined_prediction_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)
cc_test(
name = "combined_prediction_calculator_test",
srcs = ["combined_prediction_calculator_test.cc"],
deps = [
":combined_prediction_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -0,0 +1,187 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/btree_map.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
constexpr char kPredictionTag[] = "PREDICTION";
Classification GetMaxScoringClassification(
const ClassificationList& classifications) {
Classification max_classification;
max_classification.set_score(0);
for (const auto& input : classifications.classification()) {
if (max_classification.score() < input.score()) {
max_classification = input;
}
}
return max_classification;
}
float GetScoreThreshold(
const std::string& input_label,
const absl::btree_map<std::string, float>& classwise_thresholds,
const std::string& background_label, const float default_threshold) {
float threshold = default_threshold;
auto it = classwise_thresholds.find(input_label);
if (it != classwise_thresholds.end()) {
threshold = it->second;
}
return threshold;
}
std::unique_ptr<ClassificationList> GetWinningPrediction(
const ClassificationList& classification_list,
const absl::btree_map<std::string, float>& classwise_thresholds,
const std::string& background_label, const float default_threshold) {
auto prediction_list = std::make_unique<ClassificationList>();
if (classification_list.classification().empty()) {
return prediction_list;
}
Classification& prediction = *prediction_list->add_classification();
auto argmax_prediction = GetMaxScoringClassification(classification_list);
float argmax_prediction_thresh =
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
background_label, default_threshold);
if (argmax_prediction.score() >= argmax_prediction_thresh) {
prediction.set_label(argmax_prediction.label());
prediction.set_score(argmax_prediction.score());
} else {
for (const auto& input : classification_list.classification()) {
if (input.label() == background_label) {
prediction.set_label(input.label());
prediction.set_score(input.score());
break;
}
}
}
return prediction_list;
}
} // namespace
// This calculator accepts multiple ClassificationList input streams. Each
// ClassificationList should contain classifications with labels and
// corresponding softmax scores. The calculator computes the best prediction for
// each ClassificationList input stream via argmax and thresholding. Thresholds
// for all classes can be specified in the
// `CombinedPredictionCalculatorOptions`, along with a default global
// threshold.
// Please note that for this calculator to work as designed, the class names
// other than the background class in the ClassificationList objects must be
// different, but the background class name has to be the same. This background
// label name can be set via `background_label` in
// `CombinedPredictionCalculatorOptions`.
// The ClassificationList in the PREDICTION output stream contains the label of
// the winning class and corresponding softmax score. If none of the
// ClassificationList objects has a non-background winning class, the output
// contains the background class and score of the background class in the first
// ClassificationList. If multiple ClassificationList objects have a
// non-background winning class, the output contains the winning prediction from
// the ClassificationList with the highest priority. Priority is in decreasing
// order of input streams to the graph node using this calculator.
// Input:
// At least one stream with ClassificationList.
// Output:
// PREDICTION - A ClassificationList with the winning label as the only item.
//
// Usage example:
// node {
// calculator: "CombinedPredictionCalculator"
// input_stream: "classification_list_0"
// input_stream: "classification_list_1"
// output_stream: "PREDICTION:prediction"
// options {
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
// class {
// label: "A"
// score_threshold: 0.7
// }
// default_global_threshold: 0.1
// background_label: "B"
// }
// }
// }
class CombinedPredictionCalculator : public Node {
public:
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
""};
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
for (const auto& input : options_.class_()) {
classwise_thresholds_[input.label()] = input.score_threshold();
}
classwise_thresholds_[options_.background_label()] = 0;
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
// After loop, if have winning prediction return. Otherwise empty packet.
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
auto collection = kClassificationListIn(cc);
for (int idx = 0; idx < collection.Count(); ++idx) {
const auto& packet = collection[idx];
if (packet.IsEmpty()) {
continue;
}
auto prediction = GetWinningPrediction(
packet.Get(), classwise_thresholds_, options_.background_label(),
options_.default_global_threshold());
if (prediction->classification(0).label() !=
options_.background_label()) {
kPredictionOut(cc).Send(std::move(prediction));
return absl::OkStatus();
}
if (first_winning_prediction == nullptr) {
first_winning_prediction = std::move(prediction);
}
}
if (first_winning_prediction != nullptr) {
kPredictionOut(cc).Send(std::move(first_winning_prediction));
}
return absl::OkStatus();
}
private:
CombinedPredictionCalculatorOptions options_;
absl::btree_map<std::string, float> classwise_thresholds_;
};
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,41 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message CombinedPredictionCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional CombinedPredictionCalculatorOptions ext = 483738635;
}
message Class {
optional string label = 1;
optional float score_threshold = 2;
}
// List of classes with score thresholds.
repeated Class class = 1;
// Default score threshold applied to a label.
optional float default_global_threshold = 2 [default = 0];
// Name of the background class whose input scores will be ignored while
// thresholding.
optional string background_label = 3;
}

View File

@ -0,0 +1,315 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kPredictionTag[] = "PREDICTION";
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
float drama_thresh, float llama_thresh, float bazinga_thresh,
float joy_thresh, float peace_thresh) {
constexpr absl::string_view kCalculatorProto = R"pb(
calculator: "CombinedPredictionCalculator"
input_stream: "custom_softmax_scores"
input_stream: "canned_softmax_scores"
output_stream: "PREDICTION:prediction"
options {
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
class { label: "CustomDrama" score_threshold: $0 }
class { label: "CustomLlama" score_threshold: $1 }
class { label: "CannedBazinga" score_threshold: $2 }
class { label: "CannedJoy" score_threshold: $3 }
class { label: "CannedPeace" score_threshold: $4 }
background_label: "Negative"
}
}
)pb";
auto runner = std::make_unique<CalculatorRunner>(
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
bazinga_thresh, joy_thresh, peace_thresh));
return runner;
}
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
const float negative_score, const float drama_score,
const float llama_score) {
auto custom_scores = std::make_unique<ClassificationList>();
auto custom_negative = custom_scores->add_classification();
custom_negative->set_label("Negative");
custom_negative->set_score(negative_score);
auto drama = custom_scores->add_classification();
drama->set_label("CustomDrama");
drama->set_score(drama_score);
auto llama = custom_scores->add_classification();
llama->set_label("CustomLlama");
llama->set_score(llama_score);
return custom_scores;
}
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
const float negative_score, const float bazinga_score,
const float joy_score, const float peace_score) {
auto canned_scores = std::make_unique<ClassificationList>();
auto canned_negative = canned_scores->add_classification();
canned_negative->set_label("Negative");
canned_negative->set_score(negative_score);
auto bazinga = canned_scores->add_classification();
bazinga->set_label("CannedBazinga");
bazinga->set_score(bazinga_score);
auto joy = canned_scores->add_classification();
joy->set_label("CannedJoy");
joy->set_score(joy_score);
auto peace = canned_scores->add_classification();
peace->set_label("CannedPeace");
peace->set_score(peace_score);
return canned_scores;
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomEmpty_CannedEmpty_ResultIsEmpty) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
auto canned_scores = BuildCannedScoreInput(
/*negative_score=*/0.1,
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
runner->MutableInputs()->Index(1).packets.push_back(
Adopt(canned_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), "CannedJoy");
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
auto custom_scores =
BuildCustomScoreInput(/*negative_score=*/0.1,
/*drama_score=*/0.2, /*llama_score=*/0.7);
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(custom_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), "CustomLlama");
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
}
struct CombinedPredictionCalculatorTestCase {
std::string test_name;
float custom_negative_score;
float drama_score;
float llama_score;
float drama_thresh;
float llama_thresh;
float canned_negative_score;
float bazinga_score;
float joy_score;
float peace_score;
float bazinga_thresh;
float joy_thresh;
float peace_thresh;
std::string max_scoring_label;
float max_score;
};
using CombinedPredictionCalculatorTest =
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
auto runner = BuildNodeRunnerWithOptions(
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
test_case.joy_thresh, test_case.peace_thresh);
auto custom_scores =
BuildCustomScoreInput(test_case.custom_negative_score,
test_case.drama_score, test_case.llama_score);
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(custom_scores.release()).At(Timestamp(1)));
auto canned_scores = BuildCannedScoreInput(
test_case.canned_negative_score, test_case.bazinga_score,
test_case.joy_score, test_case.peace_score);
runner->MutableInputs()->Index(1).packets.push_back(
Adopt(canned_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
}
INSTANTIATE_TEST_CASE_P(
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
{
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
.custom_negative_score = 0.1,
.drama_score = 0.5,
.llama_score = 0.3,
.drama_thresh = 0.25,
.llama_thresh = 0.7,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "CustomDrama",
.max_score = 0.5,
},
{
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
.custom_negative_score = 0.1,
.drama_score = 0.3,
.llama_score = 0.6,
.drama_thresh = 0.4,
.llama_thresh = 0.8,
.canned_negative_score = 0.1,
.bazinga_score = 0.4,
.joy_score = 0.3,
.peace_score = 0.2,
.bazinga_thresh = 0.0,
.joy_thresh = 0.0,
.peace_thresh = 0.0,
.max_scoring_label = "CannedBazinga",
.max_score = 0.4,
},
{
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
.custom_negative_score = 0.5,
.drama_score = 0.1,
.llama_score = 0.4,
.drama_thresh = 0.1,
.llama_thresh = 0.05,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.5,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
.custom_negative_score = 0.8,
.drama_score = 0.1,
.llama_score = 0.1,
.drama_thresh = 0.25,
.llama_thresh = 0.7,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.8,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
.custom_negative_score = 0.1,
.drama_score = 0.2,
.llama_score = 0.7,
.drama_thresh = 1.1,
.llama_thresh = 1.1,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.1,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
.custom_negative_score = 0.1,
.drama_score = 0.3,
.llama_score = 0.6,
.drama_thresh = 0.4,
.llama_thresh = 0.8,
.canned_negative_score = 0.3,
.bazinga_score = 0.2,
.joy_score = 0.3,
.peace_score = 0.2,
.bazinga_thresh = 0.5,
.joy_thresh = 0.5,
.peace_thresh = 0.5,
.max_scoring_label = "Negative",
.max_score = 0.1,
},
}),
[](const testing::TestParamInfo<
CombinedPredictionCalculatorTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace mediapipe