mediapipe/mediapipe/tasks/cc/components/classification_postprocessing.cc
MediaPipe Team f8af41b1eb Internal change
PiperOrigin-RevId: 477538515
2022-09-28 21:32:36 +00:00

513 lines
22 KiB
C++

/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include <stdint.h>
#include <limits>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace {
using ::mediapipe::Tensor;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::Timestamp;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::ClassifierOptions;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) {
if (options.max_results() == 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Invalid `max_results` option: value must be != 0.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
if (options.category_allowlist_size() > 0 &&
options.category_denylist_size() > 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"`category_allowlist` and `category_denylist` are mutually "
"exclusive options.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
struct ClassificationHeadsProperties {
int num_heads;
bool quantized;
};
// Identifies the number of classification heads and whether they are quantized
// or not.
absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
const ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
"Classification tflite models are "
"assumed to have a single subgraph.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
const auto* primary_subgraph = (*model.subgraphs())[0];
int num_output_tensors = primary_subgraph->outputs()->size();
// Sanity check tensor types and check if model outputs are quantized or not.
int num_quantized_tensors = 0;
for (int i = 0; i < num_output_tensors; ++i) {
const auto* tensor =
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
if (tensor->type() != tflite::TensorType_FLOAT32 &&
tensor->type() != tflite::TensorType_UINT8) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected output tensor at index %d to have type "
"UINT8 or FLOAT32, found %s instead.",
i, tflite::EnumNameTensorType(tensor->type())),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
}
if (tensor->type() == tflite::TensorType_UINT8) {
num_quantized_tensors++;
}
}
if (num_quantized_tensors != num_output_tensors &&
num_quantized_tensors != 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected either all or none of the output tensors to be "
"quantized, but found %d quantized outputs for %d total outputs.",
num_quantized_tensors, num_output_tensors),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
}
// Check if metadata is consistent with model topology.
const auto* output_tensors_metadata =
model_resources.GetMetadataExtractor()->GetOutputTensorMetadata();
if (output_tensors_metadata != nullptr &&
num_output_tensors != output_tensors_metadata->size()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of output tensors (%d) and "
"output tensors metadata (%d).",
num_output_tensors, output_tensors_metadata->size()),
MediaPipeTasksStatus::kMetadataInconsistencyError);
}
return ClassificationHeadsProperties{
/* num_heads= */ num_output_tensors,
/* quantized= */ num_quantized_tensors > 0};
}
// Builds the label map from the tensor metadata, if available.
absl::StatusOr<LabelItems> GetLabelItemsIfAny(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata, absl::string_view locale) {
const std::string labels_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
if (labels_filename.empty()) {
LabelItems empty_label_items;
return empty_label_items;
}
ASSIGN_OR_RETURN(absl::string_view labels_file,
metadata_extractor.GetAssociatedFile(labels_filename));
const std::string display_names_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
locale);
absl::string_view display_names_file;
if (!display_names_filename.empty()) {
ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
display_names_filename));
}
return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file);
}
// Gets the score threshold from metadata, if any. Returns
// kDefaultScoreThreshold otherwise.
absl::StatusOr<float> GetScoreThreshold(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) {
ASSIGN_OR_RETURN(const ProcessUnit* score_thresholding_process_unit,
metadata_extractor.FindFirstProcessUnit(
tensor_metadata,
tflite::ProcessUnitOptions_ScoreThresholdingOptions));
if (score_thresholding_process_unit == nullptr) {
return kDefaultScoreThreshold;
}
return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
->global_score_threshold();
}
// Gets the category allowlist or denylist (if any) as a set of indices.
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
const ClassifierOptions& options, const LabelItems& label_items) {
absl::flat_hash_set<int> category_indices;
// Exit early if no denylist/allowlist.
if (options.category_denylist_size() == 0 &&
options.category_allowlist_size() == 0) {
return category_indices;
}
if (label_items.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Using `category_allowlist` or `category_denylist` requires labels to "
"be present in the TFLite Model Metadata but none was found.",
MediaPipeTasksStatus::kMetadataMissingLabelsError);
}
const auto& category_list = options.category_allowlist_size() > 0
? options.category_allowlist()
: options.category_denylist();
for (const auto& category_name : category_list) {
int index = -1;
for (int i = 0; i < label_items.size(); ++i) {
if (label_items.at(i).name() == category_name) {
index = i;
break;
}
}
// Ignores duplicate or unknown categories.
if (index < 0) {
continue;
}
category_indices.insert(index);
}
return category_indices;
}
absl::Status ConfigureScoreCalibrationIfAny(
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
ClassificationPostprocessingOptions* options) {
const auto* tensor_metadata =
metadata_extractor.GetOutputTensorMetadata(tensor_index);
if (tensor_metadata == nullptr) {
return absl::OkStatus();
}
// Get ScoreCalibrationOptions, if any.
ASSIGN_OR_RETURN(const ProcessUnit* score_calibration_process_unit,
metadata_extractor.FindFirstProcessUnit(
*tensor_metadata,
tflite::ProcessUnitOptions_ScoreCalibrationOptions));
if (score_calibration_process_unit == nullptr) {
return absl::OkStatus();
}
auto* score_calibration_options =
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
// Get corresponding AssociatedFile.
auto score_calibration_filename =
metadata_extractor.FindFirstAssociatedFileName(
*tensor_metadata,
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
if (score_calibration_filename.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kNotFound,
"Found ScoreCalibrationOptions but missing required associated "
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
}
ASSIGN_OR_RETURN(
absl::string_view score_calibration_file,
metadata_extractor.GetAssociatedFile(score_calibration_filename));
ScoreCalibrationCalculatorOptions calculator_options;
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
score_calibration_options->score_transformation(),
score_calibration_options->default_score(), score_calibration_file,
&calculator_options));
(*options->mutable_score_calibration_options())[tensor_index] =
calculator_options;
return absl::OkStatus();
}
// Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator(
const ClassifierOptions& options,
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
TensorsToClassificationCalculatorOptions* calculator_options) {
const auto* tensor_metadata =
metadata_extractor.GetOutputTensorMetadata(tensor_index);
// Extract label map and score threshold from metadata, if available. Those
// are optional for classification models.
LabelItems label_items;
float score_threshold = kDefaultScoreThreshold;
if (tensor_metadata != nullptr) {
ASSIGN_OR_RETURN(label_items,
GetLabelItemsIfAny(metadata_extractor, *tensor_metadata,
options.display_names_locale()));
ASSIGN_OR_RETURN(score_threshold,
GetScoreThreshold(metadata_extractor, *tensor_metadata));
}
// Allowlist / denylist.
ASSIGN_OR_RETURN(auto allow_or_deny_categories,
GetAllowOrDenyCategoryIndicesIfAny(options, label_items));
if (!allow_or_deny_categories.empty()) {
if (options.category_allowlist_size()) {
calculator_options->mutable_allow_classes()->Assign(
allow_or_deny_categories.begin(), allow_or_deny_categories.end());
} else {
calculator_options->mutable_ignore_classes()->Assign(
allow_or_deny_categories.begin(), allow_or_deny_categories.end());
}
}
// Score threshold.
if (options.has_score_threshold()) {
score_threshold = options.score_threshold();
}
calculator_options->set_min_score_threshold(score_threshold);
// Number of results.
if (options.max_results() > 0) {
calculator_options->set_top_k(options.max_results());
} else {
// Setting to a negative value lets the calculator return all results.
calculator_options->set_top_k(-1);
}
// Label map.
*calculator_options->mutable_label_items() = std::move(label_items);
// Always sort results.
calculator_options->set_sort_by_descending_score(true);
return absl::OkStatus();
}
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
absl::Status ConfigureClassificationPostprocessing(
const ModelResources& model_resources,
const ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options) {
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
ASSIGN_OR_RETURN(const auto heads_properties,
GetClassificationHeadsProperties(model_resources));
for (int i = 0; i < heads_properties.num_heads; ++i) {
MP_RETURN_IF_ERROR(ConfigureScoreCalibrationIfAny(
*model_resources.GetMetadataExtractor(), i, options));
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
classifier_options, *model_resources.GetMetadataExtractor(), i,
options->add_tensors_to_classifications_options()));
}
ConfigureClassificationAggregationCalculator(
*model_resources.GetMetadataExtractor(),
options->mutable_classification_aggregation_options());
options->set_has_quantized_outputs(heads_properties.quantized);
return absl::OkStatus();
}
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
// raw tensors into ClassificationResult objects.
// - Accepts CPU input tensors.
//
// Inputs:
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of timestamps that a single ClassificationResult should
// aggregate. This is mostly useful for classifiers working on time series,
// e.g. audio or video classification.
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results.
//
// The recommended way of using this subgraph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessing()' function. See header file
// for more details.
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto classification_result_out,
BuildClassificationPostprocessing(
sc->Options<ClassificationPostprocessingOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
return graph.GetConfig();
}
private:
// Adds an on-device classification postprocessing subgraph into the provided
// builder::Graph instance. The classification postprocessing subgraph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream containing the output classification results (ClassificationResult).
//
// options: the on-device ClassificationPostprocessingOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>>
BuildClassificationPostprocessing(
const ClassificationPostprocessingOptions& options,
Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
const int num_heads = options.tensors_to_classifications_options_size();
// Sanity check.
if (num_heads == 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ClassificationPostprocessingOptions must contain at least one "
"TensorsToClassificationCalculatorOptions.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
// If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in);
if (options.has_quantized_outputs()) {
GenericNode* tensors_dequantization_node =
&graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node->In(kTensorsTag);
dequantized_tensors = {tensors_dequantization_node, kTensorsTag};
}
// If there are multiple classification heads, the output tensors need to be
// split.
std::vector<TensorsSource> split_tensors;
split_tensors.reserve(num_heads);
if (num_heads > 1) {
GenericNode* split_tensor_vector_node =
&graph.AddNode("SplitTensorVectorCalculator");
auto& split_tensor_vector_options =
split_tensor_vector_node
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
for (int i = 0; i < num_heads; ++i) {
auto* range = split_tensor_vector_options.add_ranges();
range->set_begin(i);
range->set_end(i + 1);
split_tensors.emplace_back(split_tensor_vector_node, i);
}
dequantized_tensors >> split_tensor_vector_node->In(0);
} else {
split_tensors.emplace_back(dequantized_tensors);
}
// Adds score calibration for heads that specify it, if any.
std::vector<TensorsSource> calibrated_tensors;
calibrated_tensors.reserve(num_heads);
for (int i = 0; i < num_heads; ++i) {
if (options.score_calibration_options().contains(i)) {
GenericNode* score_calibration_node =
&graph.AddNode("ScoreCalibrationCalculator");
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
.CopyFrom(options.score_calibration_options().at(i));
split_tensors[i] >> score_calibration_node->In(kScoresTag);
calibrated_tensors.emplace_back(score_calibration_node,
kCalibratedScoresTag);
} else {
calibrated_tensors.emplace_back(split_tensors[i]);
}
}
// Adds a TensorsToClassificationCalculator for each head.
std::vector<GenericNode*> tensors_to_classification_nodes;
tensors_to_classification_nodes.reserve(num_heads);
for (int i = 0; i < num_heads; ++i) {
tensors_to_classification_nodes.emplace_back(
&graph.AddNode("TensorsToClassificationCalculator"));
tensors_to_classification_nodes.back()
->GetOptions<TensorsToClassificationCalculatorOptions>()
.CopyFrom(options.tensors_to_classifications_options(i));
calibrated_tensors[i] >>
tensors_to_classification_nodes.back()->In(kTensorsTag);
}
// Aggregates Classifications into a single ClassificationResult.
auto& result_aggregation =
graph.AddNode("ClassificationAggregationCalculator");
result_aggregation.GetOptions<ClassificationAggregationCalculatorOptions>()
.CopyFrom(options.classification_aggregation_options());
for (int i = 0; i < num_heads; ++i) {
tensors_to_classification_nodes[i]->Out(kClassificationsTag) >>
result_aggregation.In(
absl::StrFormat("%s:%d", kClassificationsTag, i));
}
timestamps_in >> result_aggregation.In(kTimestampsTag);
// Connects output.
return result_aggregation[Output<ClassificationResult>(
kClassificationResultTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
} // namespace components
} // namespace tasks
} // namespace mediapipe