diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 7d01e4dfe..a688f291a 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -44,6 +44,30 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "classification_aggregation_calculator_test", + srcs = ["classification_aggregation_calculator_test.cc"], + deps = [ + ":classification_aggregation_calculator", + ":classification_aggregation_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:output_stream_poller", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) + mediapipe_proto_library( name = "score_calibration_calculator_proto", srcs = ["score_calibration_calculator.proto"], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index e1f69e607..1a83fdad2 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -31,37 +31,62 @@ namespace mediapipe { namespace api2 { -using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::Classifications; -// Aggregates ClassificationLists into a single ClassificationResult that has -// 3 dimensions: (classification head, classification timestamp, classification -// category). +// Aggregates ClassificationLists into either a ClassificationResult object +// representing the classification results aggregated by classifier head, or +// into an std::vector representing the classification +// results aggregated first by timestamp then by classifier head. // // Inputs: -// CLASSIFICATIONS - ClassificationList +// CLASSIFICATIONS - ClassificationList @Multiple // ClassificationList per classification head. // TIMESTAMPS - std::vector @Optional -// The collection of the timestamps that a single ClassificationResult -// should aggragate. This stream is optional, and the timestamp information -// will only be populated to the ClassificationResult proto when this stream -// is connected. +// The collection of the timestamps that this calculator should aggregate. +// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS +// output is used for results. Otherwise as no timestamp aggregation is +// required the CLASSIFICATIONS output is used for results. // // Outputs: -// CLASSIFICATION_RESULT - ClassificationResult +// CLASSIFICATIONS - ClassificationResult @Optional +// The classification results aggregated by head. Must be connected if the +// TIMESTAMPS input is not connected, as it signals that timestamp +// aggregation is not required. +// TIMESTAMPED_CLASSIFICATIONS - std::vector @Optional +// The classification result aggregated by timestamp, then by head. Must be +// connected if the TIMESTAMPS input is connected, as it signals that +// timestamp aggregation is required. +// // TODO: remove output once migration is over. +// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional // The aggregated classification result. // -// Example: +// Example without timestamp aggregation: +// node { +// calculator: "ClassificationAggregationCalculator" +// input_stream: "CLASSIFICATIONS:0:stream_a" +// input_stream: "CLASSIFICATIONS:1:stream_b" +// input_stream: "CLASSIFICATIONS:2:stream_c" +// output_stream: "CLASSIFICATIONS:classifications" +// options { +// [mediapipe.ClassificationAggregationCalculatorOptions.ext] { +// head_names: "head_name_a" +// head_names: "head_name_b" +// head_names: "head_name_c" +// } +// } +// } +// +// Example with timestamp aggregation: // node { // calculator: "ClassificationAggregationCalculator" // input_stream: "CLASSIFICATIONS:0:stream_a" // input_stream: "CLASSIFICATIONS:1:stream_b" // input_stream: "CLASSIFICATIONS:2:stream_c" // input_stream: "TIMESTAMPS:timestamps" -// output_stream: "CLASSIFICATION_RESULT:classification_result" +// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications" // options { -// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] { +// [mediapipe.ClassificationAggregationCalculatorOptions.ext] { // head_names: "head_name_a" // head_names: "head_name_b" // head_names: "head_name_c" @@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node { "CLASSIFICATIONS"}; static constexpr Input>::Optional kTimestampsIn{ "TIMESTAMPS"}; - static constexpr Output kOut{"CLASSIFICATION_RESULT"}; - MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut); + static constexpr Output::Optional kClassificationsOut{ + "CLASSIFICATIONS"}; + static constexpr Output>::Optional + kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"}; + static constexpr Output::Optional + kClassificationResultOut{"CLASSIFICATION_RESULT"}; + MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, + kClassificationsOut, kTimestampedClassificationsOut, + kClassificationResultOut); static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc); @@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node { cached_classifications_; ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); + std::vector ConvertToTimestampedClassificationResults( + CalculatorContext* cc); + // TODO: deprecate this function once migration is over. + ClassificationResult LegacyConvertToClassificationResult( + CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } + // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if + // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is + // not connected. All dependent tasks must be updated to use these outputs + // first. return absl::OkStatus(); } @@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process( [](const auto& elem) -> ClassificationList { return elem.Get(); }); cached_classifications_[cc->InputTimestamp().Value()] = std::move(classification_lists); - if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) { - return absl::OkStatus(); + ClassificationResult classification_result; + if (time_aggregation_enabled_) { + if (kTimestampsIn(cc).IsEmpty()) { + return absl::OkStatus(); + } + classification_result = LegacyConvertToClassificationResult(cc); + kTimestampedClassificationsOut(cc).Send( + ConvertToTimestampedClassificationResults(cc)); + } else { + classification_result = LegacyConvertToClassificationResult(cc); + kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } - kOut(cc).Send(ConvertToClassificationResult(cc)); + kClassificationResultOut(cc).Send(classification_result); RET_CHECK(cached_classifications_.empty()); return absl::OkStatus(); } @@ -136,6 +186,50 @@ ClassificationResult ClassificationAggregationCalculator::ConvertToClassificationResult( CalculatorContext* cc) { ClassificationResult result; + auto& classification_lists = + cached_classifications_[cc->InputTimestamp().Value()]; + for (int i = 0; i < classification_lists.size(); ++i) { + auto classifications = result.add_classifications(); + classifications->set_head_index(i); + if (!head_names_.empty()) { + classifications->set_head_name(head_names_[i]); + } + *classifications->mutable_classification_list() = + std::move(classification_lists[i]); + } + cached_classifications_.erase(cc->InputTimestamp().Value()); + return result; +} + +std::vector +ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( + CalculatorContext* cc) { + auto timestamps = kTimestampsIn(cc).Get(); + std::vector results; + results.reserve(timestamps.size()); + for (const auto& timestamp : timestamps) { + ClassificationResult result; + result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000); + auto& classification_lists = cached_classifications_[timestamp.Value()]; + for (int i = 0; i < classification_lists.size(); ++i) { + auto classifications = result.add_classifications(); + classifications->set_head_index(i); + if (!head_names_.empty()) { + classifications->set_head_name(head_names_[i]); + } + *classifications->mutable_classification_list() = + std::move(classification_lists[i]); + } + cached_classifications_.erase(timestamp.Value()); + results.push_back(std::move(result)); + } + return results; +} + +ClassificationResult +ClassificationAggregationCalculator::LegacyConvertToClassificationResult( + CalculatorContext* cc) { + ClassificationResult result; Timestamp first_timestamp(0); std::vector timestamps; if (time_aggregation_enabled_) { @@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult( entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / 1000); } - cached_classifications_.erase(timestamp.Value()); } return result; } diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto index c2a74a48a..e2ed1788e 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc new file mode 100644 index 000000000..1bc8cafd6 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc @@ -0,0 +1,213 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::ParseTextProtoOrDie; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::testing::Pointwise; + +constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0"; +constexpr char kClassificationInput0Name[] = "classifications_0"; +constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1"; +constexpr char kClassificationInput1Name[] = "classifications_1"; +constexpr char kTimestampsTag[] = "TIMESTAMPS"; +constexpr char kTimestampsName[] = "timestamps"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; +constexpr char kClassificationsName[] = "classifications"; +constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; +constexpr char kTimestampedClassificationsName[] = + "timestamped_classifications"; + +ClassificationList MakeClassificationList(int class_index) { + return ParseTextProtoOrDie(absl::StrFormat( + R"pb( + classification { index: %d } + )pb", + class_index)); +} + +class ClassificationAggregationCalculatorTest + : public tflite_shims::testing::Test { + protected: + absl::StatusOr BuildGraph( + bool connect_timestamps = false) { + Graph graph; + auto& calculator = graph.AddNode("ClassificationAggregationCalculator"); + calculator + .GetOptions() = + ParseTextProtoOrDie< + mediapipe::ClassificationAggregationCalculatorOptions>( + R"pb(head_names: "foo" head_names: "bar")pb"); + graph[Input(kClassificationInput0Tag)].SetName( + kClassificationInput0Name) >> + calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0)); + graph[Input(kClassificationInput1Tag)].SetName( + kClassificationInput1Name) >> + calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1)); + if (connect_timestamps) { + graph[Input>(kTimestampsTag)].SetName( + kTimestampsName) >> + calculator.In(kTimestampsTag); + calculator.Out(kTimestampedClassificationsTag) + .SetName(kTimestampedClassificationsName) >> + graph[Output>( + kTimestampedClassificationsTag)]; + } else { + calculator.Out(kClassificationsTag).SetName(kClassificationsName) >> + graph[Output(kClassificationsTag)]; + } + + MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); + if (connect_timestamps) { + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kTimestampedClassificationsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + kClassificationsName)); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + + absl::Status Send( + std::vector classifications, int timestamp = 0, + std::optional> aggregation_timestamps = std::nullopt) { + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kClassificationInput0Name, + MakePacket(classifications[0]) + .At(Timestamp(timestamp)))); + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kClassificationInput1Name, + MakePacket(classifications[1]) + .At(Timestamp(timestamp)))); + if (aggregation_timestamps.has_value()) { + auto packet = std::make_unique>(); + for (const auto& timestamp : *aggregation_timestamps) { + packet->emplace_back(Timestamp(timestamp)); + } + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); + } + return absl::OkStatus(); + } + + template + absl::StatusOr GetResult(OutputStreamPoller& poller) { + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); + MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); + + Packet packet; + if (!poller.Next(&packet)) { + return absl::InternalError("Unable to get output packet"); + } + auto result = packet.Get(); + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); + return result; + } + + private: + CalculatorGraph calculator_graph_; +}; + +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) { + MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph()); + MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); + MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult(poller)); + + EXPECT_THAT(result, + EqualsProto(ParseTextProtoOrDie( + R"pb(classifications { + head_index: 0 + head_name: "foo" + classification_list { classification { index: 0 } } + } + classifications { + head_index: 1 + head_name: "bar" + classification_list { classification { index: 1 } } + })pb"))); +} + +TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) { + MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true)); + MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)})); + MP_ASSERT_OK(Send( + {MakeClassificationList(2), MakeClassificationList(3)}, + /*timestamp=*/1000, + /*aggregation_timestamps=*/std::optional>({0, 1000}))); + MP_ASSERT_OK_AND_ASSIGN(auto result, + GetResult>(poller)); + + EXPECT_THAT(result, + Pointwise(EqualsProto(), + {ParseTextProtoOrDie(R"pb( + timestamp_ms: 0, + classifications { + head_index: 0 + head_name: "foo" + classification_list { classification { index: 0 } } + } + classifications { + head_index: 1 + head_name: "bar" + classification_list { classification { index: 1 } } + } + )pb"), + ParseTextProtoOrDie(R"pb( + timestamp_ms: 1, + classifications { + head_index: 0 + head_name: "foo" + classification_list { classification { index: 2 } } + } + classifications { + head_index: 1 + head_name: "bar" + classification_list { classification { index: 3 } } + } + )pb")})); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 633b5b369..7b455c0c4 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -28,6 +28,7 @@ mediapipe_proto_library( srcs = ["classifications.proto"], deps = [ ":category_proto", + "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto index 2ba760e99..412e71428 100644 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -20,6 +20,7 @@ package mediapipe.tasks.components.containers.proto; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "CategoryProto"; +// TODO: deprecate this message once migration is over. // A single classification result. message Category { // The index of the category in the corresponding label map, usually packed in diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index 712607fa6..f098ed0e4 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -17,11 +17,13 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; +import "mediapipe/framework/formats/classification.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; +// TODO: deprecate this message once migration is over. // List of predicted categories with an optional timestamp. message ClassificationEntry { // The array of predicted categories, usually sorted by descending scores, @@ -33,9 +35,12 @@ message ClassificationEntry { optional int64 timestamp_ms = 2; } -// Classifications for a given classifier head. +// Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { + // TODO: deprecate this field once migration is over. repeated ClassificationEntry entries = 1; + // The classification results for this head. + optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful // for multi-head models. optional int32 head_index = 2; @@ -45,7 +50,17 @@ message Classifications { optional string head_name = 3; } -// Contains one set of results per classifier head. +// Classifications for a given classifier model. message ClassificationResult { + // The classification results for each model head, i.e. one for each output + // tensor. repeated Classifications classifications = 1; + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for classification on time series (e.g. audio + // classification). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + optional int64 timestamp_ms = 2; } diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index b4fbf9669..40b805d3b 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -286,7 +286,7 @@ absl::Status ConfigureScoreCalibrationIfAny( void ConfigureClassificationAggregationCalculator( const ModelMetadataExtractor& metadata_extractor, - ClassificationAggregationCalculatorOptions* options) { + mediapipe::ClassificationAggregationCalculatorOptions* options) { auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); if (output_tensors_metadata == nullptr) { return; @@ -494,7 +494,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Aggregates Classifications into a single ClassificationResult. auto& result_aggregation = graph.AddNode("ClassificationAggregationCalculator"); - result_aggregation.GetOptions() + result_aggregation + .GetOptions() .CopyFrom(options.classification_aggregation_options()); for (int i = 0; i < num_heads; ++i) { tensors_to_classification_nodes[i]->Out(kClassificationsTag) >> diff --git a/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto index 1de788eab..84ba95222 100644 --- a/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto @@ -38,7 +38,7 @@ message ClassificationPostprocessingGraphOptions { // Options for the ClassificationAggregationCalculator encapsulated by the // ClassificationPostprocessing subgraph. - optional ClassificationAggregationCalculatorOptions + optional mediapipe.ClassificationAggregationCalculatorOptions classification_aggregation_options = 2; // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).