Modify ClassificationAggregationCalculator to output new unified formats.
PiperOrigin-RevId: 485546547
This commit is contained in:
parent
f1f123d255
commit
475e6b4fd5
|
@ -44,6 +44,30 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "classification_aggregation_calculator_test",
|
||||
srcs = ["classification_aggregation_calculator_test.cc"],
|
||||
deps = [
|
||||
":classification_aggregation_calculator",
|
||||
":classification_aggregation_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:output_stream_poller",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "score_calibration_calculator_proto",
|
||||
srcs = ["score_calibration_calculator.proto"],
|
||||
|
|
|
@ -31,37 +31,62 @@
|
|||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
|
||||
// Aggregates ClassificationLists into a single ClassificationResult that has
|
||||
// 3 dimensions: (classification head, classification timestamp, classification
|
||||
// category).
|
||||
// Aggregates ClassificationLists into either a ClassificationResult object
|
||||
// representing the classification results aggregated by classifier head, or
|
||||
// into an std::vector<ClassificationResult> representing the classification
|
||||
// results aggregated first by timestamp then by classifier head.
|
||||
//
|
||||
// Inputs:
|
||||
// CLASSIFICATIONS - ClassificationList
|
||||
// CLASSIFICATIONS - ClassificationList @Multiple
|
||||
// ClassificationList per classification head.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of the timestamps that a single ClassificationResult
|
||||
// should aggragate. This stream is optional, and the timestamp information
|
||||
// will only be populated to the ClassificationResult proto when this stream
|
||||
// is connected.
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the CLASSIFICATIONS output is used for results.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// Example:
|
||||
// Example without timestamp aggregation:
|
||||
// node {
|
||||
// calculator: "ClassificationAggregationCalculator"
|
||||
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
||||
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
||||
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
||||
// output_stream: "CLASSIFICATIONS:classifications"
|
||||
// options {
|
||||
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
|
||||
// head_names: "head_name_a"
|
||||
// head_names: "head_name_b"
|
||||
// head_names: "head_name_c"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Example with timestamp aggregation:
|
||||
// node {
|
||||
// calculator: "ClassificationAggregationCalculator"
|
||||
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
||||
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
||||
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
||||
// input_stream: "TIMESTAMPS:timestamps"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result"
|
||||
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
|
||||
// options {
|
||||
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] {
|
||||
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
|
||||
// head_names: "head_name_a"
|
||||
// head_names: "head_name_b"
|
||||
// head_names: "head_name_c"
|
||||
|
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
|
|||
"CLASSIFICATIONS"};
|
||||
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||
"TIMESTAMPS"};
|
||||
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut);
|
||||
static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
|
||||
"CLASSIFICATIONS"};
|
||||
static constexpr Output<std::vector<ClassificationResult>>::Optional
|
||||
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
|
||||
static constexpr Output<ClassificationResult>::Optional
|
||||
kClassificationResultOut{"CLASSIFICATION_RESULT"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
|
||||
kClassificationsOut, kTimestampedClassificationsOut,
|
||||
kClassificationResultOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc);
|
||||
|
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
|
|||
cached_classifications_;
|
||||
|
||||
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
||||
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
|
||||
CalculatorContext* cc);
|
||||
// TODO: deprecate this function once migration is over.
|
||||
ClassificationResult LegacyConvertToClassificationResult(
|
||||
CalculatorContext* cc);
|
||||
};
|
||||
|
||||
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||
|
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
|
|||
<< "The size of classifications input streams should match the "
|
||||
"size of head names specified in the calculator options";
|
||||
}
|
||||
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
|
||||
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
|
||||
// not connected. All dependent tasks must be updated to use these outputs
|
||||
// first.
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
|
|||
[](const auto& elem) -> ClassificationList { return elem.Get(); });
|
||||
cached_classifications_[cc->InputTimestamp().Value()] =
|
||||
std::move(classification_lists);
|
||||
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) {
|
||||
ClassificationResult classification_result;
|
||||
if (time_aggregation_enabled_) {
|
||||
if (kTimestampsIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
kOut(cc).Send(ConvertToClassificationResult(cc));
|
||||
classification_result = LegacyConvertToClassificationResult(cc);
|
||||
kTimestampedClassificationsOut(cc).Send(
|
||||
ConvertToTimestampedClassificationResults(cc));
|
||||
} else {
|
||||
classification_result = LegacyConvertToClassificationResult(cc);
|
||||
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
|
||||
}
|
||||
kClassificationResultOut(cc).Send(classification_result);
|
||||
RET_CHECK(cached_classifications_.empty());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -136,6 +186,50 @@ ClassificationResult
|
|||
ClassificationAggregationCalculator::ConvertToClassificationResult(
|
||||
CalculatorContext* cc) {
|
||||
ClassificationResult result;
|
||||
auto& classification_lists =
|
||||
cached_classifications_[cc->InputTimestamp().Value()];
|
||||
for (int i = 0; i < classification_lists.size(); ++i) {
|
||||
auto classifications = result.add_classifications();
|
||||
classifications->set_head_index(i);
|
||||
if (!head_names_.empty()) {
|
||||
classifications->set_head_name(head_names_[i]);
|
||||
}
|
||||
*classifications->mutable_classification_list() =
|
||||
std::move(classification_lists[i]);
|
||||
}
|
||||
cached_classifications_.erase(cc->InputTimestamp().Value());
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<ClassificationResult>
|
||||
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
|
||||
CalculatorContext* cc) {
|
||||
auto timestamps = kTimestampsIn(cc).Get();
|
||||
std::vector<ClassificationResult> results;
|
||||
results.reserve(timestamps.size());
|
||||
for (const auto& timestamp : timestamps) {
|
||||
ClassificationResult result;
|
||||
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
|
||||
auto& classification_lists = cached_classifications_[timestamp.Value()];
|
||||
for (int i = 0; i < classification_lists.size(); ++i) {
|
||||
auto classifications = result.add_classifications();
|
||||
classifications->set_head_index(i);
|
||||
if (!head_names_.empty()) {
|
||||
classifications->set_head_name(head_names_[i]);
|
||||
}
|
||||
*classifications->mutable_classification_list() =
|
||||
std::move(classification_lists[i]);
|
||||
}
|
||||
cached_classifications_.erase(timestamp.Value());
|
||||
results.push_back(std::move(result));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
ClassificationResult
|
||||
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
|
||||
CalculatorContext* cc) {
|
||||
ClassificationResult result;
|
||||
Timestamp first_timestamp(0);
|
||||
std::vector<Timestamp> timestamps;
|
||||
if (time_aggregation_enabled_) {
|
||||
|
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
|
|||
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
|
||||
1000);
|
||||
}
|
||||
cached_classifications_.erase(timestamp.Value());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks;
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
|
|
|
@ -0,0 +1,213 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::ParseTextProtoOrDie;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
|
||||
constexpr char kClassificationInput0Name[] = "classifications_0";
|
||||
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
|
||||
constexpr char kClassificationInput1Name[] = "classifications_1";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kClassificationsName[] = "classifications";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
constexpr char kTimestampedClassificationsName[] =
|
||||
"timestamped_classifications";
|
||||
|
||||
ClassificationList MakeClassificationList(int class_index) {
|
||||
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
|
||||
R"pb(
|
||||
classification { index: %d }
|
||||
)pb",
|
||||
class_index));
|
||||
}
|
||||
|
||||
class ClassificationAggregationCalculatorTest
|
||||
: public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
bool connect_timestamps = false) {
|
||||
Graph graph;
|
||||
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
|
||||
calculator
|
||||
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
|
||||
ParseTextProtoOrDie<
|
||||
mediapipe::ClassificationAggregationCalculatorOptions>(
|
||||
R"pb(head_names: "foo" head_names: "bar")pb");
|
||||
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
|
||||
kClassificationInput0Name) >>
|
||||
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
|
||||
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
|
||||
kClassificationInput1Name) >>
|
||||
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
calculator.In(kTimestampsTag);
|
||||
calculator.Out(kTimestampedClassificationsTag)
|
||||
.SetName(kTimestampedClassificationsName) >>
|
||||
graph[Output<std::vector<ClassificationResult>>(
|
||||
kTimestampedClassificationsTag)];
|
||||
} else {
|
||||
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedClassificationsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kClassificationsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
absl::Status Send(
|
||||
std::vector<ClassificationList> classifications, int timestamp = 0,
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kClassificationInput0Name,
|
||||
MakePacket<ClassificationList>(classifications[0])
|
||||
.At(Timestamp(timestamp))));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kClassificationInput1Name,
|
||||
MakePacket<ClassificationList>(classifications[1])
|
||||
.At(Timestamp(timestamp))));
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
};
|
||||
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
|
||||
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
|
||||
|
||||
EXPECT_THAT(result,
|
||||
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 0 } }
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
classification_list { classification { index: 1 } }
|
||||
})pb")));
|
||||
}
|
||||
|
||||
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||
MP_ASSERT_OK(Send(
|
||||
{MakeClassificationList(2), MakeClassificationList(3)},
|
||||
/*timestamp=*/1000,
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||
GetResult<std::vector<ClassificationResult>>(poller));
|
||||
|
||||
EXPECT_THAT(result,
|
||||
Pointwise(EqualsProto(),
|
||||
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 0,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 0 } }
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
classification_list { classification { index: 1 } }
|
||||
}
|
||||
)pb"),
|
||||
ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||
timestamp_ms: 1,
|
||||
classifications {
|
||||
head_index: 0
|
||||
head_name: "foo"
|
||||
classification_list { classification { index: 2 } }
|
||||
}
|
||||
classifications {
|
||||
head_index: 1
|
||||
head_name: "bar"
|
||||
classification_list { classification { index: 3 } }
|
||||
}
|
||||
)pb")}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -28,6 +28,7 @@ mediapipe_proto_library(
|
|||
srcs = ["classifications.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
"//mediapipe/framework/formats:classification_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ package mediapipe.tasks.components.containers.proto;
|
|||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// TODO: deprecate this message once migration is over.
|
||||
// A single classification result.
|
||||
message Category {
|
||||
// The index of the category in the corresponding label map, usually packed in
|
||||
|
|
|
@ -17,11 +17,13 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
import "mediapipe/framework/formats/classification.proto";
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// TODO: deprecate this message once migration is over.
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
|
@ -33,9 +35,12 @@ message ClassificationEntry {
|
|||
optional int64 timestamp_ms = 2;
|
||||
}
|
||||
|
||||
// Classifications for a given classifier head.
|
||||
// Classifications for a given classifier head, i.e. for a given output tensor.
|
||||
message Classifications {
|
||||
// TODO: deprecate this field once migration is over.
|
||||
repeated ClassificationEntry entries = 1;
|
||||
// The classification results for this head.
|
||||
optional mediapipe.ClassificationList classification_list = 4;
|
||||
// The index of the classifier head these categories refer to. This is useful
|
||||
// for multi-head models.
|
||||
optional int32 head_index = 2;
|
||||
|
@ -45,7 +50,17 @@ message Classifications {
|
|||
optional string head_name = 3;
|
||||
}
|
||||
|
||||
// Contains one set of results per classifier head.
|
||||
// Classifications for a given classifier model.
|
||||
message ClassificationResult {
|
||||
// The classification results for each model head, i.e. one for each output
|
||||
// tensor.
|
||||
repeated Classifications classifications = 1;
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for classification on time series (e.g. audio
|
||||
// classification). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
optional int64 timestamp_ms = 2;
|
||||
}
|
||||
|
|
|
@ -286,7 +286,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
|||
|
||||
void ConfigureClassificationAggregationCalculator(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
ClassificationAggregationCalculatorOptions* options) {
|
||||
mediapipe::ClassificationAggregationCalculatorOptions* options) {
|
||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr) {
|
||||
return;
|
||||
|
@ -494,7 +494,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
// Aggregates Classifications into a single ClassificationResult.
|
||||
auto& result_aggregation =
|
||||
graph.AddNode("ClassificationAggregationCalculator");
|
||||
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
|
||||
result_aggregation
|
||||
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
|
||||
.CopyFrom(options.classification_aggregation_options());
|
||||
for (int i = 0; i < num_heads; ++i) {
|
||||
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
|
||||
|
|
|
@ -38,7 +38,7 @@ message ClassificationPostprocessingGraphOptions {
|
|||
|
||||
// Options for the ClassificationAggregationCalculator encapsulated by the
|
||||
// ClassificationPostprocessing subgraph.
|
||||
optional ClassificationAggregationCalculatorOptions
|
||||
optional mediapipe.ClassificationAggregationCalculatorOptions
|
||||
classification_aggregation_options = 2;
|
||||
|
||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||
|
|
Loading…
Reference in New Issue
Block a user