diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 40b805d3b..0fb62afaf 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; +constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; + +// Struct holding the different output streams produced by the graph. +struct ClassificationPostprocessingOutputStreams { + Source classification_result; + Source classifications; + Source> timestamped_classifications; +}; // Performs sanity checks on provided ClassifierOptions. absl::Status SanityCheckClassifierOptions( @@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph( // TENSORS - std::vector // The output tensors of an InferenceCalculator. // TIMESTAMPS - std::vector @Optional -// The collection of timestamps that a single ClassificationResult should -// aggregate. This is mostly useful for classifiers working on time series, -// e.g. audio or video classification. +// The collection of the timestamps that this calculator should aggregate. +// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS +// output is used for results. Otherwise as no timestamp aggregation is +// required the CLASSIFICATIONS output is used for results. +// // Outputs: -// CLASSIFICATION_RESULT - ClassificationResult -// The output aggregated classification results. +// CLASSIFICATIONS - ClassificationResult @Optional +// The classification results aggregated by head. Must be connected if the +// TIMESTAMPS input is not connected, as it signals that timestamp +// aggregation is not required. +// TIMESTAMPED_CLASSIFICATIONS - std::vector @Optional +// The classification result aggregated by timestamp, then by head. Must be +// connected if the TIMESTAMPS input is connected, as it signals that +// timestamp aggregation is required. +// // TODO: remove output once migration is over. +// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional +// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { mediapipe::SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( - auto classification_result_out, + auto output_streams, BuildClassificationPostprocessing( sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - classification_result_out >> + output_streams.classification_result >> graph[Output(kClassificationResultTag)]; + output_streams.classifications >> + graph[Output(kClassificationsTag)]; + output_streams.timestamped_classifications >> + graph[Output>( + kTimestampedClassificationsTag)]; return graph.GetConfig(); } private: // Adds an on-device classification postprocessing graph into the provided // builder::Graph instance. The classification postprocessing graph takes - // tensors (std::vector) as input and returns one output - // stream containing the output classification results (ClassificationResult). + // tensors (std::vector) and optional timestamps + // (std::vector) as input and returns two output streams: + // - classification results aggregated by classifier head as a + // ClassificationResult proto, used when no timestamps are passed in + // the graph, + // - classification results aggregated by timestamp then by classifier head + // as a std::vector, used when timestamps are passed + // in the graph. // // options: the on-device ClassificationPostprocessingGraphOptions. // tensors_in: (std::vector>) tensors to postprocess. // timestamps_in: (std::vector) optional collection of - // timestamps that a single ClassificationResult should aggregate. + // timestamps that should be used to aggregate classification results. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr> + absl::StatusOr BuildClassificationPostprocessing( const proto::ClassificationPostprocessingGraphOptions& options, Source> tensors_in, @@ -505,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { timestamps_in >> result_aggregation.In(kTimestampsTag); // Connects output. - return result_aggregation[Output( - kClassificationResultTag)]; + ClassificationPostprocessingOutputStreams output_streams{ + /*classification_result=*/result_aggregation + [Output(kClassificationResultTag)], + /*classifications=*/ + result_aggregation[Output(kClassificationsTag)], + /*timestamped_classifications=*/ + result_aggregation[Output>( + kTimestampedClassificationsTag)]}; + return output_streams; } }; diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index be166982d..48575ceb0 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -45,12 +45,22 @@ namespace processors { // TENSORS - std::vector // The output tensors of an InferenceCalculator. // TIMESTAMPS - std::vector @Optional -// The collection of timestamps that a single ClassificationResult should -// aggregate. This is mostly useful for classifiers working on time series, -// e.g. audio or video classification. +// The collection of the timestamps that this calculator should aggregate. +// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS +// output is used for results. Otherwise as no timestamp aggregation is +// required the CLASSIFICATIONS output is used for results. // Outputs: -// CLASSIFICATION_RESULT - ClassificationResult -// The output aggregated classification results. +// CLASSIFICATIONS - ClassificationResult @Optional +// The classification results aggregated by head. Must be connected if the +// TIMESTAMPS input is not connected, as it signals that timestamp +// aggregation is not required. +// TIMESTAMPED_CLASSIFICATIONS - std::vector @Optional +// The classification result aggregated by timestamp, then by head. Must be +// connected if the TIMESTAMPS input is connected, as it signals that +// timestamp aggregation is required. +// // TODO: remove output once migration is over. +// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional +// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index bb03e2530..d4728e725 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mediapipe/framework/output_stream_poller.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" @@ -64,6 +65,7 @@ using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::testing::HasSubstr; +using ::testing::Pointwise; using ::testing::proto::Approximately; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; @@ -86,6 +88,11 @@ constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationResultName[] = "classification_result"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kClassificationsName[] = "classifications"; +constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; +constexpr char kTimestampedClassificationsName[] = + "timestamped_classifications"; // Helper function to get ModelResources. absl::StatusOr> CreateModelResourcesForModel( @@ -413,6 +420,316 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { } class PostprocessingTest : public tflite_shims::testing::Test { + protected: + absl::StatusOr BuildGraph( + absl::string_view model_name, const proto::ClassifierOptions& options, + bool connect_timestamps = false) { + ASSIGN_OR_RETURN(auto model_resources, + CreateModelResourcesForModel(model_name)); + + Graph graph; + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( + *model_resources, options, + &postprocessing + .GetOptions())); + graph[Input>(kTensorsTag)].SetName(kTensorsName) >> + postprocessing.In(kTensorsTag); + if (connect_timestamps) { + graph[Input>(kTimestampsTag)].SetName( + kTimestampsName) >> + postprocessing.In(kTimestampsTag); + postprocessing.Out(kTimestampedClassificationsTag) + .SetName(kTimestampedClassificationsName) >> + graph[Output>( + kTimestampedClassificationsTag)]; + } else { + postprocessing.Out(kClassificationsTag).SetName(kClassificationsName) >> + graph[Output(kClassificationsTag)]; + } + + MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); + if (connect_timestamps) { + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kTimestampedClassificationsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kClassificationsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + + template + void AddTensor( + const std::vector& tensor, const Tensor::ElementType& element_type, + const Tensor::QuantizationParameters& quantization_parameters = {}) { + tensors_->emplace_back(element_type, + Tensor::Shape{1, static_cast(tensor.size())}, + quantization_parameters); + auto view = tensors_->back().GetCpuWriteView(); + T* buffer = view.buffer(); + std::copy(tensor.begin(), tensor.end(), buffer); + } + + absl::Status Run( + std::optional> aggregation_timestamps = std::nullopt, + int timestamp = 0) { + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); + // Reset tensors for future calls. + tensors_ = absl::make_unique>(); + if (aggregation_timestamps.has_value()) { + auto packet = absl::make_unique>(); + for (const auto& timestamp : *aggregation_timestamps) { + packet->emplace_back(Timestamp(timestamp)); + } + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); + } + return absl::OkStatus(); + } + + template + absl::StatusOr GetResult(OutputStreamPoller& poller) { + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); + MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); + + Packet packet; + if (!poller.Next(&packet)) { + return absl::InternalError("Unable to get output packet"); + } + auto result = packet.Get(); + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); + return result; + } + + private: + CalculatorGraph calculator_graph_; + std::unique_ptr> tensors_ = + absl::make_unique>(); +}; + +TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { + // Build graph. + proto::ClassifierOptions options; + options.set_max_results(3); + options.set_score_threshold(0.5); + MP_ASSERT_OK_AND_ASSIGN( + auto poller, + BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); + // Build input tensors. + std::vector tensor(kMobileNetNumClasses, 0); + tensor[1] = 18; + tensor[2] = 16; + + // Send tensors and get results. + AddTensor(tensor, Tensor::ElementType::kUInt8, + /*quantization_parameters=*/{0.1, 10}); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult(poller)); + + // Validate results. + EXPECT_THAT(results, + EqualsProto(ParseTextProtoOrDie(R"pb( + classifications { + head_index: 0 + classification_list { + classification { index: 1 score: 0.8 } + classification { index: 2 score: 0.6 } + } + } + )pb"))); +} + +TEST_F(PostprocessingTest, SucceedsWithMetadata) { + // Build graph. + proto::ClassifierOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN( + auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); + // Build input tensors. + std::vector tensor(kMobileNetNumClasses, 0); + tensor[1] = 12; + tensor[2] = 14; + tensor[3] = 16; + tensor[4] = 18; + + // Send tensors and get results. + AddTensor(tensor, Tensor::ElementType::kUInt8, + /*quantization_parameters=*/{0.1, 10}); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult(poller)); + + // Validate results. + EXPECT_THAT( + results, EqualsProto(ParseTextProtoOrDie(R"pb( + classifications { + head_index: 0 + head_name: "probability" + classification_list { + classification { index: 4 score: 0.8 label: "tiger shark" } + classification { index: 3 score: 0.6 label: "great white shark" } + classification { index: 2 score: 0.4 label: "goldfish" } + } + } + )pb"))); +} + +TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { + // Build graph. + proto::ClassifierOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN( + auto poller, + BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); + // Build input tensors. + std::vector tensor(kMobileNetNumClasses, 0); + tensor[1] = 12; + tensor[2] = 14; + tensor[3] = 16; + tensor[4] = 18; + + // Send tensors and get results. + AddTensor(tensor, Tensor::ElementType::kUInt8, + /*quantization_parameters=*/{0.1, 10}); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult(poller)); + + // Validate results. + EXPECT_THAT( + results, EqualsProto(ParseTextProtoOrDie(R"pb( + classifications { + head_index: 0 + head_name: "probability" + classification_list { + classification { index: 4 score: 0.6899744811 label: "tiger shark" } + classification { + index: 3 + score: 0.6456563062 + label: "great white shark" + } + classification { index: 2 score: 0.5986876601 label: "goldfish" } + } + } + )pb"))); +} + +TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { + // Build graph. + proto::ClassifierOptions options; + options.set_max_results(2); + MP_ASSERT_OK_AND_ASSIGN( + auto poller, + BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); + // Build input tensors. + std::vector tensor_0(kTwoHeadsNumClasses[0], 0); + tensor_0[1] = 0.2; + tensor_0[2] = 0.4; + tensor_0[3] = 0.6; + std::vector tensor_1(kTwoHeadsNumClasses[1], 0); + tensor_1[1] = 0.2; + tensor_1[2] = 0.4; + tensor_1[3] = 0.6; + + // Send tensors and get results. + AddTensor(tensor_0, Tensor::ElementType::kFloat32); + AddTensor(tensor_1, Tensor::ElementType::kFloat32); + MP_ASSERT_OK(Run()); + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult(poller)); + + // Validate results. + EXPECT_THAT( + results, EqualsProto(ParseTextProtoOrDie(R"pb( + classifications { + head_index: 0 + head_name: "yamnet_classification" + classification_list { + classification { index: 3 score: 0.6 label: "Narration, monologue" } + classification { index: 2 score: 0.4 label: "Conversation" } + } + } + classifications { + head_index: 1 + head_name: "bird_classification" + classification_list { + classification { index: 3 score: 0.6 label: "Azara\'s Spinetail" } + classification { index: 2 score: 0.4 label: "House Sparrow" } + } + } + )pb"))); +} + +TEST_F(PostprocessingTest, SucceedsWithTimestamps) { + // Build graph. + proto::ClassifierOptions options; + options.set_max_results(2); + MP_ASSERT_OK_AND_ASSIGN( + auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, + /*connect_timestamps=*/true)); + // Build input tensors. + std::vector tensor_0(kMobileNetNumClasses, 0); + tensor_0[1] = 12; + tensor_0[2] = 14; + tensor_0[3] = 16; + std::vector tensor_1(kMobileNetNumClasses, 0); + tensor_1[5] = 12; + tensor_1[6] = 14; + tensor_1[7] = 16; + + // Send tensors and get results. + AddTensor(tensor_0, Tensor::ElementType::kUInt8, + /*quantization_parameters=*/{0.1, 10}); + MP_ASSERT_OK(Run()); + AddTensor(tensor_1, Tensor::ElementType::kUInt8, + /*quantization_parameters=*/{0.1, 10}); + MP_ASSERT_OK(Run( + /*aggregation_timestamps=*/std::optional>({0, 1000}), + /*timestamp=*/1000)); + + MP_ASSERT_OK_AND_ASSIGN(auto results, + GetResult>(poller)); + + // Validate results. + EXPECT_THAT( + results, + Pointwise( + EqualsProto(), + {ParseTextProtoOrDie(R"pb( + timestamp_ms: 0 + classifications { + head_index: 0 + head_name: "probability" + classification_list { + classification { + index: 3 + score: 0.6 + label: "great white shark" + } + classification { index: 2 score: 0.4 label: "goldfish" } + } + })pb"), + ParseTextProtoOrDie(R"pb( + timestamp_ms: 1 + classifications { + head_index: 0 + head_name: "probability" + classification_list { + classification { index: 7 score: 0.6 label: "stingray" } + classification { index: 6 score: 0.4 label: "electric ray" } + } + })pb")})); +} + +// TODO: remove these tests once migration is over. +class LegacyPostprocessingTest : public tflite_shims::testing::Test { protected: absl::StatusOr BuildGraph( absl::string_view model_name, const proto::ClassifierOptions& options, @@ -496,7 +813,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { absl::make_unique>(); }; -TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { +TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { // Build graph. proto::ClassifierOptions options; options.set_max_results(3); @@ -525,7 +842,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { })pb")); } -TEST_F(PostprocessingTest, SucceedsWithMetadata) { +TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { // Build graph. proto::ClassifierOptions options; options.set_max_results(3); @@ -568,7 +885,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { })pb")); } -TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { +TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { // Build graph. proto::ClassifierOptions options; options.set_max_results(3); @@ -614,7 +931,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { })pb")); } -TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { +TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { // Build graph. proto::ClassifierOptions options; options.set_max_results(2); @@ -674,7 +991,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { })pb")); } -TEST_F(PostprocessingTest, SucceedsWithTimestamps) { +TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { // Build graph. proto::ClassifierOptions options; options.set_max_results(2);