Modify ClassificationAggregationCalculator to output new unified formats.
PiperOrigin-RevId: 485546547
This commit is contained in:
parent
f1f123d255
commit
475e6b4fd5
|
@ -44,6 +44,30 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "classification_aggregation_calculator_test",
|
||||||
|
srcs = ["classification_aggregation_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":classification_aggregation_calculator",
|
||||||
|
":classification_aggregation_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:output_stream_poller",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "score_calibration_calculator_proto",
|
name = "score_calibration_calculator_proto",
|
||||||
srcs = ["score_calibration_calculator.proto"],
|
srcs = ["score_calibration_calculator.proto"],
|
||||||
|
|
|
@ -31,37 +31,62 @@
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||||
|
|
||||||
// Aggregates ClassificationLists into a single ClassificationResult that has
|
// Aggregates ClassificationLists into either a ClassificationResult object
|
||||||
// 3 dimensions: (classification head, classification timestamp, classification
|
// representing the classification results aggregated by classifier head, or
|
||||||
// category).
|
// into an std::vector<ClassificationResult> representing the classification
|
||||||
|
// results aggregated first by timestamp then by classifier head.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// CLASSIFICATIONS - ClassificationList
|
// CLASSIFICATIONS - ClassificationList @Multiple
|
||||||
// ClassificationList per classification head.
|
// ClassificationList per classification head.
|
||||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
// The collection of the timestamps that a single ClassificationResult
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
// should aggragate. This stream is optional, and the timestamp information
|
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
|
||||||
// will only be populated to the ClassificationResult proto when this stream
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
// is connected.
|
// required the CLASSIFICATIONS output is used for results.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
|
// The classification results aggregated by head. Must be connected if the
|
||||||
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
|
// aggregation is not required.
|
||||||
|
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
|
||||||
|
// The classification result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
|
// // TODO: remove output once migration is over.
|
||||||
|
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||||
// The aggregated classification result.
|
// The aggregated classification result.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example without timestamp aggregation:
|
||||||
|
// node {
|
||||||
|
// calculator: "ClassificationAggregationCalculator"
|
||||||
|
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
||||||
|
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
||||||
|
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
||||||
|
// output_stream: "CLASSIFICATIONS:classifications"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
|
||||||
|
// head_names: "head_name_a"
|
||||||
|
// head_names: "head_name_b"
|
||||||
|
// head_names: "head_name_c"
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Example with timestamp aggregation:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "ClassificationAggregationCalculator"
|
// calculator: "ClassificationAggregationCalculator"
|
||||||
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
// input_stream: "CLASSIFICATIONS:0:stream_a"
|
||||||
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
// input_stream: "CLASSIFICATIONS:1:stream_b"
|
||||||
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
// input_stream: "CLASSIFICATIONS:2:stream_c"
|
||||||
// input_stream: "TIMESTAMPS:timestamps"
|
// input_stream: "TIMESTAMPS:timestamps"
|
||||||
// output_stream: "CLASSIFICATION_RESULT:classification_result"
|
// output_stream: "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.ClassificationAggregationCalculatorOptions.ext] {
|
// [mediapipe.ClassificationAggregationCalculatorOptions.ext] {
|
||||||
// head_names: "head_name_a"
|
// head_names: "head_name_a"
|
||||||
// head_names: "head_name_b"
|
// head_names: "head_name_b"
|
||||||
// head_names: "head_name_c"
|
// head_names: "head_name_c"
|
||||||
|
@ -74,8 +99,15 @@ class ClassificationAggregationCalculator : public Node {
|
||||||
"CLASSIFICATIONS"};
|
"CLASSIFICATIONS"};
|
||||||
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||||
"TIMESTAMPS"};
|
"TIMESTAMPS"};
|
||||||
static constexpr Output<ClassificationResult> kOut{"CLASSIFICATION_RESULT"};
|
static constexpr Output<ClassificationResult>::Optional kClassificationsOut{
|
||||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn, kOut);
|
"CLASSIFICATIONS"};
|
||||||
|
static constexpr Output<std::vector<ClassificationResult>>::Optional
|
||||||
|
kTimestampedClassificationsOut{"TIMESTAMPED_CLASSIFICATIONS"};
|
||||||
|
static constexpr Output<ClassificationResult>::Optional
|
||||||
|
kClassificationResultOut{"CLASSIFICATION_RESULT"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kTimestampsIn,
|
||||||
|
kClassificationsOut, kTimestampedClassificationsOut,
|
||||||
|
kClassificationResultOut);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
absl::Status Open(CalculatorContext* cc);
|
absl::Status Open(CalculatorContext* cc);
|
||||||
|
@ -88,6 +120,11 @@ class ClassificationAggregationCalculator : public Node {
|
||||||
cached_classifications_;
|
cached_classifications_;
|
||||||
|
|
||||||
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
||||||
|
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
|
||||||
|
CalculatorContext* cc);
|
||||||
|
// TODO: deprecate this function once migration is over.
|
||||||
|
ClassificationResult LegacyConvertToClassificationResult(
|
||||||
|
CalculatorContext* cc);
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||||
|
@ -100,6 +137,10 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||||
<< "The size of classifications input streams should match the "
|
<< "The size of classifications input streams should match the "
|
||||||
"size of head names specified in the calculator options";
|
"size of head names specified in the calculator options";
|
||||||
}
|
}
|
||||||
|
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
|
||||||
|
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
|
||||||
|
// not connected. All dependent tasks must be updated to use these outputs
|
||||||
|
// first.
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,10 +165,19 @@ absl::Status ClassificationAggregationCalculator::Process(
|
||||||
[](const auto& elem) -> ClassificationList { return elem.Get(); });
|
[](const auto& elem) -> ClassificationList { return elem.Get(); });
|
||||||
cached_classifications_[cc->InputTimestamp().Value()] =
|
cached_classifications_[cc->InputTimestamp().Value()] =
|
||||||
std::move(classification_lists);
|
std::move(classification_lists);
|
||||||
if (time_aggregation_enabled_ && kTimestampsIn(cc).IsEmpty()) {
|
ClassificationResult classification_result;
|
||||||
return absl::OkStatus();
|
if (time_aggregation_enabled_) {
|
||||||
|
if (kTimestampsIn(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
classification_result = LegacyConvertToClassificationResult(cc);
|
||||||
|
kTimestampedClassificationsOut(cc).Send(
|
||||||
|
ConvertToTimestampedClassificationResults(cc));
|
||||||
|
} else {
|
||||||
|
classification_result = LegacyConvertToClassificationResult(cc);
|
||||||
|
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
|
||||||
}
|
}
|
||||||
kOut(cc).Send(ConvertToClassificationResult(cc));
|
kClassificationResultOut(cc).Send(classification_result);
|
||||||
RET_CHECK(cached_classifications_.empty());
|
RET_CHECK(cached_classifications_.empty());
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -136,6 +186,50 @@ ClassificationResult
|
||||||
ClassificationAggregationCalculator::ConvertToClassificationResult(
|
ClassificationAggregationCalculator::ConvertToClassificationResult(
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
ClassificationResult result;
|
ClassificationResult result;
|
||||||
|
auto& classification_lists =
|
||||||
|
cached_classifications_[cc->InputTimestamp().Value()];
|
||||||
|
for (int i = 0; i < classification_lists.size(); ++i) {
|
||||||
|
auto classifications = result.add_classifications();
|
||||||
|
classifications->set_head_index(i);
|
||||||
|
if (!head_names_.empty()) {
|
||||||
|
classifications->set_head_name(head_names_[i]);
|
||||||
|
}
|
||||||
|
*classifications->mutable_classification_list() =
|
||||||
|
std::move(classification_lists[i]);
|
||||||
|
}
|
||||||
|
cached_classifications_.erase(cc->InputTimestamp().Value());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ClassificationResult>
|
||||||
|
ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
|
||||||
|
CalculatorContext* cc) {
|
||||||
|
auto timestamps = kTimestampsIn(cc).Get();
|
||||||
|
std::vector<ClassificationResult> results;
|
||||||
|
results.reserve(timestamps.size());
|
||||||
|
for (const auto& timestamp : timestamps) {
|
||||||
|
ClassificationResult result;
|
||||||
|
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) / 1000);
|
||||||
|
auto& classification_lists = cached_classifications_[timestamp.Value()];
|
||||||
|
for (int i = 0; i < classification_lists.size(); ++i) {
|
||||||
|
auto classifications = result.add_classifications();
|
||||||
|
classifications->set_head_index(i);
|
||||||
|
if (!head_names_.empty()) {
|
||||||
|
classifications->set_head_name(head_names_[i]);
|
||||||
|
}
|
||||||
|
*classifications->mutable_classification_list() =
|
||||||
|
std::move(classification_lists[i]);
|
||||||
|
}
|
||||||
|
cached_classifications_.erase(timestamp.Value());
|
||||||
|
results.push_back(std::move(result));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
ClassificationResult
|
||||||
|
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
|
||||||
|
CalculatorContext* cc) {
|
||||||
|
ClassificationResult result;
|
||||||
Timestamp first_timestamp(0);
|
Timestamp first_timestamp(0);
|
||||||
std::vector<Timestamp> timestamps;
|
std::vector<Timestamp> timestamps;
|
||||||
if (time_aggregation_enabled_) {
|
if (time_aggregation_enabled_) {
|
||||||
|
@ -177,7 +271,6 @@ ClassificationAggregationCalculator::ConvertToClassificationResult(
|
||||||
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
|
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
|
||||||
1000);
|
1000);
|
||||||
}
|
}
|
||||||
cached_classifications_.erase(timestamp.Value());
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks;
|
package mediapipe;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,213 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/framework/output_stream_poller.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::ParseTextProtoOrDie;
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::testing::Pointwise;
|
||||||
|
|
||||||
|
constexpr char kClassificationInput0Tag[] = "CLASSIFICATIONS_0";
|
||||||
|
constexpr char kClassificationInput0Name[] = "classifications_0";
|
||||||
|
constexpr char kClassificationInput1Tag[] = "CLASSIFICATIONS_1";
|
||||||
|
constexpr char kClassificationInput1Name[] = "classifications_1";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
constexpr char kTimestampsName[] = "timestamps";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
|
constexpr char kClassificationsName[] = "classifications";
|
||||||
|
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
|
constexpr char kTimestampedClassificationsName[] =
|
||||||
|
"timestamped_classifications";
|
||||||
|
|
||||||
|
ClassificationList MakeClassificationList(int class_index) {
|
||||||
|
return ParseTextProtoOrDie<ClassificationList>(absl::StrFormat(
|
||||||
|
R"pb(
|
||||||
|
classification { index: %d }
|
||||||
|
)pb",
|
||||||
|
class_index));
|
||||||
|
}
|
||||||
|
|
||||||
|
class ClassificationAggregationCalculatorTest
|
||||||
|
: public tflite_shims::testing::Test {
|
||||||
|
protected:
|
||||||
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
|
bool connect_timestamps = false) {
|
||||||
|
Graph graph;
|
||||||
|
auto& calculator = graph.AddNode("ClassificationAggregationCalculator");
|
||||||
|
calculator
|
||||||
|
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>() =
|
||||||
|
ParseTextProtoOrDie<
|
||||||
|
mediapipe::ClassificationAggregationCalculatorOptions>(
|
||||||
|
R"pb(head_names: "foo" head_names: "bar")pb");
|
||||||
|
graph[Input<ClassificationList>(kClassificationInput0Tag)].SetName(
|
||||||
|
kClassificationInput0Name) >>
|
||||||
|
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 0));
|
||||||
|
graph[Input<ClassificationList>(kClassificationInput1Tag)].SetName(
|
||||||
|
kClassificationInput1Name) >>
|
||||||
|
calculator.In(absl::StrFormat("%s:%d", kClassificationsTag, 1));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||||
|
kTimestampsName) >>
|
||||||
|
calculator.In(kTimestampsTag);
|
||||||
|
calculator.Out(kTimestampedClassificationsTag)
|
||||||
|
.SetName(kTimestampedClassificationsName) >>
|
||||||
|
graph[Output<std::vector<ClassificationResult>>(
|
||||||
|
kTimestampedClassificationsTag)];
|
||||||
|
} else {
|
||||||
|
calculator.Out(kClassificationsTag).SetName(kClassificationsName) >>
|
||||||
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kTimestampedClassificationsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kClassificationsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Send(
|
||||||
|
std::vector<ClassificationList> classifications, int timestamp = 0,
|
||||||
|
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kClassificationInput0Name,
|
||||||
|
MakePacket<ClassificationList>(classifications[0])
|
||||||
|
.At(Timestamp(timestamp))));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kClassificationInput1Name,
|
||||||
|
MakePacket<ClassificationList>(classifications[1])
|
||||||
|
.At(Timestamp(timestamp))));
|
||||||
|
if (aggregation_timestamps.has_value()) {
|
||||||
|
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||||
|
for (const auto& timestamp : *aggregation_timestamps) {
|
||||||
|
packet->emplace_back(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||||
|
|
||||||
|
Packet packet;
|
||||||
|
if (!poller.Next(&packet)) {
|
||||||
|
return absl::InternalError("Unable to get output packet");
|
||||||
|
}
|
||||||
|
auto result = packet.Get<T>();
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraph calculator_graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph());
|
||||||
|
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<ClassificationResult>(poller));
|
||||||
|
|
||||||
|
EXPECT_THAT(result,
|
||||||
|
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(
|
||||||
|
R"pb(classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "foo"
|
||||||
|
classification_list { classification { index: 0 } }
|
||||||
|
}
|
||||||
|
classifications {
|
||||||
|
head_index: 1
|
||||||
|
head_name: "bar"
|
||||||
|
classification_list { classification { index: 1 } }
|
||||||
|
})pb")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ClassificationAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||||
|
MP_ASSERT_OK(Send({MakeClassificationList(0), MakeClassificationList(1)}));
|
||||||
|
MP_ASSERT_OK(Send(
|
||||||
|
{MakeClassificationList(2), MakeClassificationList(3)},
|
||||||
|
/*timestamp=*/1000,
|
||||||
|
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto result,
|
||||||
|
GetResult<std::vector<ClassificationResult>>(poller));
|
||||||
|
|
||||||
|
EXPECT_THAT(result,
|
||||||
|
Pointwise(EqualsProto(),
|
||||||
|
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
timestamp_ms: 0,
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "foo"
|
||||||
|
classification_list { classification { index: 0 } }
|
||||||
|
}
|
||||||
|
classifications {
|
||||||
|
head_index: 1
|
||||||
|
head_name: "bar"
|
||||||
|
classification_list { classification { index: 1 } }
|
||||||
|
}
|
||||||
|
)pb"),
|
||||||
|
ParseTextProtoOrDie<ClassificationResult>(R"pb(
|
||||||
|
timestamp_ms: 1,
|
||||||
|
classifications {
|
||||||
|
head_index: 0
|
||||||
|
head_name: "foo"
|
||||||
|
classification_list { classification { index: 2 } }
|
||||||
|
}
|
||||||
|
classifications {
|
||||||
|
head_index: 1
|
||||||
|
head_name: "bar"
|
||||||
|
classification_list { classification { index: 3 } }
|
||||||
|
}
|
||||||
|
)pb")}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -28,6 +28,7 @@ mediapipe_proto_library(
|
||||||
srcs = ["classifications.proto"],
|
srcs = ["classifications.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
":category_proto",
|
":category_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ package mediapipe.tasks.components.containers.proto;
|
||||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||||
option java_outer_classname = "CategoryProto";
|
option java_outer_classname = "CategoryProto";
|
||||||
|
|
||||||
|
// TODO: deprecate this message once migration is over.
|
||||||
// A single classification result.
|
// A single classification result.
|
||||||
message Category {
|
message Category {
|
||||||
// The index of the category in the corresponding label map, usually packed in
|
// The index of the category in the corresponding label map, usually packed in
|
||||||
|
|
|
@ -17,11 +17,13 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.containers.proto;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
|
import "mediapipe/framework/formats/classification.proto";
|
||||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||||
option java_outer_classname = "ClassificationsProto";
|
option java_outer_classname = "ClassificationsProto";
|
||||||
|
|
||||||
|
// TODO: deprecate this message once migration is over.
|
||||||
// List of predicted categories with an optional timestamp.
|
// List of predicted categories with an optional timestamp.
|
||||||
message ClassificationEntry {
|
message ClassificationEntry {
|
||||||
// The array of predicted categories, usually sorted by descending scores,
|
// The array of predicted categories, usually sorted by descending scores,
|
||||||
|
@ -33,9 +35,12 @@ message ClassificationEntry {
|
||||||
optional int64 timestamp_ms = 2;
|
optional int64 timestamp_ms = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Classifications for a given classifier head.
|
// Classifications for a given classifier head, i.e. for a given output tensor.
|
||||||
message Classifications {
|
message Classifications {
|
||||||
|
// TODO: deprecate this field once migration is over.
|
||||||
repeated ClassificationEntry entries = 1;
|
repeated ClassificationEntry entries = 1;
|
||||||
|
// The classification results for this head.
|
||||||
|
optional mediapipe.ClassificationList classification_list = 4;
|
||||||
// The index of the classifier head these categories refer to. This is useful
|
// The index of the classifier head these categories refer to. This is useful
|
||||||
// for multi-head models.
|
// for multi-head models.
|
||||||
optional int32 head_index = 2;
|
optional int32 head_index = 2;
|
||||||
|
@ -45,7 +50,17 @@ message Classifications {
|
||||||
optional string head_name = 3;
|
optional string head_name = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains one set of results per classifier head.
|
// Classifications for a given classifier model.
|
||||||
message ClassificationResult {
|
message ClassificationResult {
|
||||||
|
// The classification results for each model head, i.e. one for each output
|
||||||
|
// tensor.
|
||||||
repeated Classifications classifications = 1;
|
repeated Classifications classifications = 1;
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for classification on time series (e.g. audio
|
||||||
|
// classification). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
optional int64 timestamp_ms = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -286,7 +286,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
||||||
|
|
||||||
void ConfigureClassificationAggregationCalculator(
|
void ConfigureClassificationAggregationCalculator(
|
||||||
const ModelMetadataExtractor& metadata_extractor,
|
const ModelMetadataExtractor& metadata_extractor,
|
||||||
ClassificationAggregationCalculatorOptions* options) {
|
mediapipe::ClassificationAggregationCalculatorOptions* options) {
|
||||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||||
if (output_tensors_metadata == nullptr) {
|
if (output_tensors_metadata == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -494,7 +494,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
// Aggregates Classifications into a single ClassificationResult.
|
// Aggregates Classifications into a single ClassificationResult.
|
||||||
auto& result_aggregation =
|
auto& result_aggregation =
|
||||||
graph.AddNode("ClassificationAggregationCalculator");
|
graph.AddNode("ClassificationAggregationCalculator");
|
||||||
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
|
result_aggregation
|
||||||
|
.GetOptions<mediapipe::ClassificationAggregationCalculatorOptions>()
|
||||||
.CopyFrom(options.classification_aggregation_options());
|
.CopyFrom(options.classification_aggregation_options());
|
||||||
for (int i = 0; i < num_heads; ++i) {
|
for (int i = 0; i < num_heads; ++i) {
|
||||||
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
|
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
|
||||||
|
|
|
@ -38,7 +38,7 @@ message ClassificationPostprocessingGraphOptions {
|
||||||
|
|
||||||
// Options for the ClassificationAggregationCalculator encapsulated by the
|
// Options for the ClassificationAggregationCalculator encapsulated by the
|
||||||
// ClassificationPostprocessing subgraph.
|
// ClassificationPostprocessing subgraph.
|
||||||
optional ClassificationAggregationCalculatorOptions
|
optional mediapipe.ClassificationAggregationCalculatorOptions
|
||||||
classification_aggregation_options = 2;
|
classification_aggregation_options = 2;
|
||||||
|
|
||||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||||
|
|
Loading…
Reference in New Issue
Block a user