Internal change

PiperOrigin-RevId: 485553891
This commit is contained in:
MediaPipe Team 2022-11-02 03:48:02 -07:00 committed by Copybara-Service
parent 475e6b4fd5
commit aab5f84aae
3 changed files with 387 additions and 23 deletions

View File

@ -78,6 +78,14 @@ constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
// Struct holding the different output streams produced by the graph.
struct ClassificationPostprocessingOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
// Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions(
@ -378,12 +386,23 @@ absl::Status ConfigureClassificationPostprocessingGraph(
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of timestamps that a single ClassificationResult should
// aggregate. This is mostly useful for classifiers working on time series,
// e.g. audio or video classification.
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// output is used for results. Otherwise as no timestamp aggregation is
// required the CLASSIFICATIONS output is used for results.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results.
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
@ -394,28 +413,39 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
mediapipe::SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto classification_result_out,
auto output_streams,
BuildClassificationPostprocessing(
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >>
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
return graph.GetConfig();
}
private:
// Adds an on-device classification postprocessing graph into the provided
// builder::Graph instance. The classification postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream containing the output classification results (ClassificationResult).
// tensors (std::vector<mediapipe::Tensor>) and optional timestamps
// (std::vector<Timestamp>) as input and returns two output streams:
// - classification results aggregated by classifier head as a
// ClassificationResult proto, used when no timestamps are passed in
// the graph,
// - classification results aggregated by timestamp then by classifier head
// as a std::vector<ClassificationResult>, used when timestamps are passed
// in the graph.
//
// options: the on-device ClassificationPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate.
// timestamps that should be used to aggregate classification results.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>>
absl::StatusOr<ClassificationPostprocessingOutputStreams>
BuildClassificationPostprocessing(
const proto::ClassificationPostprocessingGraphOptions& options,
Source<std::vector<Tensor>> tensors_in,
@ -505,8 +535,15 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
timestamps_in >> result_aggregation.In(kTimestampsTag);
// Connects output.
return result_aggregation[Output<ClassificationResult>(
kClassificationResultTag)];
ClassificationPostprocessingOutputStreams output_streams{
/*classification_result=*/result_aggregation
[Output<ClassificationResult>(kClassificationResultTag)],
/*classifications=*/
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
/*timestamped_classifications=*/
result_aggregation[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)]};
return output_streams;
}
};

View File

@ -45,12 +45,22 @@ namespace processors {
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of timestamps that a single ClassificationResult should
// aggregate. This is mostly useful for classifiers working on time series,
// e.g. audio or video classification.
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_CLASSIFICATIONS
// output is used for results. Otherwise as no timestamp aggregation is
// required the CLASSIFICATIONS output is used for results.
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results.
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_CLASSIFICATIONS - std::vector<ClassificationResult> @Optional
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
absl::Status ConfigureClassificationPostprocessingGraph(
const tasks::core::ModelResources& model_resources,
const proto::ClassifierOptions& classifier_options,

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
@ -64,6 +65,7 @@ using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
using ::testing::HasSubstr;
using ::testing::Pointwise;
using ::testing::proto::Approximately;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -86,6 +88,11 @@ constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationResultName[] = "classification_result";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
constexpr char kTimestampedClassificationsName[] =
"timestamped_classifications";
// Helper function to get ModelResources.
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
@ -413,6 +420,316 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
}
class PostprocessingTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::ClassifierOptions& options,
bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
*model_resources, options,
&postprocessing
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
postprocessing.In(kTimestampsTag);
postprocessing.Out(kTimestampedClassificationsTag)
.SetName(kTimestampedClassificationsName) >>
graph[Output<std::vector<ClassificationResult>>(
kTimestampedClassificationsTag)];
} else {
postprocessing.Out(kClassificationsTag).SetName(kClassificationsName) >>
graph[Output<ClassificationResult>(kClassificationsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
template <typename T>
void AddTensor(
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
const Tensor::QuantizationParameters& quantization_parameters = {}) {
tensors_->emplace_back(element_type,
Tensor::Shape{1, static_cast<int>(tensor.size())},
quantization_parameters);
auto view = tensors_->back().GetCpuWriteView();
T* buffer = view.buffer<T>();
std::copy(tensor.begin(), tensor.end(), buffer);
}
absl::Status Run(
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
int timestamp = 0) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
// Reset tensors for future calls.
tensors_ = absl::make_unique<std::vector<Tensor>>();
if (aggregation_timestamps.has_value()) {
auto packet = absl::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
std::unique_ptr<std::vector<Tensor>> tensors_ =
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
options.set_score_threshold(0.5);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 18;
tensor[2] = 16;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<ClassificationResult>(poller));
// Validate results.
EXPECT_THAT(results,
EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
classifications {
head_index: 0
classification_list {
classification { index: 1 score: 0.8 }
classification { index: 2 score: 0.6 }
}
}
)pb")));
}
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<ClassificationResult>(poller));
// Validate results.
EXPECT_THAT(
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
classifications {
head_index: 0
head_name: "probability"
classification_list {
classification { index: 4 score: 0.8 label: "tiger shark" }
classification { index: 3 score: 0.6 label: "great white shark" }
classification { index: 2 score: 0.4 label: "goldfish" }
}
}
)pb")));
}
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<ClassificationResult>(poller));
// Validate results.
EXPECT_THAT(
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
classifications {
head_index: 0
head_name: "probability"
classification_list {
classification { index: 4 score: 0.6899744811 label: "tiger shark" }
classification {
index: 3
score: 0.6456563062
label: "great white shark"
}
classification { index: 2 score: 0.5986876601 label: "goldfish" }
}
}
)pb")));
}
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
// Build input tensors.
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
tensor_0[1] = 0.2;
tensor_0[2] = 0.4;
tensor_0[3] = 0.6;
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
tensor_1[1] = 0.2;
tensor_1[2] = 0.4;
tensor_1[3] = 0.6;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<ClassificationResult>(poller));
// Validate results.
EXPECT_THAT(
results, EqualsProto(ParseTextProtoOrDie<ClassificationResult>(R"pb(
classifications {
head_index: 0
head_name: "yamnet_classification"
classification_list {
classification { index: 3 score: 0.6 label: "Narration, monologue" }
classification { index: 2 score: 0.4 label: "Conversation" }
}
}
classifications {
head_index: 1
head_name: "bird_classification"
classification_list {
classification { index: 3 score: 0.6 label: "Azara\'s Spinetail" }
classification { index: 2 score: 0.4 label: "House Sparrow" }
}
}
)pb")));
}
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
/*connect_timestamps=*/true));
// Build input tensors.
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
tensor_0[1] = 12;
tensor_0[2] = 14;
tensor_0[3] = 16;
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
tensor_1[5] = 12;
tensor_1[6] = 14;
tensor_1[7] = 16;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run(
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
/*timestamp=*/1000));
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<std::vector<ClassificationResult>>(poller));
// Validate results.
EXPECT_THAT(
results,
Pointwise(
EqualsProto(),
{ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 0
classifications {
head_index: 0
head_name: "probability"
classification_list {
classification {
index: 3
score: 0.6
label: "great white shark"
}
classification { index: 2 score: 0.4 label: "goldfish" }
}
})pb"),
ParseTextProtoOrDie<ClassificationResult>(R"pb(
timestamp_ms: 1
classifications {
head_index: 0
head_name: "probability"
classification_list {
classification { index: 7 score: 0.6 label: "stingray" }
classification { index: 6 score: 0.4 label: "electric ray" }
}
})pb")}));
}
// TODO: remove these tests once migration is over.
class LegacyPostprocessingTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::ClassifierOptions& options,
@ -496,7 +813,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
@ -525,7 +842,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
@ -568,7 +885,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
@ -614,7 +931,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(2);
@ -674,7 +991,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
})pb"));
}
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(2);