Detection postprocessing support quantized tensor.
PiperOrigin-RevId: 572310272
This commit is contained in:
		
							parent
							
								
									df13788883
								
							
						
					
					
						commit
						d6d92354ea
					
				| 
						 | 
				
			
			@ -169,6 +169,7 @@ cc_library(
 | 
			
		|||
    deps = [
 | 
			
		||||
        "//mediapipe/calculators/core:split_vector_calculator",
 | 
			
		||||
        "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/tensor:tensors_dequantization_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tensor:tensors_to_detections_calculator",
 | 
			
		||||
        "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/calculators/tflite:ssd_anchors_calculator",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -710,6 +710,57 @@ absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
 | 
			
		|||
  return model_output_tensors;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Identifies whether or not the model has quantized outputs, and performs
 | 
			
		||||
// sanity checks.
 | 
			
		||||
absl::StatusOr<bool> HasQuantizedOutputs(
 | 
			
		||||
    const core::ModelResources& model_resources) {
 | 
			
		||||
  const tflite::Model& model = *model_resources.GetTfLiteModel();
 | 
			
		||||
  // Model is checked to have single subgraph before.
 | 
			
		||||
  const auto* primary_subgraph = (*model.subgraphs())[0];
 | 
			
		||||
  int num_output_tensors = primary_subgraph->outputs()->size();
 | 
			
		||||
  // Sanity check tensor types and check if model outputs are quantized or not.
 | 
			
		||||
  int num_quantized_tensors = 0;
 | 
			
		||||
  for (int i = 0; i < num_output_tensors; ++i) {
 | 
			
		||||
    const auto* tensor =
 | 
			
		||||
        primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
 | 
			
		||||
    if (tensor->type() != tflite::TensorType_FLOAT32 &&
 | 
			
		||||
        tensor->type() != tflite::TensorType_UINT8) {
 | 
			
		||||
      return CreateStatusWithPayload(
 | 
			
		||||
          absl::StatusCode::kInvalidArgument,
 | 
			
		||||
          absl::StrFormat("Expected output tensor at index %d to have type "
 | 
			
		||||
                          "UINT8 or FLOAT32, found %s instead.",
 | 
			
		||||
                          i, tflite::EnumNameTensorType(tensor->type())),
 | 
			
		||||
          MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
 | 
			
		||||
    }
 | 
			
		||||
    if (tensor->type() == tflite::TensorType_UINT8) {
 | 
			
		||||
      num_quantized_tensors++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (num_quantized_tensors != num_output_tensors &&
 | 
			
		||||
      num_quantized_tensors != 0) {
 | 
			
		||||
    return CreateStatusWithPayload(
 | 
			
		||||
        absl::StatusCode::kInvalidArgument,
 | 
			
		||||
        absl::StrFormat(
 | 
			
		||||
            "Expected either all or none of the output tensors to be "
 | 
			
		||||
            "quantized, but found %d quantized outputs for %d total outputs.",
 | 
			
		||||
            num_quantized_tensors, num_output_tensors),
 | 
			
		||||
        MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
 | 
			
		||||
  }
 | 
			
		||||
  // Check if metadata is consistent with model topology.
 | 
			
		||||
  const auto* output_tensors_metadata =
 | 
			
		||||
      model_resources.GetMetadataExtractor()->GetOutputTensorMetadata();
 | 
			
		||||
  if (output_tensors_metadata != nullptr &&
 | 
			
		||||
      num_output_tensors != output_tensors_metadata->size()) {
 | 
			
		||||
    return CreateStatusWithPayload(
 | 
			
		||||
        absl::StatusCode::kInvalidArgument,
 | 
			
		||||
        absl::StrFormat("Mismatch between number of output tensors (%d) and "
 | 
			
		||||
                        "output tensors metadata (%d).",
 | 
			
		||||
                        num_output_tensors, output_tensors_metadata->size()),
 | 
			
		||||
        MediaPipeTasksStatus::kMetadataInconsistencyError);
 | 
			
		||||
  }
 | 
			
		||||
  return num_quantized_tensors > 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
absl::Status ConfigureDetectionPostprocessingGraph(
 | 
			
		||||
| 
						 | 
				
			
			@ -738,7 +789,9 @@ absl::Status ConfigureDetectionPostprocessingGraph(
 | 
			
		|||
            model.subgraphs()->Get(0)->outputs()->size()),
 | 
			
		||||
        MediaPipeTasksStatus::kInvalidArgumentError);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  MP_ASSIGN_OR_RETURN(bool has_quantized_outputs,
 | 
			
		||||
                      HasQuantizedOutputs(model_resources));
 | 
			
		||||
  options.set_has_quantized_outputs(has_quantized_outputs);
 | 
			
		||||
  const ModelMetadataExtractor* metadata_extractor =
 | 
			
		||||
      model_resources.GetMetadataExtractor();
 | 
			
		||||
  if (in_model_nms) {
 | 
			
		||||
| 
						 | 
				
			
			@ -820,12 +873,20 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
 | 
			
		|||
  BuildDetectionPostprocessing(
 | 
			
		||||
      proto::DetectionPostprocessingGraphOptions& graph_options,
 | 
			
		||||
      Source<std::vector<Tensor>> tensors_in, Graph& graph) {
 | 
			
		||||
    Source<std::vector<Tensor>> tensors = tensors_in;
 | 
			
		||||
    if (graph_options.has_quantized_outputs()) {
 | 
			
		||||
      auto& tensors_dequantization_node =
 | 
			
		||||
          graph.AddNode("TensorsDequantizationCalculator");
 | 
			
		||||
      tensors_in >> tensors_dequantization_node.In(kTensorsTag);
 | 
			
		||||
      tensors = tensors_dequantization_node.Out(kTensorsTag)
 | 
			
		||||
                    .Cast<std::vector<Tensor>>();
 | 
			
		||||
    }
 | 
			
		||||
    std::optional<Source<std::vector<Detection>>> detections;
 | 
			
		||||
    if (!graph_options.has_non_max_suppression_options()) {
 | 
			
		||||
      // Calculators to perform score calibration, if specified in the options.
 | 
			
		||||
      if (graph_options.has_score_calibration_options()) {
 | 
			
		||||
        MP_ASSIGN_OR_RETURN(tensors_in,
 | 
			
		||||
                            CalibrateScores(tensors_in, graph_options, graph));
 | 
			
		||||
        MP_ASSIGN_OR_RETURN(tensors,
 | 
			
		||||
                            CalibrateScores(tensors, graph_options, graph));
 | 
			
		||||
      }
 | 
			
		||||
      // Calculator to convert output tensors to a detection proto vector.
 | 
			
		||||
      auto& tensors_to_detections =
 | 
			
		||||
| 
						 | 
				
			
			@ -833,7 +894,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
 | 
			
		|||
      tensors_to_detections
 | 
			
		||||
          .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
 | 
			
		||||
          .Swap(graph_options.mutable_tensors_to_detections_options());
 | 
			
		||||
      tensors_in >> tensors_to_detections.In(kTensorsTag);
 | 
			
		||||
      tensors >> tensors_to_detections.In(kTensorsTag);
 | 
			
		||||
      detections = tensors_to_detections.Out(kDetectionsTag)
 | 
			
		||||
                       .Cast<std::vector<Detection>>();
 | 
			
		||||
    } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -850,7 +911,7 @@ class DetectionPostprocessingGraph : public mediapipe::Subgraph {
 | 
			
		|||
          .GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
 | 
			
		||||
          .Swap(graph_options.mutable_tensors_to_detections_options());
 | 
			
		||||
      anchors >> tensors_to_detections.SideIn(kAnchorsTag);
 | 
			
		||||
      tensors_in >> tensors_to_detections.In(kTensorsTag);
 | 
			
		||||
      tensors >> tensors_to_detections.In(kTensorsTag);
 | 
			
		||||
      detections = tensors_to_detections.Out(kDetectionsTag)
 | 
			
		||||
                       .Cast<std::vector<mediapipe::Detection>>();
 | 
			
		||||
      // Non maximum suppression removes redundant object detections.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -213,6 +213,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
 | 
			
		|||
                 }
 | 
			
		||||
                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
			
		||||
               }
 | 
			
		||||
               has_quantized_outputs: false
 | 
			
		||||
          )pb"))));
 | 
			
		||||
  EXPECT_THAT(
 | 
			
		||||
      options_out.detection_label_ids_to_text_options().label_items_size(), 90);
 | 
			
		||||
| 
						 | 
				
			
			@ -244,6 +245,7 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
 | 
			
		|||
                 }
 | 
			
		||||
                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
			
		||||
               }
 | 
			
		||||
               has_quantized_outputs: false
 | 
			
		||||
          )pb")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -273,6 +275,7 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
 | 
			
		|||
                 }
 | 
			
		||||
                 box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
 | 
			
		||||
               }
 | 
			
		||||
               has_quantized_outputs: false
 | 
			
		||||
          )pb")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -311,6 +314,7 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
 | 
			
		|||
                 score_transformation: IDENTITY
 | 
			
		||||
                 default_score: 0.5
 | 
			
		||||
               }
 | 
			
		||||
               has_quantized_outputs: false
 | 
			
		||||
          )pb")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,4 +46,7 @@ message DetectionPostprocessingGraphOptions {
 | 
			
		|||
  // Optional detection label id to text calculator options.
 | 
			
		||||
  optional mediapipe.DetectionLabelIdToTextCalculatorOptions
 | 
			
		||||
      detection_label_ids_to_text_options = 5;
 | 
			
		||||
 | 
			
		||||
  // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
 | 
			
		||||
  optional bool has_quantized_outputs = 6;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user