diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 532bc9a3b..e8f9f57ff 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -161,3 +161,49 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "detection_postprocessing_graph", + srcs = ["detection_postprocessing_graph.cc"], + hdrs = ["detection_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/tasks/metadata:object_detector_metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc new file mode 100644 index 000000000..d7fc1892c --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.cc @@ -0,0 +1,886 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/core/split_vector_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" +#include "mediapipe/tasks/metadata/object_detector_metadata_schema_generated.h" +#include "mediapipe/util/label_map.pb.h" +#include "mediapipe/util/label_map_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +namespace { + +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::tflite::BoundingBoxProperties; +using ::tflite::ContentProperties; +using ::tflite::ContentProperties_BoundingBoxProperties; +using ::tflite::EnumNameContentProperties; +using ::tflite::ProcessUnit; +using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; +using ::tflite::TensorMetadata; +using LabelItems = mediapipe::proto_ns::Map; +using TensorsSource = + mediapipe::api2::builder::Source>; + +constexpr int kInModelNmsDefaultLocationsIndex = 0; +constexpr int kInModelNmsDefaultCategoriesIndex = 1; +constexpr int kInModelNmsDefaultScoresIndex = 2; +constexpr int kInModelNmsDefaultNumResultsIndex = 3; + +constexpr int kOutModelNmsDefaultLocationsIndex = 0; +constexpr int kOutModelNmsDefaultScoresIndex = 1; + +constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); + +constexpr absl::string_view kLocationTensorName = "location"; +constexpr absl::string_view kCategoryTensorName = "category"; +constexpr absl::string_view kScoreTensorName = "score"; +constexpr absl::string_view kNumberOfDetectionsTensorName = + "number of detections"; +constexpr absl::string_view kDetectorMetadataName = "DETECTOR_METADATA"; +constexpr absl::string_view kCalibratedScoresTag = "CALIBRATED_SCORES"; +constexpr absl::string_view kDetectionsTag = "DETECTIONS"; +constexpr absl::string_view kIndicesTag = "INDICES"; +constexpr absl::string_view kScoresTag = "SCORES"; +constexpr absl::string_view kTensorsTag = "TENSORS"; +constexpr absl::string_view kAnchorsTag = "ANCHORS"; + +// Struct holding the different output streams produced by the graph. +struct DetectionPostprocessingOutputStreams { + Source> detections; +}; + +// Parameters used for configuring the post-processing calculators. +struct PostProcessingSpecs { + // The maximum number of detection results to return. + int max_results; + // Indices of the output tensors to match the output tensors to the correct + // index order of the output tensors: [location, categories, scores, + // num_detections]. + std::vector output_tensor_indices; + // For each pack of 4 coordinates returned by the model, this denotes the + // order in which to get the left, top, right and bottom coordinates. + std::vector bounding_box_corners_order; + // This is populated by reading the label files from the TFLite Model + // Metadata: if no such files are available, this is left empty and the + // ObjectDetector will only be able to populate the `index` field of the + // detection results. + LabelItems label_items; + // Score threshold. Detections with a confidence below this value are + // discarded. If none is provided via metadata or options, -FLT_MAX is set as + // default value. + float score_threshold; + // Set of category indices to be allowed/denied. + absl::flat_hash_set allow_or_deny_categories; + // Indicates `allow_or_deny_categories` is an allowlist or a denylist. + bool is_allowlist; + // Score calibration options, if any. + std::optional score_calibration_options; +}; + +absl::Status SanityCheckOptions(const proto::DetectorOptions& options) { + if (options.max_results() == 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid `max_results` option: value must be != 0", + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (options.category_allowlist_size() > 0 && + options.category_denylist_size() > 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "`category_allowlist` and `category_denylist` are mutually " + "exclusive options.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::StatusOr GetBoundingBoxProperties( + const TensorMetadata& tensor_metadata) { + if (tensor_metadata.content() == nullptr || + tensor_metadata.content()->content_properties() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties for tensor %s, found none.", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + ContentProperties type = tensor_metadata.content()->content_properties_type(); + if (type != ContentProperties_BoundingBoxProperties) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties for tensor %s, found %s.", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0", + EnumNameContentProperties(type)), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + const BoundingBoxProperties* properties = + tensor_metadata.content()->content_properties_as_BoundingBoxProperties(); + + // Mobile SSD only supports "BOUNDARIES" bounding box type. + if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s", + tflite::EnumNameBoundingBoxType(properties->type())), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + // Mobile SSD only supports "RATIO" coordinates type. + if (properties->coordinate_type() != tflite::CoordinateType_RATIO) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Mobile SSD only supports CoordinateType RATIO, found %s", + tflite::EnumNameCoordinateType(properties->coordinate_type())), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + // Index is optional, but must contain 4 values if present. + if (properties->index() != nullptr && properties->index()->size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties index to contain 4 values, found " + "%d", + properties->index()->size()), + MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); + } + + return properties; +} + +absl::StatusOr GetLabelItemsIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata, + tflite::AssociatedFileType associated_file_type, absl::string_view locale) { + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName(tensor_metadata, + associated_file_type); + if (labels_filename.empty()) { + LabelItems empty_label_items; + return empty_label_items; + } + ASSIGN_OR_RETURN(absl::string_view labels_file, + metadata_extractor.GetAssociatedFile(labels_filename)); + const std::string display_names_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, associated_file_type, locale); + absl::string_view display_names_file; + if (!display_names_filename.empty()) { + ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( + display_names_filename)); + } + return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file); +} + +absl::StatusOr GetScoreThreshold( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + ASSIGN_OR_RETURN( + const ProcessUnit* score_thresholding_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); + if (score_thresholding_process_unit == nullptr) { + return kDefaultScoreThreshold; + } + return score_thresholding_process_unit->options_as_ScoreThresholdingOptions() + ->global_score_threshold(); +} + +absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( + const proto::DetectorOptions& config, const LabelItems& label_items) { + absl::flat_hash_set category_indices; + // Exit early if no denylist/allowlist. + if (config.category_denylist_size() == 0 && + config.category_allowlist_size() == 0) { + return category_indices; + } + if (label_items.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Using `category_allowlist` or `category_denylist` requires " + "labels to be present in the TFLite Model Metadata but none was found.", + MediaPipeTasksStatus::kMetadataMissingLabelsError); + } + const auto& category_list = config.category_allowlist_size() > 0 + ? config.category_allowlist() + : config.category_denylist(); + for (const auto& category_name : category_list) { + int index = -1; + for (int i = 0; i < label_items.size(); ++i) { + if (label_items.at(i).name() == category_name) { + index = i; + break; + } + } + // Ignores duplicate or unknown categories. + if (index < 0) { + continue; + } + category_indices.insert(index); + } + return category_indices; +} + +absl::StatusOr> +GetScoreCalibrationOptionsIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + // Get ScoreCalibrationOptions, if any. + ASSIGN_OR_RETURN( + const ProcessUnit* score_calibration_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit == nullptr) { + return std::nullopt; + } + auto* score_calibration_options = + score_calibration_process_unit->options_as_ScoreCalibrationOptions(); + // Get corresponding AssociatedFile. + auto score_calibration_filename = + metadata_extractor.FindFirstAssociatedFileName( + tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + if (score_calibration_filename.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + "Found ScoreCalibrationOptions but missing required associated " + "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", + MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); + } + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + ScoreCalibrationCalculatorOptions score_calibration_calculator_options; + MP_RETURN_IF_ERROR(ConfigureScoreCalibration( + score_calibration_options->score_transformation(), + score_calibration_options->default_score(), score_calibration_file, + &score_calibration_calculator_options)); + return score_calibration_calculator_options; +} + +absl::StatusOr> GetOutputTensorIndices( + const Vector>* tensor_metadatas) { + std::vector output_indices; + if (tensor_metadatas->size() == 4) { + output_indices = { + core::FindTensorIndexByMetadataName(tensor_metadatas, + kLocationTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, + kCategoryTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, + kNumberOfDetectionsTensorName)}; + // locations, categories, scores, and number of detections + for (int i = 0; i < 4; i++) { + int output_index = output_indices[i]; + // If tensor name is not found, set the default output indices. + if (output_index == -1) { + LOG(WARNING) << absl::StrFormat( + "You don't seem to be matching tensor names in metadata list. The " + "tensor name \"%s\" at index %d in the model metadata doesn't " + "match " + "the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].", + tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, + kCategoryTensorName, kScoreTensorName, + kNumberOfDetectionsTensorName); + output_indices = { + kInModelNmsDefaultLocationsIndex, kInModelNmsDefaultCategoriesIndex, + kInModelNmsDefaultScoresIndex, kInModelNmsDefaultNumResultsIndex}; + return output_indices; + } + } + } else if (tensor_metadatas->size() == 2) { + output_indices = {core::FindTensorIndexByMetadataName(tensor_metadatas, + kLocationTensorName), + core::FindTensorIndexByMetadataName(tensor_metadatas, + kScoreTensorName)}; + // location, score + for (int i = 0; i < 2; i++) { + int output_index = output_indices[i]; + // If tensor name is not found, set the default output indices. + if (output_index == -1) { + LOG(WARNING) << absl::StrFormat( + "You don't seem to be matching tensor names in metadata list. The " + "tensor name \"%s\" at index %d in the model metadata doesn't " + "match " + "the available output names: [\"%s\", \"%s\"].", + tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, + kScoreTensorName); + output_indices = {kOutModelNmsDefaultLocationsIndex, + kOutModelNmsDefaultScoresIndex}; + return output_indices; + } + } + } else { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected a model with 2 or 4 output tensors metadata, found %d.", + tensor_metadatas->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + return output_indices; +} + +// Builds PostProcessingSpecs from DetectorOptions and model metadata for +// configuring the post-processing calculators. +absl::StatusOr BuildPostProcessingSpecs( + const proto::DetectorOptions& options, bool in_model_nms, + const ModelMetadataExtractor* metadata_extractor) { + const auto* output_tensors_metadata = + metadata_extractor->GetOutputTensorMetadata(); + PostProcessingSpecs specs; + specs.max_results = options.max_results(); + ASSIGN_OR_RETURN(specs.output_tensor_indices, + GetOutputTensorIndices(output_tensors_metadata)); + // Extracts mandatory BoundingBoxProperties and performs sanity checks on the + // fly. + ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties, + GetBoundingBoxProperties(*output_tensors_metadata->Get( + specs.output_tensor_indices[0]))); + if (bounding_box_properties->index() == nullptr) { + specs.bounding_box_corners_order = {0, 1, 2, 3}; + } else { + auto bounding_box_index = bounding_box_properties->index(); + specs.bounding_box_corners_order = { + bounding_box_index->Get(0), + bounding_box_index->Get(1), + bounding_box_index->Get(2), + bounding_box_index->Get(3), + }; + } + // Builds label map (if available) from metadata. + // For models with in-model-nms, the label map is stored in the Category + // tensor which use TENSOR_VALUE_LABELS. For models with out-of-model-nms, the + // label map is stored in the Score tensor which use TENSOR_AXIS_LABELS. + ASSIGN_OR_RETURN( + specs.label_items, + GetLabelItemsIfAny( + *metadata_extractor, + *output_tensors_metadata->Get(specs.output_tensor_indices[1]), + in_model_nms ? tflite::AssociatedFileType_TENSOR_VALUE_LABELS + : tflite::AssociatedFileType_TENSOR_AXIS_LABELS, + options.display_names_locale())); + // Obtains allow/deny categories. + specs.is_allowlist = !options.category_allowlist().empty(); + ASSIGN_OR_RETURN( + specs.allow_or_deny_categories, + GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items)); + + // Sets score threshold. + if (options.has_score_threshold()) { + specs.score_threshold = options.score_threshold(); + } else { + ASSIGN_OR_RETURN( + specs.score_threshold, + GetScoreThreshold( + *metadata_extractor, + *output_tensors_metadata->Get( + specs.output_tensor_indices + [in_model_nms ? kInModelNmsDefaultScoresIndex + : kOutModelNmsDefaultScoresIndex]))); + } + if (in_model_nms) { + // Builds score calibration options (if available) from metadata. + ASSIGN_OR_RETURN( + specs.score_calibration_options, + GetScoreCalibrationOptionsIfAny( + *metadata_extractor, + *output_tensors_metadata->Get( + specs.output_tensor_indices[kInModelNmsDefaultScoresIndex]))); + } + return specs; +} + +// Builds PostProcessingSpecs from DetectorOptions and model metadata for +// configuring the post-processing calculators for models with +// non-maximum-suppression. +absl::StatusOr BuildInModelNmsPostProcessingSpecs( + const proto::DetectorOptions& options, + const ModelMetadataExtractor* metadata_extractor) { + // Checks output tensor metadata is present and consistent with model. + auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr || + output_tensors_metadata->size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (4) and " + "output tensors metadata (%d).", + output_tensors_metadata == nullptr + ? 0 + : output_tensors_metadata->size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return BuildPostProcessingSpecs(options, /*in_model_nms=*/true, + metadata_extractor); +} + +// Fills in the TensorsToDetectionsCalculatorOptions based on +// PostProcessingSpecs. +void ConfigureInModelNmsTensorsToDetectionsCalculator( + const PostProcessingSpecs& specs, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + options->set_num_classes(specs.label_items.size()); + options->set_num_coords(4); + options->set_min_score_thresh(specs.score_threshold); + if (specs.max_results != -1) { + options->set_max_results(specs.max_results); + } + if (specs.is_allowlist) { + options->mutable_allow_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } else { + options->mutable_ignore_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } + + const auto& output_indices = specs.output_tensor_indices; + // Assigns indices to each the model output tensor. + auto* tensor_mapping = options->mutable_tensor_mapping(); + tensor_mapping->set_detections_tensor_index(output_indices[0]); + tensor_mapping->set_classes_tensor_index(output_indices[1]); + tensor_mapping->set_scores_tensor_index(output_indices[2]); + tensor_mapping->set_num_detections_tensor_index(output_indices[3]); + + // Assigns the bounding box corner order. + auto box_boundaries_indices = options->mutable_box_boundaries_indices(); + box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]); + box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]); + box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]); + box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]); +} + +// Builds PostProcessingSpecs from DetectorOptions and model metadata for +// configuring the post-processing calculators for models without +// non-maximum-suppression. +absl::StatusOr BuildOutModelNmsPostProcessingSpecs( + const proto::DetectorOptions& options, + const ModelMetadataExtractor* metadata_extractor) { + // Checks output tensor metadata is present and consistent with model. + auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr || + output_tensors_metadata->size() != 2) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (2) and " + "output tensors metadata (%d).", + output_tensors_metadata == nullptr + ? 0 + : output_tensors_metadata->size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return BuildPostProcessingSpecs(options, /*in_model_nms=*/false, + metadata_extractor); +} + +// Configures the TensorsToDetectionCalculator for models without +// non-maximum-suppression in tflite model. The required config parameters are +// extracted from the ObjectDetectorMetadata +// (metadata/object_detector_metadata_schema.fbs). +absl::Status ConfigureOutModelNmsTensorsToDetectionsCalculator( + const ModelMetadataExtractor* metadata_extractor, + const PostProcessingSpecs& specs, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + bool found_detector_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto* custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kDetectorMetadataName) { + found_detector_metadata = true; + const auto* tensors_decoding_options = + GetObjectDetectorOptions(custom_metadata->data()->data()) + ->tensors_decoding_options(); + // Here we don't set the max results for TensorsToDetectionsCalculator. + // For models without nms, the results are filtered by max_results in + // NonMaxSuppressionCalculator. + options->set_num_classes(tensors_decoding_options->num_classes()); + options->set_num_boxes(tensors_decoding_options->num_boxes()); + options->set_num_coords(tensors_decoding_options->num_coords()); + options->set_keypoint_coord_offset( + tensors_decoding_options->keypoint_coord_offset()); + options->set_num_keypoints(tensors_decoding_options->num_keypoints()); + options->set_num_values_per_keypoint( + tensors_decoding_options->num_values_per_keypoint()); + options->set_x_scale(tensors_decoding_options->x_scale()); + options->set_y_scale(tensors_decoding_options->y_scale()); + options->set_w_scale(tensors_decoding_options->w_scale()); + options->set_h_scale(tensors_decoding_options->h_scale()); + options->set_apply_exponential_on_box_size( + tensors_decoding_options->apply_exponential_on_box_size()); + options->set_sigmoid_score(tensors_decoding_options->sigmoid_score()); + break; + } + } + } + if (!found_detector_metadata) { + return absl::InvalidArgumentError( + "TensorsDecodingOptions is not found in the object detector " + "metadata."); + } + // Options not configured through metadata. + options->set_box_format( + mediapipe::TensorsToDetectionsCalculatorOptions::YXHW); + options->set_min_score_thresh(specs.score_threshold); + if (specs.is_allowlist) { + options->mutable_allow_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } else { + options->mutable_ignore_classes()->Assign( + specs.allow_or_deny_categories.begin(), + specs.allow_or_deny_categories.end()); + } + + const auto& output_indices = specs.output_tensor_indices; + // Assigns indices to each the model output tensor. + auto* tensor_mapping = options->mutable_tensor_mapping(); + tensor_mapping->set_detections_tensor_index(output_indices[0]); + tensor_mapping->set_scores_tensor_index(output_indices[1]); + return absl::OkStatus(); +} + +// Configures the SsdAnchorsCalculator for models without +// non-maximum-suppression in tflite model. The required config parameters are +// extracted from the ObjectDetectorMetadata +// (metadata/object_detector_metadata_schema.fbs). +absl::Status ConfigureSsdAnchorsCalculator( + const ModelMetadataExtractor* metadata_extractor, + mediapipe::SsdAnchorsCalculatorOptions* options) { + bool found_detector_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto* custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kDetectorMetadataName) { + found_detector_metadata = true; + const auto* ssd_anchors_options = + GetObjectDetectorOptions(custom_metadata->data()->data()) + ->ssd_anchors_options(); + for (const auto* ssd_anchor : + *ssd_anchors_options->fixed_anchors_schema()->anchors()) { + auto* fixed_anchor = options->add_fixed_anchors(); + fixed_anchor->set_y_center(ssd_anchor->y_center()); + fixed_anchor->set_x_center(ssd_anchor->x_center()); + fixed_anchor->set_h(ssd_anchor->height()); + fixed_anchor->set_w(ssd_anchor->width()); + } + break; + } + } + } + if (!found_detector_metadata) { + return absl::InvalidArgumentError( + "SsdAnchorsOptions is not found in the object detector " + "metadata."); + } + return absl::OkStatus(); +} + +// Sets the default IoU-based non-maximum-suppression configs, and set the +// min_suppression_threshold and max_results for detection models without +// non-maximum-suppression. +void ConfigureNonMaxSuppressionCalculator( + const proto::DetectorOptions& detector_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + detector_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::DEFAULT); + options->set_max_num_detections(detector_options.max_results()); +} + +// Sets the labels from post PostProcessingSpecs. +void ConfigureDetectionLabelIdToTextCalculator( + PostProcessingSpecs& specs, + mediapipe::DetectionLabelIdToTextCalculatorOptions* options) { + *options->mutable_label_items() = std::move(specs.label_items); +} + +// Splits the vector of 4 output tensors from model inference and calibrate the +// score tensors according to the metadata, if any. Then concatenate the tensors +// back to a vector of 4 tensors. +absl::StatusOr>> CalibrateScores( + Source> model_output_tensors, + const proto::DetectionPostprocessingGraphOptions& options, Graph& graph) { + // Split tensors. + auto* split_tensor_vector_node = + &graph.AddNode("SplitTensorVectorCalculator"); + auto& split_tensor_vector_options = + split_tensor_vector_node + ->GetOptions(); + for (int i = 0; i < 4; ++i) { + auto* range = split_tensor_vector_options.add_ranges(); + range->set_begin(i); + range->set_end(i + 1); + } + model_output_tensors >> split_tensor_vector_node->In(0); + + // Add score calibration calculator. + auto* score_calibration_node = &graph.AddNode("ScoreCalibrationCalculator"); + score_calibration_node->GetOptions() + .CopyFrom(options.score_calibration_options()); + const auto& tensor_mapping = + options.tensors_to_detections_options().tensor_mapping(); + split_tensor_vector_node->Out(tensor_mapping.classes_tensor_index()) >> + score_calibration_node->In(kIndicesTag); + split_tensor_vector_node->Out(tensor_mapping.scores_tensor_index()) >> + score_calibration_node->In(kScoresTag); + + // Re-concatenate tensors. + auto* concatenate_tensor_vector_node = + &graph.AddNode("ConcatenateTensorVectorCalculator"); + for (int i = 0; i < 4; ++i) { + if (i == tensor_mapping.scores_tensor_index()) { + score_calibration_node->Out(kCalibratedScoresTag) >> + concatenate_tensor_vector_node->In(i); + } else { + split_tensor_vector_node->Out(i) >> concatenate_tensor_vector_node->In(i); + } + } + model_output_tensors = + concatenate_tensor_vector_node->Out(0).Cast>(); + return model_output_tensors; +} + +} // namespace + +absl::Status ConfigureDetectionPostprocessingGraph( + const tasks::core::ModelResources& model_resources, + const proto::DetectorOptions& detector_options, + proto::DetectionPostprocessingGraphOptions& options) { + MP_RETURN_IF_ERROR(SanityCheckOptions(detector_options)); + const auto& model = *model_resources.GetTfLiteModel(); + bool in_model_nms = false; + if (model.subgraphs()->size() != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected a model with a single subgraph, found %d.", + model.subgraphs()->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (model.subgraphs()->Get(0)->outputs()->size() == 2) { + in_model_nms = false; + } else if (model.subgraphs()->Get(0)->outputs()->size() == 4) { + in_model_nms = true; + } else { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected a model with 2 or 4 output tensors, found %d.", + model.subgraphs()->Get(0)->outputs()->size()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + + const ModelMetadataExtractor* metadata_extractor = + model_resources.GetMetadataExtractor(); + if (in_model_nms) { + ASSIGN_OR_RETURN(auto post_processing_specs, + BuildInModelNmsPostProcessingSpecs(detector_options, + metadata_extractor)); + ConfigureInModelNmsTensorsToDetectionsCalculator( + post_processing_specs, options.mutable_tensors_to_detections_options()); + ConfigureDetectionLabelIdToTextCalculator( + post_processing_specs, + options.mutable_detection_label_ids_to_text_options()); + if (post_processing_specs.score_calibration_options.has_value()) { + *options.mutable_score_calibration_options() = + std::move(*post_processing_specs.score_calibration_options); + } + } else { + ASSIGN_OR_RETURN(auto post_processing_specs, + BuildOutModelNmsPostProcessingSpecs(detector_options, + metadata_extractor)); + MP_RETURN_IF_ERROR(ConfigureOutModelNmsTensorsToDetectionsCalculator( + metadata_extractor, post_processing_specs, + options.mutable_tensors_to_detections_options())); + MP_RETURN_IF_ERROR(ConfigureSsdAnchorsCalculator( + metadata_extractor, options.mutable_ssd_anchors_options())); + ConfigureNonMaxSuppressionCalculator( + detector_options, options.mutable_non_max_suppression_options()); + ConfigureDetectionLabelIdToTextCalculator( + post_processing_specs, + options.mutable_detection_label_ids_to_text_options()); + } + + return absl::OkStatus(); +} + +// A DetectionPostprocessingGraph converts raw tensors into +// std::vector. +// +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator. The tensors vector could be +// size 4 or size 2. Tensors vector of size 4 expects the tensors from the +// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector +// of size 2 expects the tensors from the models without the ops. +// [1]: +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc +// Outputs: +// DETECTIONS - std::vector +// The postprocessed detection results. +// +// The recommended way of using this graph is through the GraphBuilder API +// using the 'ConfigureDetectionPostprocessingGraph()' function. See header +// file for more details. +class DetectionPostprocessingGraph : public mediapipe::Subgraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto output_streams, + BuildDetectionPostprocessing( + *sc->MutableOptions(), + graph.In(kTensorsTag).Cast>(), graph)); + output_streams.detections >> + graph.Out(kDetectionsTag).Cast>(); + return graph.GetConfig(); + } + + private: + // Adds an on-device detection postprocessing graph into the provided + // builder::Graph instance. The detection postprocessing graph takes + // tensors (std::vector) as input and returns one output + // stream: + // - Detection results as a std::vector. + // + // graph_options: the on-device DetectionPostprocessingGraphOptions. + // tensors_in: (std::vector>) tensors to postprocess. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr + BuildDetectionPostprocessing( + proto::DetectionPostprocessingGraphOptions& graph_options, + Source> tensors_in, Graph& graph) { + std::optional>> detections; + if (!graph_options.has_non_max_suppression_options()) { + // Calculators to perform score calibration, if specified in the options. + if (graph_options.has_score_calibration_options()) { + ASSIGN_OR_RETURN(tensors_in, + CalibrateScores(tensors_in, graph_options, graph)); + } + // Calculator to convert output tensors to a detection proto vector. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + tensors_to_detections + .GetOptions() + .Swap(graph_options.mutable_tensors_to_detections_options()); + tensors_in >> tensors_to_detections.In(kTensorsTag); + detections = tensors_to_detections.Out(kDetectionsTag) + .Cast>(); + } else { + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ssd_anchor.GetOptions().Swap( + graph_options.mutable_ssd_anchors_options()); + auto anchors = + ssd_anchor.SideOut("").Cast>(); + // Convert raw output tensors to detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + tensors_to_detections + .GetOptions() + .Swap(graph_options.mutable_tensors_to_detections_options()); + anchors >> tensors_to_detections.SideIn(kAnchorsTag); + tensors_in >> tensors_to_detections.In(kTensorsTag); + detections = tensors_to_detections.Out(kDetectionsTag) + .Cast>(); + // Non maximum suppression removes redundant object detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + non_maximum_suppression + .GetOptions() + .Swap(graph_options.mutable_non_max_suppression_options()); + *detections >> non_maximum_suppression.In(""); + detections = + non_maximum_suppression.Out("").Cast>(); + } + + // Calculator to assign detection labels. + auto& detection_label_id_to_text = + graph.AddNode("DetectionLabelIdToTextCalculator"); + detection_label_id_to_text + .GetOptions() + .Swap(graph_options.mutable_detection_label_ids_to_text_options()); + *detections >> detection_label_id_to_text.In(""); + return { + {detection_label_id_to_text.Out("").Cast>()}}; + } +}; + +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::processors::DetectionPostprocessingGraph); // NOLINT +// clang-format on + +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h new file mode 100644 index 000000000..1696b844f --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_ + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a DetectionPostprocessingGraph using the provided model +// resources and DetectorOptions. +// +// Example usage: +// +// auto& postprocessing = +// graph.AddNode("mediapipe.tasks.components.processors.DetectionPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph( +// model_resources, +// detector_options, +// &preprocessing.GetOptions())); +// +// The resulting DetectionPostprocessingGraph has the following I/O: +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator. The tensors vector could be +// size 4 or size 2. Tensors vector of size 4 expects the tensors from the +// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector +// of size 2 expects the tensors from the models without the ops. +// [1]: +// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc +// Outputs: +// DETECTIONS - std::vector +// The postprocessed detection results. +absl::Status ConfigureDetectionPostprocessingGraph( + const tasks::core::ModelResources& model_resources, + const proto::DetectorOptions& detector_options, + proto::DetectionPostprocessingGraphOptions& options); + +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc new file mode 100644 index 000000000..36aead0c1 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/detection_postprocessing_graph_test.cc @@ -0,0 +1,570 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/graph_runner.h" +#include "mediapipe/framework/output_stream_poller.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "tensorflow/lite/test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Pointwise; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr absl::string_view kTestDataDirectory = + "/mediapipe/tasks/testdata/vision"; +constexpr absl::string_view kMobileSsdWithMetadata = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; +constexpr absl::string_view kMobileSsdWithDummyScoreCalibration = + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration." + "tflite"; +constexpr absl::string_view kEfficientDetWithoutNms = + "efficientdet_lite0_fp16_no_nms.tflite"; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; + +constexpr absl::string_view kTensorsTag = "TENSORS"; +constexpr absl::string_view kDetectionsTag = "DETECTIONS"; +constexpr absl::string_view kTensorsName = "tensors"; +constexpr absl::string_view kDetectionsName = "detections"; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +class ConfigureTest : public tflite::testing::Test {}; + +TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.set_max_results(0); + + proto::DetectionPostprocessingGraphOptions options_out; + auto status = ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); +} + +TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.add_category_allowlist("foo"); + options_in.add_category_denylist("bar"); + + proto::DetectionPostprocessingGraphOptions options_out; + auto status = ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); +} + +TEST_F(ConfigureTest, SucceedsWithMaxResults) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.set_max_results(3); + + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + + EXPECT_THAT( + options_out, + Approximately(Partially(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + max_results: 3 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb")))); +} + +TEST_F(ConfigureTest, SucceedsWithMaxResultsWithoutModelNms) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, CreateModelResourcesForModel( + kEfficientDetWithoutNms)); + proto::DetectorOptions options_in; + options_in.set_max_results(3); + + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + EXPECT_THAT(options_out, Approximately(Partially(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_boxes: 19206 + num_coords: 4 + x_scale: 1 + y_scale: 1 + w_scale: 1 + h_scale: 1 + keypoint_coord_offset: 0 + num_keypoints: 0 + num_values_per_keypoint: 2 + apply_exponential_on_box_size: true + sigmoid_score: false + tensor_mapping { + detections_tensor_index: 1 + scores_tensor_index: 0 + } + box_format: YXHW + } + non_max_suppression_options { + max_num_detections: 3 + min_suppression_threshold: 0 + overlap_type: INTERSECTION_OVER_UNION + algorithm: DEFAULT + } + )pb")))); + EXPECT_THAT( + options_out.detection_label_ids_to_text_options().label_items_size(), 90); +} + +TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.set_score_threshold(0.5); + + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + EXPECT_THAT( + options_out, + Approximately(Partially(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: 0.5 + num_classes: 90 + num_coords: 4 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb")))); + EXPECT_THAT( + options_out.detection_label_ids_to_text_options().label_items_size(), 90); +} + +TEST_F(ConfigureTest, SucceedsWithAllowlist) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.add_category_allowlist("bicycle"); + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + // Clear labels ids to text and compare the rest of the options. + options_out.clear_detection_label_ids_to_text_options(); + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + allow_classes: 1 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb"))); +} + +TEST_F(ConfigureTest, SucceedsWithDenylist) { + MP_ASSERT_OK_AND_ASSIGN(auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithMetadata)); + proto::DetectorOptions options_in; + options_in.add_category_denylist("person"); + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + // Clear labels ids to text and compare the rest of the options. + options_out.clear_detection_label_ids_to_text_options(); + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + ignore_classes: 0 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + )pb"))); +} + +TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileSsdWithDummyScoreCalibration)); + proto::DetectorOptions options_in; + proto::DetectionPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources, + options_in, options_out)); + // Clear labels ids to text. + options_out.clear_detection_label_ids_to_text_options(); + // Check sigmoids size and first element. + ASSERT_EQ(options_out.score_calibration_options().sigmoids_size(), 89); + EXPECT_THAT(options_out.score_calibration_options().sigmoids()[0], + EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb")); + options_out.mutable_score_calibration_options()->clear_sigmoids(); + // Compare the rest of the option. + EXPECT_THAT( + options_out, + Approximately(EqualsProto( + R"pb(tensors_to_detections_options { + min_score_thresh: -3.4028235e+38 + num_classes: 90 + num_coords: 4 + tensor_mapping { + detections_tensor_index: 0 + classes_tensor_index: 1 + scores_tensor_index: 2 + num_detections_tensor_index: 3 + } + box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 } + } + score_calibration_options { + score_transformation: IDENTITY + default_score: 0.5 + } + )pb"))); +} + +class PostprocessingTest : public tflite::testing::Test { + protected: + absl::StatusOr BuildGraph( + absl::string_view model_name, const proto::DetectorOptions& options) { + ASSIGN_OR_RETURN(auto model_resources, + CreateModelResourcesForModel(model_name)); + + Graph graph; + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "DetectionPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph( + *model_resources, options, + postprocessing + .GetOptions())); + graph[Input>(kTensorsTag)].SetName( + std::string(kTensorsName)) >> + postprocessing.In(kTensorsTag); + postprocessing.Out(kDetectionsTag).SetName(std::string(kDetectionsName)) >> + graph[Output>(kDetectionsTag)]; + MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); + ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( + std::string(kDetectionsName))); + MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); + return poller; + } + + template + void AddTensor(const std::vector& tensor, + const Tensor::ElementType& element_type, + const Tensor::Shape& shape) { + tensors_->emplace_back(element_type, shape); + auto view = tensors_->back().GetCpuWriteView(); + T* buffer = view.buffer(); + std::copy(tensor.begin(), tensor.end(), buffer); + } + + absl::Status Run(int timestamp = 0) { + MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( + std::string(kTensorsName), + Adopt(tensors_.release()).At(Timestamp(timestamp)))); + // Reset tensors for future calls. + tensors_ = absl::make_unique>(); + return absl::OkStatus(); + } + + template + absl::StatusOr GetResult(OutputStreamPoller& poller) { + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); + MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); + + Packet packet; + if (!poller.Next(&packet)) { + return absl::InternalError("Unable to get output packet"); + } + auto result = packet.Get(); + MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); + return result; + } + + private: + CalculatorGraph calculator_graph_; + std::unique_ptr> tensors_ = + absl::make_unique>(); +}; + +TEST_F(PostprocessingTest, SucceedsWithMetadata) { + // Build graph. + proto::DetectorOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN(auto poller, + BuildGraph(kMobileSsdWithMetadata, options)); + + // Build input tensors. + constexpr int kBboxesNum = 5; + // Location tensor. + std::vector location_tensor(kBboxesNum * 4, 0); + for (int i = 0; i < kBboxesNum; ++i) { + location_tensor[i * 4] = 0.1f; + location_tensor[i * 4 + 1] = 0.1f; + location_tensor[i * 4 + 2] = 0.4f; + location_tensor[i * 4 + 3] = 0.5f; + } + // Category tensor. + std::vector category_tensor(kBboxesNum, 0); + for (int i = 0; i < kBboxesNum; ++i) { + category_tensor[i] = i + 1; + } + + // Score tensor. Post processed tensor scores are in descending order. + std::vector score_tensor(kBboxesNum, 0); + for (int i = 0; i < kBboxesNum; ++i) { + score_tensor[i] = static_cast(kBboxesNum - i) / kBboxesNum; + } + + // Number of detections tensor. + std::vector num_detections_tensor(1, 0); + num_detections_tensor[0] = kBboxesNum; + + // Send tensors and get results. + AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4}); + AddTensor(category_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum}); + AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum}); + AddTensor(num_detections_tensor, Tensor::ElementType::kFloat32, {1}); + MP_ASSERT_OK(Run()); + + // Validate results. + EXPECT_THAT(GetResult>(poller), + IsOkAndHolds(ElementsAre(Approximately(EqualsProto( + R"pb( + label: "bicycle" + score: 1 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1 + ymin: 0.1 + width: 0.4 + height: 0.3 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "car" + score: 0.8 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1 + ymin: 0.1 + width: 0.4 + height: 0.3 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "motorcycle" + score: 0.6 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.1 + ymin: 0.1 + width: 0.4 + height: 0.3 + } + } + )pb"))))); +} + +TEST_F(PostprocessingTest, SucceedsWithOutModelNms) { + // Build graph. + proto::DetectorOptions options; + options.set_max_results(3); + MP_ASSERT_OK_AND_ASSIGN(auto poller, + BuildGraph(kEfficientDetWithoutNms, options)); + + // Build input tensors. + constexpr int kBboxesNum = 19206; + constexpr int kBicycleBboxIdx = 1000; + constexpr int kCarBboxIdx = 2000; + constexpr int kMotoCycleBboxIdx = 4000; + // Location tensor. + std::vector location_tensor(kBboxesNum * 4, 0); + for (int i = 0; i < kBboxesNum; ++i) { + location_tensor[i * 4] = 0.5f; + location_tensor[i * 4 + 1] = 0.5f; + location_tensor[i * 4 + 2] = 0.001f; + location_tensor[i * 4 + 3] = 0.001f; + } + + // Detected three objects. + location_tensor[kBicycleBboxIdx * 4] = 0.7f; + location_tensor[kBicycleBboxIdx * 4 + 1] = 0.8f; + location_tensor[kBicycleBboxIdx * 4 + 2] = 0.2f; + location_tensor[kBicycleBboxIdx * 4 + 3] = 0.1f; + + location_tensor[kCarBboxIdx * 4] = 0.1f; + location_tensor[kCarBboxIdx * 4 + 1] = 0.1f; + location_tensor[kCarBboxIdx * 4 + 2] = 0.1f; + location_tensor[kCarBboxIdx * 4 + 3] = 0.1f; + + location_tensor[kMotoCycleBboxIdx * 4] = 0.2f; + location_tensor[kMotoCycleBboxIdx * 4 + 1] = 0.8f; + location_tensor[kMotoCycleBboxIdx * 4 + 2] = 0.1f; + location_tensor[kMotoCycleBboxIdx * 4 + 3] = 0.2f; + + // Score tensor. + constexpr int kClassesNum = 90; + std::vector score_tensor(kBboxesNum * kClassesNum, 1.f / kClassesNum); + + // Detected three objects. + score_tensor[kBicycleBboxIdx * kClassesNum + 1] = 1.0f; // bicycle. + score_tensor[kCarBboxIdx * kClassesNum + 2] = 0.9f; // car. + score_tensor[kMotoCycleBboxIdx * kClassesNum + 3] = 0.8f; // motorcycle. + + // Send tensors and get results. + AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 90}); + AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4}); + MP_ASSERT_OK(Run()); + + // Validate results. + EXPECT_THAT(GetResult>(poller), + IsOkAndHolds(ElementsAre(Approximately(EqualsProto( + R"pb( + label: "bicycle" + score: 1 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.8137423 + ymin: 0.067235775 + width: 0.117221 + height: 0.064774655 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "car" + score: 0.9 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.53849804 + ymin: 0.08949606 + width: 0.05861056 + height: 0.11722109 + } + } + )pb")), + Approximately(EqualsProto( + R"pb( + label: "motorcycle" + score: 0.8 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.13779688 + ymin: 0.26394117 + width: 0.16322193 + height: 0.07384467 + } + } + )pb"))))); +} + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 1877bc7e2..82d4ea21b 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -23,6 +23,11 @@ mediapipe_proto_library( srcs = ["classifier_options.proto"], ) +mediapipe_proto_library( + name = "detector_options_proto", + srcs = ["detector_options.proto"], +) + mediapipe_proto_library( name = "classification_postprocessing_graph_options_proto", srcs = ["classification_postprocessing_graph_options.proto"], @@ -35,6 +40,20 @@ mediapipe_proto_library( ], ) +mediapipe_proto_library( + name = "detection_postprocessing_graph_options_proto", + srcs = ["detection_postprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_proto", + "//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", + ], +) + mediapipe_proto_library( name = "embedder_options_proto", srcs = ["embedder_options.proto"], diff --git a/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto new file mode 100644 index 000000000..ec11df2b4 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto @@ -0,0 +1,49 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package mediapipe.tasks.components.processors.proto; + +import "mediapipe/calculators/tensor/tensors_to_detections_calculator.proto"; +import "mediapipe/calculators/tflite/ssd_anchors_calculator.proto"; +import "mediapipe/calculators/util/detection_label_id_to_text_calculator.proto"; +import "mediapipe/calculators/util/non_max_suppression_calculator.proto"; +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; + +message DetectionPostprocessingGraphOptions { + // Optional SsdAnchorsCalculatorOptions for models without + // non-maximum-suppression in tflite model graph. + optional mediapipe.SsdAnchorsCalculatorOptions ssd_anchors_options = 1; + + // Optional TensorsToDetectionsCalculatorOptions for models without + // non-maximum-suppression in tflite model graph. + optional mediapipe.TensorsToDetectionsCalculatorOptions + tensors_to_detections_options = 2; + + // Optional NonMaxSuppressionCalculatorOptions for models without + // non-maximum-suppression in tflite model graph. + optional mediapipe.NonMaxSuppressionCalculatorOptions + non_max_suppression_options = 3; + + // Optional score calibration options for models with non-maximum-suppression + // in tflite model graph. + optional ScoreCalibrationCalculatorOptions score_calibration_options = 4; + + // Optional detection label id to text calculator options. + optional mediapipe.DetectionLabelIdToTextCalculatorOptions + detection_label_ids_to_text_options = 5; +} diff --git a/mediapipe/tasks/cc/components/processors/proto/detector_options.proto b/mediapipe/tasks/cc/components/processors/proto/detector_options.proto new file mode 100644 index 000000000..c70b1f7a6 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/detector_options.proto @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "DetectorOptionsProto"; + +// Shared options used by all detection tasks. +message DetectorOptions { + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 1 [default = "en"]; + + // The maximum number of top-scored detection results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + optional int32 max_results = 2 [default = -1]; + + // Score threshold, overrides the ones provided in the model metadata + // (if any). Results below this value are rejected. + optional float score_threshold = 3; + + // Overlapping threshold for non-maximum-suppression calculator. Only used for + // models without built-in non-maximum-suppression, i.e., models that don't + // use the Detection_Postprocess TFLite Op + optional float min_suppression_threshold = 6; + + // Optional allowlist of category names. If non-empty, detections whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + repeated string category_allowlist = 4; + + // Optional denylist of category names. If non-empty, detection whose category + // name is in this set will be filtered out. Duplicate or unknown category + // names are ignored. Mutually exclusive with category_allowlist. + repeated string category_denylist = 5; +} diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 855fc29f5..40d7ab50b 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -54,12 +54,7 @@ cc_library( name = "object_detector_graph", srcs = ["object_detector_graph.cc"], deps = [ - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator", - "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detection_transformation_calculator", "//mediapipe/calculators/util:detections_deduplicate_calculator", @@ -71,19 +66,15 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:detection_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index bddbef4bb..de2c0dbaf 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -99,7 +99,20 @@ struct ObjectDetectorOptions { // - only RGB inputs are supported (`channels` is required to be 3). // - if type is kTfLiteFloat32, NormalizationOptions are required to be // attached to the metadata for input normalization. -// Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e: +// Output tensors could be 2 output tensors or 4 output tensors. +// The 2 output tensors must represent locations and scores, respectively. +// (kTfLiteFloat32) +// - locations tensor of size `[num_results x num_coords]`. The num_coords is +// the number of coordinates a location result represent. Usually in the +// form: [4 + 2 * keypoint_num], where 4 location values encode the bounding +// box (y_center, x_center, height, width) and the additional keypoints are in +// (y, x) order. +// (kTfLiteFloat32) +// - scores tensor of size `[num_results x num_classes]`. The values of a +// result represent the classification probability belonging to the class at +// the index, which is denoted in the label file of corresponding tensor +// metadata in the model file. +// The 4 output tensors must come from `DetectionPostProcess` op, i.e: // (kTfLiteFloat32) // - locations tensor of size `[num_results x 4]`, the inner array // representing bounding boxes in the form [top, left, right, bottom]. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index 783ed742a..e2b374970 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -13,16 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "mediapipe/calculators/core/split_vector_calculator.pb.h" -#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" -#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator.pb.h" @@ -31,19 +25,15 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" -#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" +#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" -#include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" -#include "mediapipe/util/label_map.pb.h" -#include "mediapipe/util/label_map_util.h" namespace mediapipe { namespace tasks { @@ -56,42 +46,18 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::metadata::ModelMetadataExtractor; -using ::tflite::BoundingBoxProperties; -using ::tflite::ContentProperties; -using ::tflite::ContentProperties_BoundingBoxProperties; -using ::tflite::EnumNameContentProperties; -using ::tflite::ProcessUnit; -using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; -using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; using TensorsSource = mediapipe::api2::builder::Source>; -constexpr int kDefaultLocationsIndex = 0; -constexpr int kDefaultCategoriesIndex = 1; -constexpr int kDefaultScoresIndex = 2; -constexpr int kDefaultNumResultsIndex = 3; - -constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); - -constexpr char kLocationTensorName[] = "location"; -constexpr char kCategoryTensorName[] = "category"; -constexpr char kScoreTensorName[] = "score"; -constexpr char kNumberOfDetectionsTensorName[] = "number of detections"; - -constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageTag[] = "IMAGE"; -constexpr char kIndicesTag[] = "INDICES"; constexpr char kMatrixTag[] = "MATRIX"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; -constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorTag[] = "TENSORS"; // Struct holding the different output streams produced by the object detection @@ -101,34 +67,6 @@ struct ObjectDetectionOutputStreams { Source image; }; -// Parameters used for configuring the post-processing calculators. -struct PostProcessingSpecs { - // The maximum number of detection results to return. - int max_results; - // Indices of the output tensors to match the output tensors to the correct - // index order of the output tensors: [location, categories, scores, - // num_detections]. - std::vector output_tensor_indices; - // For each pack of 4 coordinates returned by the model, this denotes the - // order in which to get the left, top, right and bottom coordinates. - std::vector bounding_box_corners_order; - // This is populated by reading the label files from the TFLite Model - // Metadata: if no such files are available, this is left empty and the - // ObjectDetector will only be able to populate the `index` field of the - // detection results. - LabelItems label_items; - // Score threshold. Detections with a confidence below this value are - // discarded. If none is provided via metadata or options, -FLT_MAX is set as - // default value. - float score_threshold; - // Set of category indices to be allowed/denied. - absl::flat_hash_set allow_or_deny_categories; - // Indicates `allow_or_deny_categories` is an allowlist or a denylist. - bool is_allowlist; - // Score calibration options, if any. - std::optional score_calibration_options; -}; - absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) { if (options.max_results() == 0) { return CreateStatusWithPayload( @@ -147,310 +85,6 @@ absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) { return absl::OkStatus(); } -absl::StatusOr GetBoundingBoxProperties( - const TensorMetadata& tensor_metadata) { - if (tensor_metadata.content() == nullptr || - tensor_metadata.content()->content_properties() == nullptr) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Expected BoundingBoxProperties for tensor %s, found none.", - tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - ContentProperties type = tensor_metadata.content()->content_properties_type(); - if (type != ContentProperties_BoundingBoxProperties) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Expected BoundingBoxProperties for tensor %s, found %s.", - tensor_metadata.name() ? tensor_metadata.name()->str() : "#0", - EnumNameContentProperties(type)), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - const BoundingBoxProperties* properties = - tensor_metadata.content()->content_properties_as_BoundingBoxProperties(); - - // Mobile SSD only supports "BOUNDARIES" bounding box type. - if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s", - tflite::EnumNameBoundingBoxType(properties->type())), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - // Mobile SSD only supports "RATIO" coordinates type. - if (properties->coordinate_type() != tflite::CoordinateType_RATIO) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Mobile SSD only supports CoordinateType RATIO, found %s", - tflite::EnumNameCoordinateType(properties->coordinate_type())), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - // Index is optional, but must contain 4 values if present. - if (properties->index() != nullptr && properties->index()->size() != 4) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat( - "Expected BoundingBoxProperties index to contain 4 values, found " - "%d", - properties->index()->size()), - MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError); - } - - return properties; -} - -absl::StatusOr GetLabelItemsIfAny( - const ModelMetadataExtractor& metadata_extractor, - const TensorMetadata& tensor_metadata, absl::string_view locale) { - const std::string labels_filename = - ModelMetadataExtractor::FindFirstAssociatedFileName( - tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS); - if (labels_filename.empty()) { - LabelItems empty_label_items; - return empty_label_items; - } - ASSIGN_OR_RETURN(absl::string_view labels_file, - metadata_extractor.GetAssociatedFile(labels_filename)); - const std::string display_names_filename = - ModelMetadataExtractor::FindFirstAssociatedFileName( - tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS, - locale); - absl::string_view display_names_file; - if (!display_names_filename.empty()) { - ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( - display_names_filename)); - } - return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file); -} - -absl::StatusOr GetScoreThreshold( - const ModelMetadataExtractor& metadata_extractor, - const TensorMetadata& tensor_metadata) { - ASSIGN_OR_RETURN( - const ProcessUnit* score_thresholding_process_unit, - metadata_extractor.FindFirstProcessUnit( - tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); - if (score_thresholding_process_unit == nullptr) { - return kDefaultScoreThreshold; - } - return score_thresholding_process_unit->options_as_ScoreThresholdingOptions() - ->global_score_threshold(); -} - -absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( - const ObjectDetectorOptionsProto& config, const LabelItems& label_items) { - absl::flat_hash_set category_indices; - // Exit early if no denylist/allowlist. - if (config.category_denylist_size() == 0 && - config.category_allowlist_size() == 0) { - return category_indices; - } - if (label_items.empty()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "Using `category_allowlist` or `category_denylist` requires " - "labels to be present in the TFLite Model Metadata but none was found.", - MediaPipeTasksStatus::kMetadataMissingLabelsError); - } - const auto& category_list = config.category_allowlist_size() > 0 - ? config.category_allowlist() - : config.category_denylist(); - for (const auto& category_name : category_list) { - int index = -1; - for (int i = 0; i < label_items.size(); ++i) { - if (label_items.at(i).name() == category_name) { - index = i; - break; - } - } - // Ignores duplicate or unknown categories. - if (index < 0) { - continue; - } - category_indices.insert(index); - } - return category_indices; -} - -absl::StatusOr> -GetScoreCalibrationOptionsIfAny( - const ModelMetadataExtractor& metadata_extractor, - const TensorMetadata& tensor_metadata) { - // Get ScoreCalibrationOptions, if any. - ASSIGN_OR_RETURN( - const ProcessUnit* score_calibration_process_unit, - metadata_extractor.FindFirstProcessUnit( - tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions)); - if (score_calibration_process_unit == nullptr) { - return std::nullopt; - } - auto* score_calibration_options = - score_calibration_process_unit->options_as_ScoreCalibrationOptions(); - // Get corresponding AssociatedFile. - auto score_calibration_filename = - metadata_extractor.FindFirstAssociatedFileName( - tensor_metadata, - tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); - if (score_calibration_filename.empty()) { - return CreateStatusWithPayload( - absl::StatusCode::kNotFound, - "Found ScoreCalibrationOptions but missing required associated " - "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", - MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); - } - ASSIGN_OR_RETURN( - absl::string_view score_calibration_file, - metadata_extractor.GetAssociatedFile(score_calibration_filename)); - ScoreCalibrationCalculatorOptions score_calibration_calculator_options; - MP_RETURN_IF_ERROR(ConfigureScoreCalibration( - score_calibration_options->score_transformation(), - score_calibration_options->default_score(), score_calibration_file, - &score_calibration_calculator_options)); - return score_calibration_calculator_options; -} - -std::vector GetOutputTensorIndices( - const flatbuffers::Vector>* - tensor_metadatas) { - std::vector output_indices = { - core::FindTensorIndexByMetadataName(tensor_metadatas, - kLocationTensorName), - core::FindTensorIndexByMetadataName(tensor_metadatas, - kCategoryTensorName), - core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName), - core::FindTensorIndexByMetadataName(tensor_metadatas, - kNumberOfDetectionsTensorName)}; - // locations, categories, scores, and number of detections - for (int i = 0; i < 4; i++) { - int output_index = output_indices[i]; - // If tensor name is not found, set the default output indices. - if (output_index == -1) { - LOG(WARNING) << absl::StrFormat( - "You don't seem to be matching tensor names in metadata list. The " - "tensor name \"%s\" at index %d in the model metadata doesn't " - "match " - "the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].", - tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName, - kCategoryTensorName, kScoreTensorName, kNumberOfDetectionsTensorName); - output_indices = {kDefaultLocationsIndex, kDefaultCategoriesIndex, - kDefaultScoresIndex, kDefaultNumResultsIndex}; - return output_indices; - } - } - return output_indices; -} - -// Builds PostProcessingSpecs from ObjectDetectorOptionsProto and model metadata -// for configuring the post-processing calculators. -absl::StatusOr BuildPostProcessingSpecs( - const ObjectDetectorOptionsProto& options, - const ModelMetadataExtractor* metadata_extractor) { - // Checks output tensor metadata is present and consistent with model. - auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); - if (output_tensors_metadata == nullptr || - output_tensors_metadata->size() != 4) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat("Mismatch between number of output tensors (4) and " - "output tensors metadata (%d).", - output_tensors_metadata == nullptr - ? 0 - : output_tensors_metadata->size()), - MediaPipeTasksStatus::kMetadataInconsistencyError); - } - PostProcessingSpecs specs; - specs.max_results = options.max_results(); - specs.output_tensor_indices = GetOutputTensorIndices(output_tensors_metadata); - // Extracts mandatory BoundingBoxProperties and performs sanity checks on the - // fly. - ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties, - GetBoundingBoxProperties(*output_tensors_metadata->Get( - specs.output_tensor_indices[0]))); - if (bounding_box_properties->index() == nullptr) { - specs.bounding_box_corners_order = {0, 1, 2, 3}; - } else { - auto bounding_box_index = bounding_box_properties->index(); - specs.bounding_box_corners_order = { - bounding_box_index->Get(0), - bounding_box_index->Get(1), - bounding_box_index->Get(2), - bounding_box_index->Get(3), - }; - } - // Builds label map (if available) from metadata. - ASSIGN_OR_RETURN(specs.label_items, - GetLabelItemsIfAny(*metadata_extractor, - *output_tensors_metadata->Get( - specs.output_tensor_indices[1]), - options.display_names_locale())); - // Obtains allow/deny categories. - specs.is_allowlist = !options.category_allowlist().empty(); - ASSIGN_OR_RETURN( - specs.allow_or_deny_categories, - GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items)); - // Sets score threshold. - if (options.has_score_threshold()) { - specs.score_threshold = options.score_threshold(); - } else { - ASSIGN_OR_RETURN(specs.score_threshold, - GetScoreThreshold(*metadata_extractor, - *output_tensors_metadata->Get( - specs.output_tensor_indices[2]))); - } - // Builds score calibration options (if available) from metadata. - ASSIGN_OR_RETURN( - specs.score_calibration_options, - GetScoreCalibrationOptionsIfAny( - *metadata_extractor, - *output_tensors_metadata->Get(specs.output_tensor_indices[2]))); - return specs; -} - -// Fills in the TensorsToDetectionsCalculatorOptions based on -// PostProcessingSpecs. -void ConfigureTensorsToDetectionsCalculator( - const PostProcessingSpecs& specs, - mediapipe::TensorsToDetectionsCalculatorOptions* options) { - options->set_num_classes(specs.label_items.size()); - options->set_num_coords(4); - options->set_min_score_thresh(specs.score_threshold); - if (specs.max_results != -1) { - options->set_max_results(specs.max_results); - } - if (specs.is_allowlist) { - options->mutable_allow_classes()->Assign( - specs.allow_or_deny_categories.begin(), - specs.allow_or_deny_categories.end()); - } else { - options->mutable_ignore_classes()->Assign( - specs.allow_or_deny_categories.begin(), - specs.allow_or_deny_categories.end()); - } - - const auto& output_indices = specs.output_tensor_indices; - // Assigns indices to each the model output tensor. - auto* tensor_mapping = options->mutable_tensor_mapping(); - tensor_mapping->set_detections_tensor_index(output_indices[0]); - tensor_mapping->set_classes_tensor_index(output_indices[1]); - tensor_mapping->set_scores_tensor_index(output_indices[2]); - tensor_mapping->set_num_detections_tensor_index(output_indices[3]); - - // Assigns the bounding box corner order. - auto box_boundaries_indices = options->mutable_box_boundaries_indices(); - box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]); - box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]); - box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]); - box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]); -} - } // namespace // A "mediapipe.tasks.vision.ObjectDetectorGraph" performs object detection. @@ -530,7 +164,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { const core::ModelResources& model_resources, Source image_in, Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); - // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( @@ -539,13 +172,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { model.subgraphs()->size()), MediaPipeTasksStatus::kInvalidArgumentError); } - if (model.subgraphs()->Get(0)->outputs()->size() != 4) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrFormat("Expected a model with 4 output tensors, found %d.", - model.subgraphs()->Get(0)->outputs()->size()), - MediaPipeTasksStatus::kInvalidArgumentError); - } // Checks that metadata is available. auto* metadata_extractor = model_resources.GetMetadataExtractor(); if (metadata_extractor->GetModelMetadata() == nullptr || @@ -577,70 +203,36 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { auto& inference = AddInference( model_resources, task_options.base_options().acceleration(), graph); preprocessing.Out(kTensorTag) >> inference.In(kTensorTag); - - // Adds post processing calculators. - ASSIGN_OR_RETURN( - auto post_processing_specs, - BuildPostProcessingSpecs(task_options, metadata_extractor)); - // Calculators to perform score calibration, if specified in the metadata. - TensorsSource calibrated_tensors = + TensorsSource model_output_tensors = inference.Out(kTensorTag).Cast>(); - if (post_processing_specs.score_calibration_options.has_value()) { - // Split tensors. - auto* split_tensor_vector_node = - &graph.AddNode("SplitTensorVectorCalculator"); - auto& split_tensor_vector_options = - split_tensor_vector_node - ->GetOptions(); - for (int i = 0; i < 4; ++i) { - auto* range = split_tensor_vector_options.add_ranges(); - range->set_begin(i); - range->set_end(i + 1); - } - calibrated_tensors >> split_tensor_vector_node->In(0); - // Add score calibration calculator. - auto* score_calibration_node = - &graph.AddNode("ScoreCalibrationCalculator"); - score_calibration_node->GetOptions() - .CopyFrom(*post_processing_specs.score_calibration_options); - split_tensor_vector_node->Out( - post_processing_specs.output_tensor_indices[1]) >> - score_calibration_node->In(kIndicesTag); - split_tensor_vector_node->Out( - post_processing_specs.output_tensor_indices[2]) >> - score_calibration_node->In(kScoresTag); - - // Re-concatenate tensors. - auto* concatenate_tensor_vector_node = - &graph.AddNode("ConcatenateTensorVectorCalculator"); - for (int i = 0; i < 4; ++i) { - if (i == post_processing_specs.output_tensor_indices[2]) { - score_calibration_node->Out(kCalibratedScoresTag) >> - concatenate_tensor_vector_node->In(i); - } else { - split_tensor_vector_node->Out(i) >> - concatenate_tensor_vector_node->In(i); - } - } - calibrated_tensors = - concatenate_tensor_vector_node->Out(0).Cast>(); - } - // Calculator to convert output tensors to a detection proto vector. - // Connects TensorsToDetectionsCalculator's input stream to the output - // tensors produced by the inference subgraph. - auto& tensors_to_detections = - graph.AddNode("TensorsToDetectionsCalculator"); - ConfigureTensorsToDetectionsCalculator( - post_processing_specs, - &tensors_to_detections - .GetOptions()); - calibrated_tensors >> tensors_to_detections.In(kTensorTag); + // Add Detection postprocessing graph to convert tensors to detections. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.DetectionPostprocessingGraph"); + components::processors::proto::DetectorOptions detector_options; + detector_options.set_max_results(task_options.max_results()); + detector_options.set_score_threshold(task_options.score_threshold()); + detector_options.set_display_names_locale( + task_options.display_names_locale()); + detector_options.mutable_category_allowlist()->CopyFrom( + task_options.category_allowlist()); + detector_options.mutable_category_denylist()->CopyFrom( + task_options.category_denylist()); + // TODO: expose min suppression threshold in + // ObjectDetectorOptions. + detector_options.set_min_suppression_threshold(0.3); + MP_RETURN_IF_ERROR( + components::processors::ConfigureDetectionPostprocessingGraph( + model_resources, detector_options, + postprocessing + .GetOptions())); + model_output_tensors >> postprocessing.In(kTensorTag); + auto detections = postprocessing.Out(kDetectionsTag); // Calculator to projects detections back to the original coordinate system. auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); - tensors_to_detections.Out(kDetectionsTag) >> - detection_projection.In(kDetectionsTag); + detections >> detection_projection.In(kDetectionsTag); preprocessing.Out(kMatrixTag) >> detection_projection.In(kProjectionMatrixTag); @@ -652,22 +244,13 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { detection_transformation.In(kDetectionsTag); preprocessing.Out(kImageSizeTag) >> detection_transformation.In(kImageSizeTag); - - // Calculator to assign detection labels. - auto& detection_label_id_to_text = - graph.AddNode("DetectionLabelIdToTextCalculator"); - auto& detection_label_id_to_text_opts = - detection_label_id_to_text - .GetOptions(); - *detection_label_id_to_text_opts.mutable_label_items() = - std::move(post_processing_specs.label_items); - detection_transformation.Out(kPixelDetectionsTag) >> - detection_label_id_to_text.In(""); + auto detections_in_pixel = + detection_transformation.Out(kPixelDetectionsTag); // Deduplicate Detections with same bounding box coordinates. auto& detections_deduplicate = graph.AddNode("DetectionsDeduplicateCalculator"); - detection_label_id_to_text.Out("") >> detections_deduplicate.In(""); + detections_in_pixel >> detections_deduplicate.In(""); // Outputs the labeled detections and the processed image as the subgraph // output streams. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index a4fed0f9e..8642af7c4 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -76,15 +76,18 @@ using ::testing::HasSubstr; using ::testing::Optional; using DetectionProto = mediapipe::Detection; -constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kMobileSsdWithMetadata[] = - "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; -constexpr char kMobileSsdWithDummyScoreCalibration[] = +constexpr absl::string_view kTestDataDirectory{ + "/mediapipe/tasks/testdata/vision/"}; +constexpr absl::string_view kMobileSsdWithMetadata{ + "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"}; +constexpr absl::string_view kMobileSsdWithDummyScoreCalibration{ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration." - "tflite"; + "tflite"}; // The model has different output tensor order. -constexpr char kEfficientDetWithMetadata[] = - "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"; +constexpr absl::string_view kEfficientDetWithMetadata{ + "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"}; +constexpr absl::string_view kEfficientDetWithoutNms{ + "efficientdet_lite0_fp16_no_nms.tflite"}; // Checks that the two provided `Detection` proto vectors are equal, with a // tolerancy on floating-point scores to account for numerical instabilities. @@ -451,6 +454,51 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) { })pb")})); } +TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs.jpg"))); + auto options = std::make_unique(); + options->max_results = 4; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image)); + MP_ASSERT_OK(object_detector->Close()); + ExpectApproximatelyEqual( + results, + ConvertToDetectionResult( + {ParseTextProtoOrDie(R"pb( + label: "dog" + score: 0.733542 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 636 ymin: 160 width: 282 height: 451 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.699751 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 870 ymin: 411 width: 208 height: 187 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "dog" + score: 0.682425 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 386 ymin: 216 width: 256 height: 376 } + })pb"), + ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.646585 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 83 ymin: 399 width: 347 height: 198 } + })pb")})); +} + TEST_F(ImageModeTest, SucceedsWithoutImageResizing) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( "./", kTestDataDirectory, diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 649d8d452..a8704123c 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -41,6 +41,7 @@ mediapipe_files(srcs = [ "conv2d_input_channel_1.tflite", "deeplabv3.tflite", "dense.tflite", + "efficientdet_lite0_fp16_no_nms.tflite", "face_detection_full_range.tflite", "face_detection_full_range_sparse.tflite", "face_detection_short_range.tflite", @@ -167,6 +168,7 @@ filegroup( "conv2d_input_channel_1.tflite", "deeplabv3.tflite", "dense.tflite", + "efficientdet_lite0_fp16_no_nms.tflite", "face_detection_full_range.tflite", "face_detection_full_range_sparse.tflite", "face_detection_short_range.tflite", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index c47f0fbb6..6a1582cc7 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -276,8 +276,8 @@ def external_files(): http_file( name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_tflite", - sha256 = "bcda125c96d3767bca894c8cbe7bc458379c9974c9fd8bdc6204e7124a74082a", - urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682456096034465"], + sha256 = "237a58389081333e5cf4154e42b593ce7dd357445536fcaf4ca5bc51c2c50f1c", + urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682476299542472"], ) http_file(