Detection postprocessing support quantized tensor.
PiperOrigin-RevId: 572310272
This commit is contained in:
		
							parent
							
								
									df13788883
								
							
						
					
					
						commit
						d6d92354ea
					
				| 
						 | 
					@ -169,6 +169,7 @@ cc_library(
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        "//mediapipe/calculators/core:split_vector_calculator",
 | 
					        "//mediapipe/calculators/core:split_vector_calculator",
 | 
				
			||||||
        "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
 | 
					        "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/tensor:tensors_dequantization_calculator",
 | 
				
			||||||
        "//mediapipe/calculators/tensor:tensors_to_detections_calculator",
 | 
					        "//mediapipe/calculators/tensor:tensors_to_detections_calculator",
 | 
				
			||||||
        "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
 | 
					        "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
 | 
				
			||||||
        "//mediapipe/calculators/tflite:ssd_anchors_calculator",
 | 
					        "//mediapipe/calculators/tflite:ssd_anchors_calculator",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -710,6 +710,57 @@ absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
 | 
				
			||||||
  return model_output_tensors;
 | 
					  return model_output_tensors;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Identifies whether or not the model has quantized outputs, and performs
 | 
				
			||||||
 | 
					// sanity checks.
 | 
				
			||||||
 | 
					absl::StatusOr<bool> HasQuantizedOutputs(
 | 
				
			||||||
 | 
					    const core::ModelResources& model_resources) {
 | 
				
			||||||
 | 
					  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
				
			||||||
 | 
					  // Model is checked to have single subgraph before.
 | 
				
			||||||
 | 
					  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
				
			||||||
 | 
					  int num_output_tensors = primary_subgraph->outputs()->size();
 | 
				
			||||||
 | 
					  // Sanity check tensor types and check if model outputs are quantized or not.
 | 
				
			||||||
 | 
					  int num_quantized_tensors = 0;
 | 
				
			||||||
 | 
					  for (int i = 0; i < num_output_tensors; ++i) {
 | 
				
			||||||
 | 
					    const auto* tensor =
 | 
				
			||||||
 | 
					        primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
 | 
				
			||||||
 | 
					    if (tensor->type() != tflite::TensorType_FLOAT32 &&
 | 
				
			||||||
 | 
					        tensor->type() != tflite::TensorType_UINT8) {
 | 
				
			||||||
 | 
					      return CreateStatusWithPayload(
 | 
				
			||||||
 | 
					          absl::StatusCode::kInvalidArgument,
 | 
				
			||||||
 | 
					          absl::StrFormat("Expected output tensor at index %d to have type "
 | 
				
			||||||
 | 
					                          "UINT8 or FLOAT32, found %s instead.",
 | 
				
			||||||
 | 
					                          i, tflite::EnumNameTensorType(tensor->type())),
 | 
				
			||||||
 | 
					          MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (tensor->type() == tflite::TensorType_UINT8) {
 | 
				
			||||||
 | 
					      num_quantized_tensors++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  if (num_quantized_tensors != num_output_tensors &&
 | 
				
			||||||
 | 
					      num_quantized_tensors != 0) {
 | 
				
			||||||
 | 
					    return CreateStatusWithPayload(
 | 
				
			||||||
 | 
					        absl::StatusCode::kInvalidArgument,
 | 
				
			||||||
 | 
					        absl::StrFormat(
 | 
				
			||||||
 | 
					            "Expected either all or none of the output tensors to be "
 | 
				
			||||||
 | 
					            "quantized, but found %d quantized outputs for %d total outputs.",
 | 
				
			||||||
 | 
					            num_quantized_tensors, num_output_tensors),
 | 
				
			||||||
 | 
					        MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Check if metadata is consistent with model topology.
 | 
				
			||||||
 | 
					  const auto* output_tensors_metadata =
 | 
				
			||||||
 | 
					      model_resources.GetMetadataExtractor()->GetOutputTensorMetadata();
 | 
				
			||||||
 | 
					  if (output_tensors_metadata != nullptr &&
 | 
				
			||||||
 | 
					      num_output_tensors != output_tensors_metadata->size()) {
 | 
				
			||||||
 | 
					    return CreateStatusWithPayload(
 | 
				
			||||||
 | 
					        absl::StatusCode::kInvalidArgument,
 | 
				
			||||||
 | 
					        absl::StrFormat("Mismatch between number of output tensors (%d) and "
 | 
				
			||||||
 | 
					                        "output tensors metadata (%d).",
 | 
				
			||||||
 | 
					                        num_output_tensors, output_tensors_metadata->size()),
 | 
				
			||||||
 | 
					        MediaPipeTasksStatus::kMetadataInconsistencyError);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return num_quantized_tensors > 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
absl::Status ConfigureDetectionPostprocessingGraph(
 | 
					absl::Status ConfigureDetectionPostprocessingGraph(
 | 
				
			||||||
| 
						 | 
					@ -738,7 +789,9 @@ absl::Status ConfigureDetectionPostprocessingGraph(
 | 
				
			||||||
            model.subgraphs()->Get(0)->outputs()->size()),
 | 
					            model.subgraphs()->Get(0)->outputs()->size()),
 | 
				
			||||||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
					        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  MP_ASSIGN_OR_RETURN(bool has_quantized_outputs,
 | 
				
			||||||
 | 
					                      HasQuantizedOutputs(model_resources));
 | 
				
			||||||
 | 
					  options.set_has_quantized_outputs(has_quantized_outputs);
 | 
				
			||||||
  const ModelMetadataExtractor* metadata_extractor =
 | 
					  const ModelMetadataExtractor* metadata_extractor =
 | 
				
			||||||
      model_resources.GetMetadataExtractor();
 | 
					      model_resources.GetMetadataExtractor();
 | 
				
			||||||
  if (in_model_nms) {
 | 
					  if (in_model_nms) {
 | 
				
			||||||
| 
						 | 
					@ -820,12 +873,20 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
 | 
				
			||||||
  BuildDetectionPostprocessing(
 | 
					  BuildDetectionPostprocessing(
 | 
				
			||||||
      proto::DetectionPostprocessingGraphOptions& graph_options,
 | 
					      proto::DetectionPostprocessingGraphOptions& graph_options,
 | 
				
			||||||
      Source<std::vector<Tensor>> tensors_in, Graph& graph) {
 | 
					      Source<std::vector<Tensor>> tensors_in, Graph& graph) {
 | 
				
			||||||
 | 
					    Source<std::vector<Tensor>> tensors = tensors_in;
 | 
				
			||||||
 | 
					    if (graph_options.has_quantized_outputs()) {
 | 
				
			||||||
 | 
					      auto& tensors_dequantization_node =
 | 
				
			||||||
 | 
					          graph.AddNode("TensorsDequantizationCalculator");
 | 
				
			||||||
 | 
					      tensors_in >> tensors_dequantization_node.In(kTensorsTag);
 | 
				
			||||||
 | 
					      tensors = tensors_dequantization_node.Out(kTensorsTag)
 | 
				
			||||||
 | 
					                    .Cast<std::vector<Tensor>>();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    std::optional<Source<std::vector<Detection>>> detections;
 | 
					    std::optional<Source<std::vector<Detection>>> detections;
 | 
				
			||||||
    if (!graph_options.has_non_max_suppression_options()) {
 | 
					    if (!graph_options.has_non_max_suppression_options()) {
 | 
				
			||||||
      // Calculators to perform score calibration, if specified in the options.
 | 
					      // Calculators to perform score calibration, if specified in the options.
 | 
				
			||||||
      if (graph_options.has_score_calibration_options()) {
 | 
					      if (graph_options.has_score_calibration_options()) {
 | 
				
			||||||
        MP_ASSIGN_OR_RETURN(tensors_in,
 | 
					        MP_ASSIGN_OR_RETURN(tensors,
 | 
				
			||||||
                            CalibrateScores(tensors_in, graph_options, graph));
 | 
					                            CalibrateScores(tensors, graph_options, graph));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      // Calculator to convert output tensors to a detection proto vector.
 | 
					      // Calculator to convert output tensors to a detection proto vector.
 | 
				
			||||||
      auto& tensors_to_detections =
 | 
					      auto& tensors_to_detections =
 | 
				
			||||||
| 
						 | 
					@ -833,7 +894,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
 | 
				
			||||||
      tensors_to_detections
 | 
					      tensors_to_detections
 | 
				
			||||||
          .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
 | 
					          .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
 | 
				
			||||||
          .Swap(graph_options.mutable_tensors_to_detections_options());
 | 
					          .Swap(graph_options.mutable_tensors_to_detections_options());
 | 
				
			||||||
      tensors_in >> tensors_to_detections.In(kTensorsTag);
 | 
					      tensors >> tensors_to_detections.In(kTensorsTag);
 | 
				
			||||||
      detections = tensors_to_detections.Out(kDetectionsTag)
 | 
					      detections = tensors_to_detections.Out(kDetectionsTag)
 | 
				
			||||||
                       .Cast<std::vector<Detection>>();
 | 
					                       .Cast<std::vector<Detection>>();
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
| 
						 | 
					@ -850,7 +911,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
 | 
				
			||||||
          .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
 | 
					          .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
 | 
				
			||||||
          .Swap(graph_options.mutable_tensors_to_detections_options());
 | 
					          .Swap(graph_options.mutable_tensors_to_detections_options());
 | 
				
			||||||
      anchors >> tensors_to_detections.SideIn(kAnchorsTag);
 | 
					      anchors >> tensors_to_detections.SideIn(kAnchorsTag);
 | 
				
			||||||
      tensors_in >> tensors_to_detections.In(kTensorsTag);
 | 
					      tensors >> tensors_to_detections.In(kTensorsTag);
 | 
				
			||||||
      detections = tensors_to_detections.Out(kDetectionsTag)
 | 
					      detections = tensors_to_detections.Out(kDetectionsTag)
 | 
				
			||||||
                       .Cast<std::vector<mediapipe::Detection>>();
 | 
					                       .Cast<std::vector<mediapipe::Detection>>();
 | 
				
			||||||
      // Non maximum suppression removes redundant object detections.
 | 
					      // Non maximum suppression removes redundant object detections.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,6 +213,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
 | 
				
			||||||
                 }
 | 
					                 }
 | 
				
			||||||
                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
					                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
				
			||||||
               }
 | 
					               }
 | 
				
			||||||
 | 
					               has_quantized_outputs: false
 | 
				
			||||||
          )pb"))));
 | 
					          )pb"))));
 | 
				
			||||||
  EXPECT_THAT(
 | 
					  EXPECT_THAT(
 | 
				
			||||||
      options_out.detection_label_ids_to_text_options().label_items_size(), 90);
 | 
					      options_out.detection_label_ids_to_text_options().label_items_size(), 90);
 | 
				
			||||||
| 
						 | 
					@ -244,6 +245,7 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
 | 
				
			||||||
                 }
 | 
					                 }
 | 
				
			||||||
                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
					                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
				
			||||||
               }
 | 
					               }
 | 
				
			||||||
 | 
					               has_quantized_outputs: false
 | 
				
			||||||
          )pb")));
 | 
					          )pb")));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -273,6 +275,7 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
 | 
				
			||||||
                 }
 | 
					                 }
 | 
				
			||||||
                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
					                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
				
			||||||
               }
 | 
					               }
 | 
				
			||||||
 | 
					               has_quantized_outputs: false
 | 
				
			||||||
          )pb")));
 | 
					          )pb")));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -311,6 +314,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
 | 
				
			||||||
                 score_transformation: IDENTITY
 | 
					                 score_transformation: IDENTITY
 | 
				
			||||||
                 default_score: 0.5
 | 
					                 default_score: 0.5
 | 
				
			||||||
               }
 | 
					               }
 | 
				
			||||||
 | 
					               has_quantized_outputs: false
 | 
				
			||||||
          )pb")));
 | 
					          )pb")));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,4 +46,7 @@ message DetectionPostprocessingGraphOptions {
 | 
				
			||||||
  // Optional detection label id to text calculator options.
 | 
					  // Optional detection label id to text calculator options.
 | 
				
			||||||
  optional mediapipe.DetectionLabelIdToTextCalculatorOptions
 | 
					  optional mediapipe.DetectionLabelIdToTextCalculatorOptions
 | 
				
			||||||
      detection_label_ids_to_text_options = 5;
 | 
					      detection_label_ids_to_text_options = 5;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
 | 
				
			||||||
 | 
					  optional bool has_quantized_outputs = 6;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user