Modify ClassificationAggregationCalculator to output new unified formats.

PiperOrigin-RevId: 485546547
This commit is contained in:
MediaPipe Team 2022-11-02 03:09:25 -07:00 committed by Copybara-Service
parent f1f123d255
commit 475e6b4fd5
9 changed files with 373 additions and 25 deletions

View File

@ -44,6 +44,30 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "classification_aggregation_calculator_test",
srcs = ["classification_aggregation_calculator_test.cc"],
deps = [
":classification_aggregation_calculator",
":classification_aggregation_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:output_stream_poller",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "score_calibration_calculator_proto", name = "score_calibration_calculator_proto",
srcs = ["score_calibration_calculator.proto"], srcs = ["score_calibration_calculator.proto"],

View File

@ -31,37 +31,62 @@
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications; using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has // Aggregates ClassificationLists into either a ClassificationResult object
// 3 dimensions: (classification head, classification timestamp, classification // representing the classification results aggregated by classifier head, or
// category). // into an std::vector<ClassificationResult> representing the classification
// results aggregated first by timestamp then by classifier head.
// //
// Inputs: // Inputs:
// CLASSIFICATIONS - ClassificationList // CLASSIFICATIONS - ClassificationList @Multiple
// ClassificationList per classification head. // ClassificationList per classification head.
// TIMESTAMPS - std::vector<Timestamp> @Optional // TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that a single ClassificationResult // The collection of the timestamps that this calculator should aggregate.
// should aggragate. This stream is optional, and the timestamp information // This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// will only be populated to the ClassificationResult proto when this stream // output is used for results. Otherwise as no timestamp aggregation is
// is connected. // required the CLASSIFICATIONS output is used for results.
// //
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result. // The aggregated classification result.
// //
// Example: // Example without timestamp aggregation:
// node {
// calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c"
// output_stream: "CLASSIFICATIONS:classifications"
// options {
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a"
// head_names: "head_name_b"
// head_names: "head_name_c"
// }
// }
// }
//
// Example with timestamp aggregation:
// node { // node {
// calculator: "ClassificationAggregationCalculator" // calculator: "ClassificationAggregationCalculator"
// input_stream: "CLASSIFICATIONS:0:stream_a" // input_stream: "CLASSIFICATIONS:0:stream_a"
// input_stream: "CLASSIFICATIONS:1:stream_b" // input_stream: "CLASSIFICATIONS:1:stream_b"
// input_stream: "CLASSIFICATIONS:2:stream_c" // input_stream: "CLASSIFICATIONS:2:stream_c"
// input_stream: "TIMESTAMPS:timestamps" // input_stream: "TIMESTAMPS:timestamps"
// output_stream: "CLASSIFICATION_RESULT:classification_result" // output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
// options { // options {
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] { // [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
// head_names: "head_name_a" // head_names: "head_name_a"
// head_names: "head_name_b" // head_names: "head_name_b"
// head_names: "head_name_c" // head_names: "head_name_c"
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
"CLASSIFICATIONS"}; "CLASSIFICATIONS"};
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{ static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
"TIMESTAMPS"}; "TIMESTAMPS"};
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"}; static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut); "CLASSIFICATIONS"};
static constexpr Output<std::vector<ClassificationResult>>::Optional
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
static constexpr Output<ClassificationResult>::Optional
kClassificationResultOut{"CLASSIFICATION_RESULT"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
kClassificationsOut, kTimestampedClassificationsOut,
kClassificationResultOut);
static absl::Status UpdateContract(CalculatorContract* cc); static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc); absl::Status Open(CalculatorContext* cc);
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
cached_classifications_; cached_classifications_;
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
CalculatorContext* cc);
// TODO: deprecate this function once migration is over.
ClassificationResult LegacyConvertToClassificationResult(
CalculatorContext* cc);
}; };
absl::Status ClassificationAggregationCalculator::UpdateContract( absl::Status ClassificationAggregationCalculator::UpdateContract(
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
<< "The size of classifications input streams should match the " << "The size of classifications input streams should match the "
"size of head names specified in the calculator options"; "size of head names specified in the calculator options";
} }
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
// not connected. All dependent tasks must be updated to use these outputs
// first.
return absl::OkStatus(); return absl::OkStatus();
} }
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
[](const auto& elem) -> ClassificationList { return elem.Get(); }); [](const auto& elem) -> ClassificationList { return elem.Get(); });
cached_classifications_[cc->InputTimestamp().Value()] = cached_classifications_[cc->InputTimestamp().Value()] =
std::move(classification_lists); std::move(classification_lists);
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) { ClassificationResult classification_result;
return absl::OkStatus(); if (time_aggregation_enabled_) {
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
classification_result = LegacyConvertToClassificationResult(cc);
kTimestampedClassificationsOut(cc).Send(
ConvertToTimestampedClassificationResults(cc));
} else {
classification_result = LegacyConvertToClassificationResult(cc);
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
} }
kOut(cc).Send(ConvertToClassificationResult(cc)); kClassificationResultOut(cc).Send(classification_result);
RET_CHECK(cached_classifications_.empty()); RET_CHECK(cached_classifications_.empty());
return absl::OkStatus(); return absl::OkStatus();
} }
@ -136,6 +186,50 @@ ClassificationResult
ClassificationAggregationCalculator::ConvertToClassificationResult( ClassificationAggregationCalculator::ConvertToClassificationResult(
CalculatorContext* cc) { CalculatorContext* cc) {
ClassificationResult result; ClassificationResult result;
auto& classification_lists =
cached_classifications_[cc->InputTimestamp().Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(cc->InputTimestamp().Value());
return result;
}
std::vector<ClassificationResult>
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
CalculatorContext* cc) {
auto timestamps = kTimestampsIn(cc).Get();
std::vector<ClassificationResult> results;
results.reserve(timestamps.size());
for (const auto& timestamp : timestamps) {
ClassificationResult result;
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
auto& classification_lists = cached_classifications_[timestamp.Value()];
for (int i = 0; i < classification_lists.size(); ++i) {
auto classifications = result.add_classifications();
classifications->set_head_index(i);
if (!head_names_.empty()) {
classifications->set_head_name(head_names_[i]);
}
*classifications->mutable_classification_list() =
std::move(classification_lists[i]);
}
cached_classifications_.erase(timestamp.Value());
results.push_back(std::move(result));
}
return results;
}
ClassificationResult
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
Timestamp first_timestamp(0); Timestamp first_timestamp(0);
std::vector<Timestamp> timestamps; std::vector<Timestamp> timestamps;
if (time_aggregation_enabled_) { if (time_aggregation_enabled_) {
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
1000); 1000);
} }
cached_classifications_.erase(timestamp.Value());
} }
return result; return result;
} }

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";

View File

@ -0,0 +1,213 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::Pointwise;
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
constexpr char kClassificationInput0Name[] = "classifications_0";
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
constexpr char kClassificationInput1Name[] = "classifications_1";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications";
ClassificationList MakeClassificationList(int class_index) {
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
R"pb(
classification { index: %d }
)pb",
class_index));
}
class ClassificationAggregationCalculatorTest
: public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
bool connect_timestamps = false) {
Graph graph;
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
calculator
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
ParseTextProtoOrDie<
mediapipe::ClassificationAggregationCalculatorOptions>(
R"pb(head_names: "foo" head_names: "bar")pb");
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
kClassificationInput0Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
kClassificationInput1Name) >>
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
calculator.In(kTimestampsTag);
calculator.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
} else {
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
graph[Output<ClassificationResult>(kClassificationsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
absl::Status Send(
std::vector<ClassificationList> classifications, int timestamp = 0,
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput0Name,
MakePacket<ClassificationList>(classifications[0])
.At(Timestamp(timestamp))));
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kClassificationInput1Name,
MakePacket<ClassificationList>(classifications[1])
.At(Timestamp(timestamp))));
if (aggregation_timestamps.has_value()) {
auto packet = std::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
};
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
EXPECT_THAT(result,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
})pb")));
}
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
MP_ASSERT_OK(Send(
{MakeClassificationList(2), MakeClassificationList(3)},
/*timestamp=*/1000,
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
MP_ASSERT_OK_AND_ASSIGN(auto result,
GetResult<std::vector<ClassificationResult>>(poller));
EXPECT_THAT(result,
Pointwise(EqualsProto(),
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 0 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 1 } }
}
)pb"),
ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 1,
classifications {
head_index: 0
head_name: "foo"
classification_list { classification { index: 2 } }
}
classifications {
head_index: 1
head_name: "bar"
classification_list { classification { index: 3 } }
}
)pb")}));
}
} // namespace
} // namespace mediapipe

View File

@ -28,6 +28,7 @@ mediapipe_proto_library(
srcs = ["classifications.proto"], srcs = ["classifications.proto"],
deps = [ deps = [
":category_proto", ":category_proto",
"//mediapipe/framework/formats:classification_proto",
], ],
) )

View File

@ -20,6 +20,7 @@ package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto"; option java_outer_classname = "CategoryProto";
// TODO: deprecate this message once migration is over.
// A single classification result. // A single classification result.
message Category { message Category {
// The index of the category in the corresponding label map, usually packed in // The index of the category in the corresponding label map, usually packed in

View File

@ -17,11 +17,13 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/tasks/cc/components/containers/proto/category.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto"; option java_outer_classname = "ClassificationsProto";
// TODO: deprecate this message once migration is over.
// List of predicted categories with an optional timestamp. // List of predicted categories with an optional timestamp.
message ClassificationEntry { message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores, // The array of predicted categories, usually sorted by descending scores,
@ -33,9 +35,12 @@ message ClassificationEntry {
optional int64 timestamp_ms = 2; optional int64 timestamp_ms = 2;
} }
// Classifications for a given classifier head. // Classifications for a given classifier head, i.e. for a given output tensor.
message Classifications { message Classifications {
// TODO: deprecate this field once migration is over.
repeated ClassificationEntry entries = 1; repeated ClassificationEntry entries = 1;
// The classification results for this head.
optional mediapipe.ClassificationList classification_list = 4;
// The index of the classifier head these categories refer to. This is useful // The index of the classifier head these categories refer to. This is useful
// for multi-head models. // for multi-head models.
optional int32 head_index = 2; optional int32 head_index = 2;
@ -45,7 +50,17 @@ message Classifications {
optional string head_name = 3; optional string head_name = 3;
} }
// Contains one set of results per classifier head. // Classifications for a given classifier model.
message ClassificationResult { message ClassificationResult {
// The classification results for each model head, i.e. one for each output
// tensor.
repeated Classifications classifications = 1; repeated Classifications classifications = 1;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
optional int64 timestamp_ms = 2;
} }

View File

@ -286,7 +286,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
void ConfigureClassificationAggregationCalculator( void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor, const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) { mediapipe::ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) { if (output_tensors_metadata == nullptr) {
return; return;
@ -494,7 +494,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
// Aggregates Classifications into a single ClassificationResult. // Aggregates Classifications into a single ClassificationResult.
auto& result_aggregation = auto& result_aggregation =
graph.AddNode("ClassificationAggregationCalculator"); graph.AddNode("ClassificationAggregationCalculator");
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>() result_aggregation
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
.CopyFrom(options.classification_aggregation_options()); .CopyFrom(options.classification_aggregation_options());
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >> tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>

View File

@ -38,7 +38,7 @@ message ClassificationPostprocessingGraphOptions {
// Options for the ClassificationAggregationCalculator encapsulated by the // Options for the ClassificationAggregationCalculator encapsulated by the
// ClassificationPostprocessing subgraph. // ClassificationPostprocessing subgraph.
optional ClassificationAggregationCalculatorOptions optional mediapipe.ClassificationAggregationCalculatorOptions
classification_aggregation_options = 2; classification_aggregation_options = 2;
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).