Detection postprocessing support quantized tensor.
PiperOrigin-RevId: 572310272
This commit is contained in:
parent
df13788883
commit
d6d92354ea
|
@ -169,6 +169,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:split_vector_calculator",
|
"//mediapipe/calculators/core:split_vector_calculator",
|
||||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
||||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
|
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
|
||||||
|
|
|
@ -710,6 +710,57 @@ absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
|
||||||
return model_output_tensors;
|
return model_output_tensors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Identifies whether or not the model has quantized outputs, and performs
|
||||||
|
// sanity checks.
|
||||||
|
absl::StatusOr<bool> HasQuantizedOutputs(
|
||||||
|
const core::ModelResources& model_resources) {
|
||||||
|
const tflite::Model& model = *model_resources.GetTfLiteModel();
|
||||||
|
// Model is checked to have single subgraph before.
|
||||||
|
const auto* primary_subgraph = (*model.subgraphs())[0];
|
||||||
|
int num_output_tensors = primary_subgraph->outputs()->size();
|
||||||
|
// Sanity check tensor types and check if model outputs are quantized or not.
|
||||||
|
int num_quantized_tensors = 0;
|
||||||
|
for (int i = 0; i < num_output_tensors; ++i) {
|
||||||
|
const auto* tensor =
|
||||||
|
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
|
||||||
|
if (tensor->type() != tflite::TensorType_FLOAT32 &&
|
||||||
|
tensor->type() != tflite::TensorType_UINT8) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Expected output tensor at index %d to have type "
|
||||||
|
"UINT8 or FLOAT32, found %s instead.",
|
||||||
|
i, tflite::EnumNameTensorType(tensor->type())),
|
||||||
|
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||||
|
}
|
||||||
|
if (tensor->type() == tflite::TensorType_UINT8) {
|
||||||
|
num_quantized_tensors++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (num_quantized_tensors != num_output_tensors &&
|
||||||
|
num_quantized_tensors != 0) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"Expected either all or none of the output tensors to be "
|
||||||
|
"quantized, but found %d quantized outputs for %d total outputs.",
|
||||||
|
num_quantized_tensors, num_output_tensors),
|
||||||
|
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||||
|
}
|
||||||
|
// Check if metadata is consistent with model topology.
|
||||||
|
const auto* output_tensors_metadata =
|
||||||
|
model_resources.GetMetadataExtractor()->GetOutputTensorMetadata();
|
||||||
|
if (output_tensors_metadata != nullptr &&
|
||||||
|
num_output_tensors != output_tensors_metadata->size()) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Mismatch between number of output tensors (%d) and "
|
||||||
|
"output tensors metadata (%d).",
|
||||||
|
num_output_tensors, output_tensors_metadata->size()),
|
||||||
|
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||||
|
}
|
||||||
|
return num_quantized_tensors > 0;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status ConfigureDetectionPostprocessingGraph(
|
absl::Status ConfigureDetectionPostprocessingGraph(
|
||||||
|
@ -738,7 +789,9 @@ absl::Status ConfigureDetectionPostprocessingGraph(
|
||||||
model.subgraphs()->Get(0)->outputs()->size()),
|
model.subgraphs()->Get(0)->outputs()->size()),
|
||||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||||
}
|
}
|
||||||
|
MP_ASSIGN_OR_RETURN(bool has_quantized_outputs,
|
||||||
|
HasQuantizedOutputs(model_resources));
|
||||||
|
options.set_has_quantized_outputs(has_quantized_outputs);
|
||||||
const ModelMetadataExtractor* metadata_extractor =
|
const ModelMetadataExtractor* metadata_extractor =
|
||||||
model_resources.GetMetadataExtractor();
|
model_resources.GetMetadataExtractor();
|
||||||
if (in_model_nms) {
|
if (in_model_nms) {
|
||||||
|
@ -820,12 +873,20 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
BuildDetectionPostprocessing(
|
BuildDetectionPostprocessing(
|
||||||
proto::DetectionPostprocessingGraphOptions& graph_options,
|
proto::DetectionPostprocessingGraphOptions& graph_options,
|
||||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
||||||
|
Source<std::vector<Tensor>> tensors = tensors_in;
|
||||||
|
if (graph_options.has_quantized_outputs()) {
|
||||||
|
auto& tensors_dequantization_node =
|
||||||
|
graph.AddNode("TensorsDequantizationCalculator");
|
||||||
|
tensors_in >> tensors_dequantization_node.In(kTensorsTag);
|
||||||
|
tensors = tensors_dequantization_node.Out(kTensorsTag)
|
||||||
|
.Cast<std::vector<Tensor>>();
|
||||||
|
}
|
||||||
std::optional<Source<std::vector<Detection>>> detections;
|
std::optional<Source<std::vector<Detection>>> detections;
|
||||||
if (!graph_options.has_non_max_suppression_options()) {
|
if (!graph_options.has_non_max_suppression_options()) {
|
||||||
// Calculators to perform score calibration, if specified in the options.
|
// Calculators to perform score calibration, if specified in the options.
|
||||||
if (graph_options.has_score_calibration_options()) {
|
if (graph_options.has_score_calibration_options()) {
|
||||||
MP_ASSIGN_OR_RETURN(tensors_in,
|
MP_ASSIGN_OR_RETURN(tensors,
|
||||||
CalibrateScores(tensors_in, graph_options, graph));
|
CalibrateScores(tensors, graph_options, graph));
|
||||||
}
|
}
|
||||||
// Calculator to convert output tensors to a detection proto vector.
|
// Calculator to convert output tensors to a detection proto vector.
|
||||||
auto& tensors_to_detections =
|
auto& tensors_to_detections =
|
||||||
|
@ -833,7 +894,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
tensors_to_detections
|
tensors_to_detections
|
||||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
||||||
.Swap(graph_options.mutable_tensors_to_detections_options());
|
.Swap(graph_options.mutable_tensors_to_detections_options());
|
||||||
tensors_in >> tensors_to_detections.In(kTensorsTag);
|
tensors >> tensors_to_detections.In(kTensorsTag);
|
||||||
detections = tensors_to_detections.Out(kDetectionsTag)
|
detections = tensors_to_detections.Out(kDetectionsTag)
|
||||||
.Cast<std::vector<Detection>>();
|
.Cast<std::vector<Detection>>();
|
||||||
} else {
|
} else {
|
||||||
|
@ -850,7 +911,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
||||||
.Swap(graph_options.mutable_tensors_to_detections_options());
|
.Swap(graph_options.mutable_tensors_to_detections_options());
|
||||||
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
|
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
|
||||||
tensors_in >> tensors_to_detections.In(kTensorsTag);
|
tensors >> tensors_to_detections.In(kTensorsTag);
|
||||||
detections = tensors_to_detections.Out(kDetectionsTag)
|
detections = tensors_to_detections.Out(kDetectionsTag)
|
||||||
.Cast<std::vector<mediapipe::Detection>>();
|
.Cast<std::vector<mediapipe::Detection>>();
|
||||||
// Non maximum suppression removes redundant object detections.
|
// Non maximum suppression removes redundant object detections.
|
||||||
|
|
|
@ -213,6 +213,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
||||||
}
|
}
|
||||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||||
}
|
}
|
||||||
|
has_quantized_outputs: false
|
||||||
)pb"))));
|
)pb"))));
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
|
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
|
||||||
|
@ -244,6 +245,7 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
||||||
}
|
}
|
||||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||||
}
|
}
|
||||||
|
has_quantized_outputs: false
|
||||||
)pb")));
|
)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,6 +275,7 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
||||||
}
|
}
|
||||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||||
}
|
}
|
||||||
|
has_quantized_outputs: false
|
||||||
)pb")));
|
)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,6 +314,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
||||||
score_transformation: IDENTITY
|
score_transformation: IDENTITY
|
||||||
default_score: 0.5
|
default_score: 0.5
|
||||||
}
|
}
|
||||||
|
has_quantized_outputs: false
|
||||||
)pb")));
|
)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -46,4 +46,7 @@ message DetectionPostprocessingGraphOptions {
|
||||||
// Optional detection label id to text calculator options.
|
// Optional detection label id to text calculator options.
|
||||||
optional mediapipe.DetectionLabelIdToTextCalculatorOptions
|
optional mediapipe.DetectionLabelIdToTextCalculatorOptions
|
||||||
detection_label_ids_to_text_options = 5;
|
detection_label_ids_to_text_options = 5;
|
||||||
|
|
||||||
|
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||||
|
optional bool has_quantized_outputs = 6;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user