Detection postprocessing support quantized tensor.
PiperOrigin-RevId: 572310272
This commit is contained in:
parent
df13788883
commit
d6d92354ea
|
@ -169,6 +169,7 @@ cc_library(
|
|||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
|
||||
|
|
|
@ -710,6 +710,57 @@ absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
|
|||
return model_output_tensors;
|
||||
}
|
||||
|
||||
// Identifies whether or not the model has quantized outputs, and performs
|
||||
// sanity checks.
|
||||
absl::StatusOr<bool> HasQuantizedOutputs(
|
||||
const core::ModelResources& model_resources) {
|
||||
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||
// Model is checked to have single subgraph before.
|
||||
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||
int num_output_tensors = primary_subgraph->outputs()->size();
|
||||
// Sanity check tensor types and check if model outputs are quantized or not.
|
||||
int num_quantized_tensors = 0;
|
||||
for (int i = 0; i < num_output_tensors; ++i) {
|
||||
const auto* tensor =
|
||||
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
|
||||
if (tensor->type() != tflite::TensorType_FLOAT32 &&
|
||||
tensor->type() != tflite::TensorType_UINT8) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected output tensor at index %d to have type "
|
||||
"UINT8 or FLOAT32, found %s instead.",
|
||||
i, tflite::EnumNameTensorType(tensor->type())),
|
||||
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||
}
|
||||
if (tensor->type() == tflite::TensorType_UINT8) {
|
||||
num_quantized_tensors++;
|
||||
}
|
||||
}
|
||||
if (num_quantized_tensors != num_output_tensors &&
|
||||
num_quantized_tensors != 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected either all or none of the output tensors to be "
|
||||
"quantized, but found %d quantized outputs for %d total outputs.",
|
||||
num_quantized_tensors, num_output_tensors),
|
||||
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||
}
|
||||
// Check if metadata is consistent with model topology.
|
||||
const auto* output_tensors_metadata =
|
||||
model_resources.GetMetadataExtractor()->GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata != nullptr &&
|
||||
num_output_tensors != output_tensors_metadata->size()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Mismatch between number of output tensors (%d) and "
|
||||
"output tensors metadata (%d).",
|
||||
num_output_tensors, output_tensors_metadata->size()),
|
||||
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||
}
|
||||
return num_quantized_tensors > 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureDetectionPostprocessingGraph(
|
||||
|
@ -738,7 +789,9 @@ absl::Status ConfigureDetectionPostprocessingGraph(
|
|||
model.subgraphs()->Get(0)->outputs()->size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
MP_ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
||||
HasQuantizedOutputs(model_resources));
|
||||
options.set_has_quantized_outputs(has_quantized_outputs);
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
model_resources.GetMetadataExtractor();
|
||||
if (in_model_nms) {
|
||||
|
@ -820,12 +873,20 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
|||
BuildDetectionPostprocessing(
|
||||
proto::DetectionPostprocessingGraphOptions& graph_options,
|
||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
||||
Source<std::vector<Tensor>> tensors = tensors_in;
|
||||
if (graph_options.has_quantized_outputs()) {
|
||||
auto& tensors_dequantization_node =
|
||||
graph.AddNode("TensorsDequantizationCalculator");
|
||||
tensors_in >> tensors_dequantization_node.In(kTensorsTag);
|
||||
tensors = tensors_dequantization_node.Out(kTensorsTag)
|
||||
.Cast<std::vector<Tensor>>();
|
||||
}
|
||||
std::optional<Source<std::vector<Detection>>> detections;
|
||||
if (!graph_options.has_non_max_suppression_options()) {
|
||||
// Calculators to perform score calibration, if specified in the options.
|
||||
if (graph_options.has_score_calibration_options()) {
|
||||
MP_ASSIGN_OR_RETURN(tensors_in,
|
||||
CalibrateScores(tensors_in, graph_options, graph));
|
||||
MP_ASSIGN_OR_RETURN(tensors,
|
||||
CalibrateScores(tensors, graph_options, graph));
|
||||
}
|
||||
// Calculator to convert output tensors to a detection proto vector.
|
||||
auto& tensors_to_detections =
|
||||
|
@ -833,7 +894,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
|||
tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
||||
.Swap(graph_options.mutable_tensors_to_detections_options());
|
||||
tensors_in >> tensors_to_detections.In(kTensorsTag);
|
||||
tensors >> tensors_to_detections.In(kTensorsTag);
|
||||
detections = tensors_to_detections.Out(kDetectionsTag)
|
||||
.Cast<std::vector<Detection>>();
|
||||
} else {
|
||||
|
@ -850,7 +911,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
|||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
||||
.Swap(graph_options.mutable_tensors_to_detections_options());
|
||||
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
|
||||
tensors_in >> tensors_to_detections.In(kTensorsTag);
|
||||
tensors >> tensors_to_detections.In(kTensorsTag);
|
||||
detections = tensors_to_detections.Out(kDetectionsTag)
|
||||
.Cast<std::vector<mediapipe::Detection>>();
|
||||
// Non maximum suppression removes redundant object detections.
|
||||
|
|
|
@ -213,6 +213,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
|||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
has_quantized_outputs: false
|
||||
)pb"))));
|
||||
EXPECT_THAT(
|
||||
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
|
||||
|
@ -244,6 +245,7 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
|||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
has_quantized_outputs: false
|
||||
)pb")));
|
||||
}
|
||||
|
||||
|
@ -273,6 +275,7 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
|||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
has_quantized_outputs: false
|
||||
)pb")));
|
||||
}
|
||||
|
||||
|
@ -311,6 +314,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
|||
score_transformation: IDENTITY
|
||||
default_score: 0.5
|
||||
}
|
||||
has_quantized_outputs: false
|
||||
)pb")));
|
||||
}
|
||||
|
||||
|
|
|
@ -46,4 +46,7 @@ message DetectionPostprocessingGraphOptions {
|
|||
// Optional detection label id to text calculator options.
|
||||
optional mediapipe.DetectionLabelIdToTextCalculatorOptions
|
||||
detection_label_ids_to_text_options = 5;
|
||||
|
||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||
optional bool has_quantized_outputs = 6;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user