diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 08f7f45d0..8c2c2e593 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -93,3 +93,46 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +mediapipe_proto_library( + name = "combined_prediction_calculator_proto", + srcs = ["combined_prediction_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "combined_prediction_calculator", + srcs = ["combined_prediction_calculator.cc"], + deps = [ + ":combined_prediction_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "combined_prediction_calculator_test", + srcs = ["combined_prediction_calculator_test.cc"], + deps = [ + ":combined_prediction_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc new file mode 100644 index 000000000..c7147ea6e --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc @@ -0,0 +1,187 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +Classification GetMaxScoringClassification( + const ClassificationList& classifications) { + Classification max_classification; + max_classification.set_score(0); + for (const auto& input : classifications.classification()) { + if (max_classification.score() < input.score()) { + max_classification = input; + } + } + return max_classification; +} + +float GetScoreThreshold( + const std::string& input_label, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + float threshold = default_threshold; + auto it = classwise_thresholds.find(input_label); + if (it != classwise_thresholds.end()) { + threshold = it->second; + } + return threshold; +} + +std::unique_ptr GetWinningPrediction( + const ClassificationList& classification_list, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + auto prediction_list = std::make_unique(); + if (classification_list.classification().empty()) { + return prediction_list; + } + Classification& prediction = *prediction_list->add_classification(); + auto argmax_prediction = GetMaxScoringClassification(classification_list); + float argmax_prediction_thresh = + GetScoreThreshold(argmax_prediction.label(), classwise_thresholds, + background_label, default_threshold); + if (argmax_prediction.score() >= argmax_prediction_thresh) { + prediction.set_label(argmax_prediction.label()); + prediction.set_score(argmax_prediction.score()); + } else { + for (const auto& input : classification_list.classification()) { + if (input.label() == background_label) { + prediction.set_label(input.label()); + prediction.set_score(input.score()); + break; + } + } + } + return prediction_list; +} + +} // namespace + +// This calculator accepts multiple ClassificationList input streams. Each +// ClassificationList should contain classifications with labels and +// corresponding softmax scores. The calculator computes the best prediction for +// each ClassificationList input stream via argmax and thresholding. Thresholds +// for all classes can be specified in the +// `CombinedPredictionCalculatorOptions`, along with a default global +// threshold. +// Please note that for this calculator to work as designed, the class names +// other than the background class in the ClassificationList objects must be +// different, but the background class name has to be the same. This background +// label name can be set via `background_label` in +// `CombinedPredictionCalculatorOptions`. +// The ClassificationList in the PREDICTION output stream contains the label of +// the winning class and corresponding softmax score. If none of the +// ClassificationList objects has a non-background winning class, the output +// contains the background class and score of the background class in the first +// ClassificationList. If multiple ClassificationList objects have a +// non-background winning class, the output contains the winning prediction from +// the ClassificationList with the highest priority. Priority is in decreasing +// order of input streams to the graph node using this calculator. +// Input: +// At least one stream with ClassificationList. +// Output: +// PREDICTION - A ClassificationList with the winning label as the only item. +// +// Usage example: +// node { +// calculator: "CombinedPredictionCalculator" +// input_stream: "classification_list_0" +// input_stream: "classification_list_1" +// output_stream: "PREDICTION:prediction" +// options { +// [mediapipe.CombinedPredictionCalculatorOptions.ext] { +// class { +// label: "A" +// score_threshold: 0.7 +// } +// default_global_threshold: 0.1 +// background_label: "B" +// } +// } +// } + +class CombinedPredictionCalculator : public Node { + public: + static constexpr Input::Multiple kClassificationListIn{ + ""}; + static constexpr Output kPredictionOut{"PREDICTION"}; + MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut); + + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + for (const auto& input : options_.class_()) { + classwise_thresholds_[input.label()] = input.score_threshold(); + } + classwise_thresholds_[options_.background_label()] = 0; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // After loop, if have winning prediction return. Otherwise empty packet. + std::unique_ptr first_winning_prediction = nullptr; + auto collection = kClassificationListIn(cc); + for (int idx = 0; idx < collection.Count(); ++idx) { + const auto& packet = collection[idx]; + if (packet.IsEmpty()) { + continue; + } + auto prediction = GetWinningPrediction( + packet.Get(), classwise_thresholds_, options_.background_label(), + options_.default_global_threshold()); + if (prediction->classification(0).label() != + options_.background_label()) { + kPredictionOut(cc).Send(std::move(prediction)); + return absl::OkStatus(); + } + if (first_winning_prediction == nullptr) { + first_winning_prediction = std::move(prediction); + } + } + if (first_winning_prediction != nullptr) { + kPredictionOut(cc).Send(std::move(first_winning_prediction)); + } + return absl::OkStatus(); + } + + private: + CombinedPredictionCalculatorOptions options_; + absl::btree_map classwise_thresholds_; +}; + +MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto new file mode 100644 index 000000000..730e7dd78 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message CombinedPredictionCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional CombinedPredictionCalculatorOptions ext = 483738635; + } + + message Class { + optional string label = 1; + optional float score_threshold = 2; + } + + // List of classes with score thresholds. + repeated Class class = 1; + + // Default score threshold applied to a label. + optional float default_global_threshold = 2 [default = 0]; + + // Name of the background class whose input scores will be ignored while + // thresholding. + optional string background_label = 3; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc new file mode 100644 index 000000000..ecf49795b --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc @@ -0,0 +1,315 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +std::unique_ptr BuildNodeRunnerWithOptions( + float drama_thresh, float llama_thresh, float bazinga_thresh, + float joy_thresh, float peace_thresh) { + constexpr absl::string_view kCalculatorProto = R"pb( + calculator: "CombinedPredictionCalculator" + input_stream: "custom_softmax_scores" + input_stream: "canned_softmax_scores" + output_stream: "PREDICTION:prediction" + options { + [mediapipe.CombinedPredictionCalculatorOptions.ext] { + class { label: "CustomDrama" score_threshold: $0 } + class { label: "CustomLlama" score_threshold: $1 } + class { label: "CannedBazinga" score_threshold: $2 } + class { label: "CannedJoy" score_threshold: $3 } + class { label: "CannedPeace" score_threshold: $4 } + background_label: "Negative" + } + } + )pb"; + auto runner = std::make_unique( + absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh, + bazinga_thresh, joy_thresh, peace_thresh)); + return runner; +} + +std::unique_ptr BuildCustomScoreInput( + const float negative_score, const float drama_score, + const float llama_score) { + auto custom_scores = std::make_unique(); + auto custom_negative = custom_scores->add_classification(); + custom_negative->set_label("Negative"); + custom_negative->set_score(negative_score); + auto drama = custom_scores->add_classification(); + drama->set_label("CustomDrama"); + drama->set_score(drama_score); + auto llama = custom_scores->add_classification(); + llama->set_label("CustomLlama"); + llama->set_score(llama_score); + return custom_scores; +} + +std::unique_ptr BuildCannedScoreInput( + const float negative_score, const float bazinga_score, + const float joy_score, const float peace_score) { + auto canned_scores = std::make_unique(); + auto canned_negative = canned_scores->add_classification(); + canned_negative->set_label("Negative"); + canned_negative->set_score(negative_score); + auto bazinga = canned_scores->add_classification(); + bazinga->set_label("CannedBazinga"); + bazinga->set_score(bazinga_score); + auto joy = canned_scores->add_classification(); + joy->set_label("CannedJoy"); + joy->set_score(joy_score); + auto peace = canned_scores->add_classification(); + peace->set_label("CannedPeace"); + peace->set_score(peace_score); + return canned_scores; +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedEmpty_ResultIsEmpty) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty()); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedNotEmpty_ResultIsCanned) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9, + /*joy_thresh=*/0.5, /*peace_thresh=*/0.8); + auto canned_scores = BuildCannedScoreInput( + /*negative_score=*/0.1, + /*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CannedJoy"); + EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomNotEmpty_CannedEmpty_ResultIsCustom) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + auto custom_scores = + BuildCustomScoreInput(/*negative_score=*/0.1, + /*drama_score=*/0.2, /*llama_score=*/0.7); + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CustomLlama"); + EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4); +} + +struct CombinedPredictionCalculatorTestCase { + std::string test_name; + float custom_negative_score; + float drama_score; + float llama_score; + float drama_thresh; + float llama_thresh; + float canned_negative_score; + float bazinga_score; + float joy_score; + float peace_score; + float bazinga_thresh; + float joy_thresh; + float peace_thresh; + std::string max_scoring_label; + float max_score; +}; + +using CombinedPredictionCalculatorTest = + testing::TestWithParam; + +TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) { + const CombinedPredictionCalculatorTestCase& test_case = GetParam(); + + auto runner = BuildNodeRunnerWithOptions( + test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh, + test_case.joy_thresh, test_case.peace_thresh); + + auto custom_scores = + BuildCustomScoreInput(test_case.custom_negative_score, + test_case.drama_score, test_case.llama_score); + + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + + auto canned_scores = BuildCannedScoreInput( + test_case.canned_negative_score, test_case.bazinga_score, + test_case.joy_score, test_case.peace_score); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label); + EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4); +} + +INSTANTIATE_TEST_CASE_P( + CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest, + testing::ValuesIn({ + { + .test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.5, + .llama_score = 0.3, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "CustomDrama", + .max_score = 0.5, + }, + { + .test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.1, + .bazinga_score = 0.4, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.0, + .joy_thresh = 0.0, + .peace_thresh = 0.0, + .max_scoring_label = "CannedBazinga", + .max_score = 0.4, + }, + { + .test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh", + .custom_negative_score = 0.5, + .drama_score = 0.1, + .llama_score = 0.4, + .drama_thresh = 0.1, + .llama_thresh = 0.05, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.5, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh", + .custom_negative_score = 0.8, + .drama_score = 0.1, + .llama_score = 0.1, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.8, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2", + .custom_negative_score = 0.1, + .drama_score = 0.2, + .llama_score = 0.7, + .drama_thresh = 1.1, + .llama_thresh = 1.1, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.3, + .bazinga_score = 0.2, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.5, + .joy_thresh = 0.5, + .peace_thresh = 0.5, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + }), + [](const testing::TestParamInfo< + CombinedPredictionCalculatorTest::ParamType>& info) { + return info.param.test_name; + }); + +} // namespace + +} // namespace mediapipe