Detection postprocessing support quantized tensor.

PiperOrigin-RevId: 572310272
This commit is contained in:
MediaPipe Team 2023-10-10 11:10:16 -07:00 committed by Copybara-Service
parent df13788883
commit d6d92354ea
4 changed files with 74 additions and 5 deletions

View File

@ -169,6 +169,7 @@ cc_library(
deps = [ deps = [
"//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tensor:tensors_to_detections_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
"//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator",

View File

@ -710,6 +710,57 @@ absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
return model_output_tensors; return model_output_tensors;
} }
// Identifies whether or not the model has quantized outputs, and performs
// sanity checks.
absl::StatusOr<bool> HasQuantizedOutputs(
const core::ModelResources& model_resources) {
const tflite::Model& model = *model_resources.GetTfLiteModel();
// Model is checked to have single subgraph before.
const auto* primary_subgraph = (*model.subgraphs())[0];
int num_output_tensors = primary_subgraph->outputs()->size();
// Sanity check tensor types and check if model outputs are quantized or not.
int num_quantized_tensors = 0;
for (int i = 0; i < num_output_tensors; ++i) {
const auto* tensor =
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
if (tensor->type() != tflite::TensorType_FLOAT32 &&
tensor->type() != tflite::TensorType_UINT8) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected output tensor at index %d to have type "
"UINT8 or FLOAT32, found %s instead.",
i, tflite::EnumNameTensorType(tensor->type())),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
}
if (tensor->type() == tflite::TensorType_UINT8) {
num_quantized_tensors++;
}
}
if (num_quantized_tensors != num_output_tensors &&
num_quantized_tensors != 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected either all or none of the output tensors to be "
"quantized, but found %d quantized outputs for %d total outputs.",
num_quantized_tensors, num_output_tensors),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
}
// Check if metadata is consistent with model topology.
const auto* output_tensors_metadata =
model_resources.GetMetadataExtractor()->GetOutputTensorMetadata();
if (output_tensors_metadata != nullptr &&
num_output_tensors != output_tensors_metadata->size()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of output tensors (%d) and "
"output tensors metadata (%d).",
num_output_tensors, output_tensors_metadata->size()),
MediaPipeTasksStatus::kMetadataInconsistencyError);
}
return num_quantized_tensors > 0;
}
} // namespace } // namespace
absl::Status ConfigureDetectionPostprocessingGraph( absl::Status ConfigureDetectionPostprocessingGraph(
@ -738,7 +789,9 @@ absl::Status ConfigureDetectionPostprocessingGraph(
model.subgraphs()->Get(0)->outputs()->size()), model.subgraphs()->Get(0)->outputs()->size()),
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
MP_ASSIGN_OR_RETURN(bool has_quantized_outputs,
HasQuantizedOutputs(model_resources));
options.set_has_quantized_outputs(has_quantized_outputs);
const ModelMetadataExtractor* metadata_extractor = const ModelMetadataExtractor* metadata_extractor =
model_resources.GetMetadataExtractor(); model_resources.GetMetadataExtractor();
if (in_model_nms) { if (in_model_nms) {
@ -820,12 +873,20 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
BuildDetectionPostprocessing( BuildDetectionPostprocessing(
proto::DetectionPostprocessingGraphOptions& graph_options, proto::DetectionPostprocessingGraphOptions& graph_options,
Source<std::vector<Tensor>> tensors_in, Graph& graph) { Source<std::vector<Tensor>> tensors_in, Graph& graph) {
Source<std::vector<Tensor>> tensors = tensors_in;
if (graph_options.has_quantized_outputs()) {
auto& tensors_dequantization_node =
graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node.In(kTensorsTag);
tensors = tensors_dequantization_node.Out(kTensorsTag)
.Cast<std::vector<Tensor>>();
}
std::optional<Source<std::vector<Detection>>> detections; std::optional<Source<std::vector<Detection>>> detections;
if (!graph_options.has_non_max_suppression_options()) { if (!graph_options.has_non_max_suppression_options()) {
// Calculators to perform score calibration, if specified in the options. // Calculators to perform score calibration, if specified in the options.
if (graph_options.has_score_calibration_options()) { if (graph_options.has_score_calibration_options()) {
MP_ASSIGN_OR_RETURN(tensors_in, MP_ASSIGN_OR_RETURN(tensors,
CalibrateScores(tensors_in, graph_options, graph)); CalibrateScores(tensors, graph_options, graph));
} }
// Calculator to convert output tensors to a detection proto vector. // Calculator to convert output tensors to a detection proto vector.
auto& tensors_to_detections = auto& tensors_to_detections =
@ -833,7 +894,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
tensors_to_detections tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>() .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
.Swap(graph_options.mutable_tensors_to_detections_options()); .Swap(graph_options.mutable_tensors_to_detections_options());
tensors_in >> tensors_to_detections.In(kTensorsTag); tensors >> tensors_to_detections.In(kTensorsTag);
detections = tensors_to_detections.Out(kDetectionsTag) detections = tensors_to_detections.Out(kDetectionsTag)
.Cast<std::vector<Detection>>(); .Cast<std::vector<Detection>>();
} else { } else {
@ -850,7 +911,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>() .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
.Swap(graph_options.mutable_tensors_to_detections_options()); .Swap(graph_options.mutable_tensors_to_detections_options());
anchors >> tensors_to_detections.SideIn(kAnchorsTag); anchors >> tensors_to_detections.SideIn(kAnchorsTag);
tensors_in >> tensors_to_detections.In(kTensorsTag); tensors >> tensors_to_detections.In(kTensorsTag);
detections = tensors_to_detections.Out(kDetectionsTag) detections = tensors_to_detections.Out(kDetectionsTag)
.Cast<std::vector<mediapipe::Detection>>(); .Cast<std::vector<mediapipe::Detection>>();
// Non maximum suppression removes redundant object detections. // Non maximum suppression removes redundant object detections.

View File

@ -213,6 +213,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
} }
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
} }
has_quantized_outputs: false
)pb")))); )pb"))));
EXPECT_THAT( EXPECT_THAT(
options_out.detection_label_ids_to_text_options().label_items_size(), 90); options_out.detection_label_ids_to_text_options().label_items_size(), 90);
@ -244,6 +245,7 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
} }
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
} }
has_quantized_outputs: false
)pb"))); )pb")));
} }
@ -273,6 +275,7 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
} }
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
} }
has_quantized_outputs: false
)pb"))); )pb")));
} }
@ -311,6 +314,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
score_transformation: IDENTITY score_transformation: IDENTITY
default_score: 0.5 default_score: 0.5
} }
has_quantized_outputs: false
)pb"))); )pb")));
} }

View File

@ -46,4 +46,7 @@ message DetectionPostprocessingGraphOptions {
// Optional detection label id to text calculator options. // Optional detection label id to text calculator options.
optional mediapipe.DetectionLabelIdToTextCalculatorOptions optional mediapipe.DetectionLabelIdToTextCalculatorOptions
detection_label_ids_to_text_options = 5; detection_label_ids_to_text_options = 5;
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
optional bool has_quantized_outputs = 6;
} }