Modify ClassificationAggregationCalculator to output new unified formats.

PiperOrigin-RevId: 485546547
This commit is contained in:
MediaPipe Team 2022-11-02 03:09:25 -07:00 committed by Copybara-Service
parent f1f123d255
commit 475e6b4fd5
9 changed files with 373 additions and 25 deletions

View File

@ -44,6 +44,30 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "classification_aggregation_calculator_test",
srcs = ["classification_aggregation_calculator_test.cc"],
deps = [
":classification_aggregation_calculator",
":classification_aggregation_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:output_stream_poller",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
mediapipe_proto_library(
name = "score_calibration_calculator_proto",
srcs = ["score_calibration_calculator.proto"],

View File

@ -31,37 +31,62 @@
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has
// 3 dimensions: (classification head, classification timestamp, classification
// category).
// Aggregates ClassificationLists into either a ClassificationResult object
// representing the classification results aggregated by classifier head, or
// into an std::vector<ClassificationResult> representing the classification
// results aggregated first by timestamp then by classifier head.
//
// Inputs:
// CLASSIFICATIONS - ClassificationList
// CLASSIFICATIONS - ClassificationList @Multiple
// ClassificationList per classification head.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that a single ClassificationResult
// should aggragate. This stream is optional, and the timestamp information
// will only be populated to the ClassificationResult proto when this stream
// is connected.
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// output is used for results. Otherwise as no timestamp aggregation is
// required the CLASSIFICATIONS output is used for results.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// Example:
// Example without timestamp aggregation:
// node {
// calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c"
// output_stream: "CLASSIFICATIONS:classifications"
// options {
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a"
// head_names: "head_name_b"
// head_names: "head_name_c"
// }
// }
// }
//
// Example with timestamp aggregation:
// node {
// calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c"
// input_stream: "TIMESTAMPS:timestamps"
// output_stream: "CLASSIFICATION_RESULT:classification_result"
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options {
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] {
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a"
// head_names: "head_name_b"
// head_names: "head_name_c"
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
"CLASSIFICATIONS"};
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
"TIMESTAMPS"};
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut);
static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
"CLASSIFICATIONS"};
static constexpr Output<std::vector<ClassificationResult>>::Optional
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
static constexpr Output<ClassificationResult>::Optional
kClassificationResultOut{"CLASSIFICATION_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
kClassificationsOut, kTimestampedClassificationsOut,
kClassificationResultOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc);
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
cached_classifications_;
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
CalculatorContext* cc);
// TODO: deprecate this function once migration is over.
ClassificationResult LegacyConvertToClassificationResult(
CalculatorContext* cc);
};
absl::Status ClassificationAggregationCalculator::UpdateContract(
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
<< "The size of classifications input streams should match the "
"size of head names specified in the calculator options";
}
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
// not connected. All dependent tasks must be updated to use these outputs
// first.
return absl::OkStatus();
}
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
[](const auto& elem) -> ClassificationList { return elem.Get(); });
cached_classifications_[cc->InputTimestamp().Value()] =
std::move(classification_lists);
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
ClassificationResult classification_result;
if (time_aggregation_enabled_) {
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
classification_result = LegacyConvertToClassificationResult(cc);
kTimestampedClassificationsOut(cc).Send(
ConvertToTimestampedClassificationResults(cc));
} else {
classification_result = LegacyConvertToClassificationResult(cc);
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
}
kOut(cc).Send(ConvertToClassificationResult(cc));
kClassificationResultOut(cc).Send(classification_result);
RET_CHECK(cached_classifications_.empty());
return absl::OkStatus();
}
@ -136,6 +186,50 @@ ClassificationResult
ClassificationAggregationCalculator::ConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
auto& classification_lists =
cached_classifications_[cc->InputTimestamp().Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(cc->InputTimestamp().Value());
return result;
}
std::vector<ClassificationResult>
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
CalculatorContext* cc) {
auto timestamps = kTimestampsIn(cc).Get();
std::vector<ClassificationResult> results;
results.reserve(timestamps.size());
for (const auto& timestamp : timestamps) {
ClassificationResult result;
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
auto& classification_lists = cached_classifications_[timestamp.Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(timestamp.Value());
results.push_back(std::move(result));
}
return results;
}
ClassificationResult
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
Timestamp first_timestamp(0);
std::vector<Timestamp> timestamps;
if (time_aggregation_enabled_) {
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
1000);
}
cached_classifications_.erase(timestamp.Value());
}
return result;
}

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks;
package mediapipe;
import "mediapipe/framework/calculator.proto";

View File

@ -0,0 +1,213 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::Pointwise;
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
constexpr char kClassificationInput0Name[] = "classifications_0";
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
constexpr char kClassificationInput1Name[] = "classifications_1";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications";
ClassificationList MakeClassificationList(int class_index) {
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
R"pb(
classification { index: %d }
)pb",
class_index));
}
class ClassificationAggregationCalculatorTest
: public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
bool connect_timestamps = false) {
Graph graph;
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
calculator
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
ParseTextProtoOrDie<
mediapipe::ClassificationAggregationCalculatorOptions>(
R"pb(head_names: "foo" head_names: "bar")pb");
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
kClassificationInput0Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
kClassificationInput1Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
calculator.In(kTimestampsTag);
calculator.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
} else {
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
graph[Output<ClassificationResult>(kClassificationsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
absl::Status Send(
std::vector<ClassificationList> classifications, int timestamp = 0,
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput0Name,
MakePacket<ClassificationList>(classifications[0])
.At(Timestamp(timestamp))));
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput1Name,
MakePacket<ClassificationList>(classifications[1])
.At(Timestamp(timestamp))));
if (aggregation_timestamps.has_value()) {
auto packet = std::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
};
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
})pb")));
}
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK(Send(
{MakeClassificationList(2), MakeClassificationList(3)},
/*timestamp=*/1000,
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
MP_ASSERT_OK_AND_ASSIGN(auto result,
GetResult<std::vector<ClassificationResult>>(poller));
EXPECT_THAT(result,
Pointwise(EqualsProto(),
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
}
)pb"),
ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 1,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 2 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 3 } }
}
)pb")}));
}
} // namespace
} // namespace mediapipe

View File

@ -28,6 +28,7 @@ mediapipe_proto_library(
srcs = ["classifications.proto"],
deps = [
":category_proto",
"//mediapipe/framework/formats:classification_proto",
],
)

View File

@ -20,6 +20,7 @@ package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto";
// TODO: deprecate this message once migration is over.
// A single classification result.
message Category {
// The index of the category in the corresponding label map, usually packed in

View File

@ -17,11 +17,13 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto";
// TODO: deprecate this message once migration is over.
// List of predicted categories with an optional timestamp.
message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores,
@ -33,9 +35,12 @@ message ClassificationEntry {
optional int64 timestamp_ms = 2;
}
// Classifications for a given classifier head.
// Classifications for a given classifier head, i.e. for a given output tensor.
message Classifications {
// TODO: deprecate this field once migration is over.
repeated ClassificationEntry entries = 1;
// The classification results for this head.
optional mediapipe.ClassificationList classification_list = 4;
// The index of the classifier head these categories refer to. This is useful
// for multi-head models.
optional int32 head_index = 2;
@ -45,7 +50,17 @@ message Classifications {
optional string head_name = 3;
}
// Contains one set of results per classifier head.
// Classifications for a given classifier model.
message ClassificationResult {
// The classification results for each model head, i.e. one for each output
// tensor.
repeated Classifications classifications = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
}

View File

@ -286,7 +286,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
mediapipe::ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
@ -494,7 +494,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
// Aggregates Classifications into a single ClassificationResult.
auto& result_aggregation =
graph.AddNode("ClassificationAggregationCalculator");
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
result_aggregation
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
.CopyFrom(options.classification_aggregation_options());
for (int i = 0; i < num_heads; ++i) {
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>

View File

@ -38,7 +38,7 @@ message ClassificationPostprocessingGraphOptions {
// Options for the ClassificationAggregationCalculator encapsulated by the
// ClassificationPostprocessing subgraph.
optional ClassificationAggregationCalculatorOptions
optional mediapipe.ClassificationAggregationCalculatorOptions
classification_aggregation_options = 2;
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).