DetectionPostProcessingGraph for post processing raw tensors from detection models.

PiperOrigin-RevId: 527363291
This commit is contained in:
MediaPipe Team 2023-04-26 13:42:54 -07:00 committed by Copybara-Service
parent 48aa88f39d
commit c44cc30ece
13 changed files with 1791 additions and 470 deletions

View File

@ -161,3 +161,49 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "detection_postprocessing_graph",
srcs = ["detection_postprocessing_graph.cc"],
hdrs = ["detection_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
"//mediapipe/calculators/util:non_max_suppression_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/tasks/metadata:object_detector_metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -0,0 +1,886 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h"
#include <optional>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/tasks/metadata/object_detector_metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
using ::flatbuffers::Offset;
using ::flatbuffers::Vector;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::BoundingBoxProperties;
using ::tflite::ContentProperties;
using ::tflite::ContentProperties_BoundingBoxProperties;
using ::tflite::EnumNameContentProperties;
using ::tflite::ProcessUnit;
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
using TensorsSource =
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
constexpr int kInModelNmsDefaultLocationsIndex = 0;
constexpr int kInModelNmsDefaultCategoriesIndex = 1;
constexpr int kInModelNmsDefaultScoresIndex = 2;
constexpr int kInModelNmsDefaultNumResultsIndex = 3;
constexpr int kOutModelNmsDefaultLocationsIndex = 0;
constexpr int kOutModelNmsDefaultScoresIndex = 1;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr absl::string_view kLocationTensorName = "location";
constexpr absl::string_view kCategoryTensorName = "category";
constexpr absl::string_view kScoreTensorName = "score";
constexpr absl::string_view kNumberOfDetectionsTensorName =
"number of detections";
constexpr absl::string_view kDetectorMetadataName = "DETECTOR_METADATA";
constexpr absl::string_view kCalibratedScoresTag = "CALIBRATED_SCORES";
constexpr absl::string_view kDetectionsTag = "DETECTIONS";
constexpr absl::string_view kIndicesTag = "INDICES";
constexpr absl::string_view kScoresTag = "SCORES";
constexpr absl::string_view kTensorsTag = "TENSORS";
constexpr absl::string_view kAnchorsTag = "ANCHORS";
// Struct holding the different output streams produced by the graph.
struct DetectionPostprocessingOutputStreams {
Source<std::vector<Detection>> detections;
};
// Parameters used for configuring the post-processing calculators.
struct PostProcessingSpecs {
// The maximum number of detection results to return.
int max_results;
// Indices of the output tensors to match the output tensors to the correct
// index order of the output tensors: [location, categories, scores,
// num_detections].
std::vector<int> output_tensor_indices;
// For each pack of 4 coordinates returned by the model, this denotes the
// order in which to get the left, top, right and bottom coordinates.
std::vector<unsigned int> bounding_box_corners_order;
// This is populated by reading the label files from the TFLite Model
// Metadata: if no such files are available, this is left empty and the
// ObjectDetector will only be able to populate the `index` field of the
// detection results.
LabelItems label_items;
// Score threshold. Detections with a confidence below this value are
// discarded. If none is provided via metadata or options, -FLT_MAX is set as
// default value.
float score_threshold;
// Set of category indices to be allowed/denied.
absl::flat_hash_set<int> allow_or_deny_categories;
// Indicates `allow_or_deny_categories` is an allowlist or a denylist.
bool is_allowlist;
// Score calibration options, if any.
std::optional<ScoreCalibrationCalculatorOptions> score_calibration_options;
};
absl::Status SanityCheckOptions(const proto::DetectorOptions& options) {
if (options.max_results() == 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Invalid `max_results` option: value must be != 0",
MediaPipeTasksStatus::kInvalidArgumentError);
}
if (options.category_allowlist_size() > 0 &&
options.category_denylist_size() > 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"`category_allowlist` and `category_denylist` are mutually "
"exclusive options.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
absl::StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
const TensorMetadata& tensor_metadata) {
if (tensor_metadata.content() == nullptr ||
tensor_metadata.content()->content_properties() == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties for tensor %s, found none.",
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
ContentProperties type = tensor_metadata.content()->content_properties_type();
if (type != ContentProperties_BoundingBoxProperties) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties for tensor %s, found %s.",
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
EnumNameContentProperties(type)),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
const BoundingBoxProperties* properties =
tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
// Mobile SSD only supports "BOUNDARIES" bounding box type.
if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
tflite::EnumNameBoundingBoxType(properties->type())),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
// Mobile SSD only supports "RATIO" coordinates type.
if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Mobile SSD only supports CoordinateType RATIO, found %s",
tflite::EnumNameCoordinateType(properties->coordinate_type())),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
// Index is optional, but must contain 4 values if present.
if (properties->index() != nullptr && properties->index()->size() != 4) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties index to contain 4 values, found "
"%d",
properties->index()->size()),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
return properties;
}
absl::StatusOr<LabelItems> GetLabelItemsIfAny(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata,
tflite::AssociatedFileType associated_file_type, absl::string_view locale) {
const std::string labels_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(tensor_metadata,
associated_file_type);
if (labels_filename.empty()) {
LabelItems empty_label_items;
return empty_label_items;
}
ASSIGN_OR_RETURN(absl::string_view labels_file,
metadata_extractor.GetAssociatedFile(labels_filename));
const std::string display_names_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, associated_file_type, locale);
absl::string_view display_names_file;
if (!display_names_filename.empty()) {
ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
display_names_filename));
}
return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file);
}
absl::StatusOr<float> GetScoreThreshold(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) {
ASSIGN_OR_RETURN(
const ProcessUnit* score_thresholding_process_unit,
metadata_extractor.FindFirstProcessUnit(
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
if (score_thresholding_process_unit == nullptr) {
return kDefaultScoreThreshold;
}
return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
->global_score_threshold();
}
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
const proto::DetectorOptions& config, const LabelItems& label_items) {
absl::flat_hash_set<int> category_indices;
// Exit early if no denylist/allowlist.
if (config.category_denylist_size() == 0 &&
config.category_allowlist_size() == 0) {
return category_indices;
}
if (label_items.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Using `category_allowlist` or `category_denylist` requires "
"labels to be present in the TFLite Model Metadata but none was found.",
MediaPipeTasksStatus::kMetadataMissingLabelsError);
}
const auto& category_list = config.category_allowlist_size() > 0
? config.category_allowlist()
: config.category_denylist();
for (const auto& category_name : category_list) {
int index = -1;
for (int i = 0; i < label_items.size(); ++i) {
if (label_items.at(i).name() == category_name) {
index = i;
break;
}
}
// Ignores duplicate or unknown categories.
if (index < 0) {
continue;
}
category_indices.insert(index);
}
return category_indices;
}
absl::StatusOr<std::optional<ScoreCalibrationCalculatorOptions>>
GetScoreCalibrationOptionsIfAny(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) {
// Get ScoreCalibrationOptions, if any.
ASSIGN_OR_RETURN(
const ProcessUnit* score_calibration_process_unit,
metadata_extractor.FindFirstProcessUnit(
tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions));
if (score_calibration_process_unit == nullptr) {
return std::nullopt;
}
auto* score_calibration_options =
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
// Get corresponding AssociatedFile.
auto score_calibration_filename =
metadata_extractor.FindFirstAssociatedFileName(
tensor_metadata,
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
if (score_calibration_filename.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kNotFound,
"Found ScoreCalibrationOptions but missing required associated "
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
}
ASSIGN_OR_RETURN(
absl::string_view score_calibration_file,
metadata_extractor.GetAssociatedFile(score_calibration_filename));
ScoreCalibrationCalculatorOptions score_calibration_calculator_options;
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
score_calibration_options->score_transformation(),
score_calibration_options->default_score(), score_calibration_file,
&score_calibration_calculator_options));
return score_calibration_calculator_options;
}
absl::StatusOr<std::vector<int>> GetOutputTensorIndices(
const Vector<Offset<TensorMetadata>>* tensor_metadatas) {
std::vector<int> output_indices;
if (tensor_metadatas->size() == 4) {
output_indices = {
core::FindTensorIndexByMetadataName(tensor_metadatas,
kLocationTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas,
kCategoryTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas,
kNumberOfDetectionsTensorName)};
// locations, categories, scores, and number of detections
for (int i = 0; i < 4; i++) {
int output_index = output_indices[i];
// If tensor name is not found, set the default output indices.
if (output_index == -1) {
LOG(WARNING) << absl::StrFormat(
"You don't seem to be matching tensor names in metadata list. The "
"tensor name \"%s\" at index %d in the model metadata doesn't "
"match "
"the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].",
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
kCategoryTensorName, kScoreTensorName,
kNumberOfDetectionsTensorName);
output_indices = {
kInModelNmsDefaultLocationsIndex, kInModelNmsDefaultCategoriesIndex,
kInModelNmsDefaultScoresIndex, kInModelNmsDefaultNumResultsIndex};
return output_indices;
}
}
} else if (tensor_metadatas->size() == 2) {
output_indices = {core::FindTensorIndexByMetadataName(tensor_metadatas,
kLocationTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas,
kScoreTensorName)};
// location, score
for (int i = 0; i < 2; i++) {
int output_index = output_indices[i];
// If tensor name is not found, set the default output indices.
if (output_index == -1) {
LOG(WARNING) << absl::StrFormat(
"You don't seem to be matching tensor names in metadata list. The "
"tensor name \"%s\" at index %d in the model metadata doesn't "
"match "
"the available output names: [\"%s\", \"%s\"].",
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
kScoreTensorName);
output_indices = {kOutModelNmsDefaultLocationsIndex,
kOutModelNmsDefaultScoresIndex};
return output_indices;
}
}
} else {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected a model with 2 or 4 output tensors metadata, found %d.",
tensor_metadatas->size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
return output_indices;
}
// Builds PostProcessingSpecs from DetectorOptions and model metadata for
// configuring the post-processing calculators.
absl::StatusOr<PostProcessingSpecs> BuildPostProcessingSpecs(
const proto::DetectorOptions& options, bool in_model_nms,
const ModelMetadataExtractor* metadata_extractor) {
const auto* output_tensors_metadata =
metadata_extractor->GetOutputTensorMetadata();
PostProcessingSpecs specs;
specs.max_results = options.max_results();
ASSIGN_OR_RETURN(specs.output_tensor_indices,
GetOutputTensorIndices(output_tensors_metadata));
// Extracts mandatory BoundingBoxProperties and performs sanity checks on the
// fly.
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
GetBoundingBoxProperties(*output_tensors_metadata->Get(
specs.output_tensor_indices[0])));
if (bounding_box_properties->index() == nullptr) {
specs.bounding_box_corners_order = {0, 1, 2, 3};
} else {
auto bounding_box_index = bounding_box_properties->index();
specs.bounding_box_corners_order = {
bounding_box_index->Get(0),
bounding_box_index->Get(1),
bounding_box_index->Get(2),
bounding_box_index->Get(3),
};
}
// Builds label map (if available) from metadata.
// For models with in-model-nms, the label map is stored in the Category
// tensor which use TENSOR_VALUE_LABELS. For models with out-of-model-nms, the
// label map is stored in the Score tensor which use TENSOR_AXIS_LABELS.
ASSIGN_OR_RETURN(
specs.label_items,
GetLabelItemsIfAny(
*metadata_extractor,
*output_tensors_metadata->Get(specs.output_tensor_indices[1]),
in_model_nms ? tflite::AssociatedFileType_TENSOR_VALUE_LABELS
: tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
options.display_names_locale()));
// Obtains allow/deny categories.
specs.is_allowlist = !options.category_allowlist().empty();
ASSIGN_OR_RETURN(
specs.allow_or_deny_categories,
GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items));
// Sets score threshold.
if (options.has_score_threshold()) {
specs.score_threshold = options.score_threshold();
} else {
ASSIGN_OR_RETURN(
specs.score_threshold,
GetScoreThreshold(
*metadata_extractor,
*output_tensors_metadata->Get(
specs.output_tensor_indices
[in_model_nms ? kInModelNmsDefaultScoresIndex
: kOutModelNmsDefaultScoresIndex])));
}
if (in_model_nms) {
// Builds score calibration options (if available) from metadata.
ASSIGN_OR_RETURN(
specs.score_calibration_options,
GetScoreCalibrationOptionsIfAny(
*metadata_extractor,
*output_tensors_metadata->Get(
specs.output_tensor_indices[kInModelNmsDefaultScoresIndex])));
}
return specs;
}
// Builds PostProcessingSpecs from DetectorOptions and model metadata for
// configuring the post-processing calculators for models with
// non-maximum-suppression.
absl::StatusOr<PostProcessingSpecs> BuildInModelNmsPostProcessingSpecs(
const proto::DetectorOptions& options,
const ModelMetadataExtractor* metadata_extractor) {
// Checks output tensor metadata is present and consistent with model.
auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr ||
output_tensors_metadata->size() != 4) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of output tensors (4) and "
"output tensors metadata (%d).",
output_tensors_metadata == nullptr
? 0
: output_tensors_metadata->size()),
MediaPipeTasksStatus::kMetadataInconsistencyError);
}
return BuildPostProcessingSpecs(options, /*in_model_nms=*/true,
metadata_extractor);
}
// Fills in the TensorsToDetectionsCalculatorOptions based on
// PostProcessingSpecs.
void ConfigureInModelNmsTensorsToDetectionsCalculator(
const PostProcessingSpecs& specs,
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
options->set_num_classes(specs.label_items.size());
options->set_num_coords(4);
options->set_min_score_thresh(specs.score_threshold);
if (specs.max_results != -1) {
options->set_max_results(specs.max_results);
}
if (specs.is_allowlist) {
options->mutable_allow_classes()->Assign(
specs.allow_or_deny_categories.begin(),
specs.allow_or_deny_categories.end());
} else {
options->mutable_ignore_classes()->Assign(
specs.allow_or_deny_categories.begin(),
specs.allow_or_deny_categories.end());
}
const auto& output_indices = specs.output_tensor_indices;
// Assigns indices to each the model output tensor.
auto* tensor_mapping = options->mutable_tensor_mapping();
tensor_mapping->set_detections_tensor_index(output_indices[0]);
tensor_mapping->set_classes_tensor_index(output_indices[1]);
tensor_mapping->set_scores_tensor_index(output_indices[2]);
tensor_mapping->set_num_detections_tensor_index(output_indices[3]);
// Assigns the bounding box corner order.
auto box_boundaries_indices = options->mutable_box_boundaries_indices();
box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]);
box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]);
box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]);
box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]);
}
// Builds PostProcessingSpecs from DetectorOptions and model metadata for
// configuring the post-processing calculators for models without
// non-maximum-suppression.
absl::StatusOr<PostProcessingSpecs> BuildOutModelNmsPostProcessingSpecs(
const proto::DetectorOptions& options,
const ModelMetadataExtractor* metadata_extractor) {
// Checks output tensor metadata is present and consistent with model.
auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr ||
output_tensors_metadata->size() != 2) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of output tensors (2) and "
"output tensors metadata (%d).",
output_tensors_metadata == nullptr
? 0
: output_tensors_metadata->size()),
MediaPipeTasksStatus::kMetadataInconsistencyError);
}
return BuildPostProcessingSpecs(options, /*in_model_nms=*/false,
metadata_extractor);
}
// Configures the TensorsToDetectionCalculator for models without
// non-maximum-suppression in tflite model. The required config parameters are
// extracted from the ObjectDetectorMetadata
// (metadata/object_detector_metadata_schema.fbs).
absl::Status ConfigureOutModelNmsTensorsToDetectionsCalculator(
const ModelMetadataExtractor* metadata_extractor,
const PostProcessingSpecs& specs,
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
bool found_detector_metadata = false;
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
metadata_extractor->GetCustomMetadataList()->size() > 0) {
for (const auto* custom_metadata :
*metadata_extractor->GetCustomMetadataList()) {
if (custom_metadata->name()->str() == kDetectorMetadataName) {
found_detector_metadata = true;
const auto* tensors_decoding_options =
GetObjectDetectorOptions(custom_metadata->data()->data())
->tensors_decoding_options();
// Here we don't set the max results for TensorsToDetectionsCalculator.
// For models without nms, the results are filtered by max_results in
// NonMaxSuppressionCalculator.
options->set_num_classes(tensors_decoding_options->num_classes());
options->set_num_boxes(tensors_decoding_options->num_boxes());
options->set_num_coords(tensors_decoding_options->num_coords());
options->set_keypoint_coord_offset(
tensors_decoding_options->keypoint_coord_offset());
options->set_num_keypoints(tensors_decoding_options->num_keypoints());
options->set_num_values_per_keypoint(
tensors_decoding_options->num_values_per_keypoint());
options->set_x_scale(tensors_decoding_options->x_scale());
options->set_y_scale(tensors_decoding_options->y_scale());
options->set_w_scale(tensors_decoding_options->w_scale());
options->set_h_scale(tensors_decoding_options->h_scale());
options->set_apply_exponential_on_box_size(
tensors_decoding_options->apply_exponential_on_box_size());
options->set_sigmoid_score(tensors_decoding_options->sigmoid_score());
break;
}
}
}
if (!found_detector_metadata) {
return absl::InvalidArgumentError(
"TensorsDecodingOptions is not found in the object detector "
"metadata.");
}
// Options not configured through metadata.
options->set_box_format(
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW);
options->set_min_score_thresh(specs.score_threshold);
if (specs.is_allowlist) {
options->mutable_allow_classes()->Assign(
specs.allow_or_deny_categories.begin(),
specs.allow_or_deny_categories.end());
} else {
options->mutable_ignore_classes()->Assign(
specs.allow_or_deny_categories.begin(),
specs.allow_or_deny_categories.end());
}
const auto& output_indices = specs.output_tensor_indices;
// Assigns indices to each the model output tensor.
auto* tensor_mapping = options->mutable_tensor_mapping();
tensor_mapping->set_detections_tensor_index(output_indices[0]);
tensor_mapping->set_scores_tensor_index(output_indices[1]);
return absl::OkStatus();
}
// Configures the SsdAnchorsCalculator for models without
// non-maximum-suppression in tflite model. The required config parameters are
// extracted from the ObjectDetectorMetadata
// (metadata/object_detector_metadata_schema.fbs).
absl::Status ConfigureSsdAnchorsCalculator(
const ModelMetadataExtractor* metadata_extractor,
mediapipe::SsdAnchorsCalculatorOptions* options) {
bool found_detector_metadata = false;
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
metadata_extractor->GetCustomMetadataList()->size() > 0) {
for (const auto* custom_metadata :
*metadata_extractor->GetCustomMetadataList()) {
if (custom_metadata->name()->str() == kDetectorMetadataName) {
found_detector_metadata = true;
const auto* ssd_anchors_options =
GetObjectDetectorOptions(custom_metadata->data()->data())
->ssd_anchors_options();
for (const auto* ssd_anchor :
*ssd_anchors_options->fixed_anchors_schema()->anchors()) {
auto* fixed_anchor = options->add_fixed_anchors();
fixed_anchor->set_y_center(ssd_anchor->y_center());
fixed_anchor->set_x_center(ssd_anchor->x_center());
fixed_anchor->set_h(ssd_anchor->height());
fixed_anchor->set_w(ssd_anchor->width());
}
break;
}
}
}
if (!found_detector_metadata) {
return absl::InvalidArgumentError(
"SsdAnchorsOptions is not found in the object detector "
"metadata.");
}
return absl::OkStatus();
}
// Sets the default IoU-based non-maximum-suppression configs, and set the
// min_suppression_threshold and max_results for detection models without
// non-maximum-suppression.
void ConfigureNonMaxSuppressionCalculator(
const proto::DetectorOptions& detector_options,
mediapipe::NonMaxSuppressionCalculatorOptions* options) {
options->set_min_suppression_threshold(
detector_options.min_suppression_threshold());
options->set_overlap_type(
mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION);
options->set_algorithm(
mediapipe::NonMaxSuppressionCalculatorOptions::DEFAULT);
options->set_max_num_detections(detector_options.max_results());
}
// Sets the labels from post PostProcessingSpecs.
void ConfigureDetectionLabelIdToTextCalculator(
PostProcessingSpecs& specs,
mediapipe::DetectionLabelIdToTextCalculatorOptions* options) {
*options->mutable_label_items() = std::move(specs.label_items);
}
// Splits the vector of 4 output tensors from model inference and calibrate the
// score tensors according to the metadata, if any. Then concatenate the tensors
// back to a vector of 4 tensors.
absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
Source<std::vector<Tensor>> model_output_tensors,
const proto::DetectionPostprocessingGraphOptions& options, Graph& graph) {
// Split tensors.
auto* split_tensor_vector_node =
&graph.AddNode("SplitTensorVectorCalculator");
auto& split_tensor_vector_options =
split_tensor_vector_node
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
for (int i = 0; i < 4; ++i) {
auto* range = split_tensor_vector_options.add_ranges();
range->set_begin(i);
range->set_end(i + 1);
}
model_output_tensors >> split_tensor_vector_node->In(0);
// Add score calibration calculator.
auto* score_calibration_node = &graph.AddNode("ScoreCalibrationCalculator");
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
.CopyFrom(options.score_calibration_options());
const auto& tensor_mapping =
options.tensors_to_detections_options().tensor_mapping();
split_tensor_vector_node->Out(tensor_mapping.classes_tensor_index()) >>
score_calibration_node->In(kIndicesTag);
split_tensor_vector_node->Out(tensor_mapping.scores_tensor_index()) >>
score_calibration_node->In(kScoresTag);
// Re-concatenate tensors.
auto* concatenate_tensor_vector_node =
&graph.AddNode("ConcatenateTensorVectorCalculator");
for (int i = 0; i < 4; ++i) {
if (i == tensor_mapping.scores_tensor_index()) {
score_calibration_node->Out(kCalibratedScoresTag) >>
concatenate_tensor_vector_node->In(i);
} else {
split_tensor_vector_node->Out(i) >> concatenate_tensor_vector_node->In(i);
}
}
model_output_tensors =
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
return model_output_tensors;
}
} // namespace
absl::Status ConfigureDetectionPostprocessingGraph(
const tasks::core::ModelResources& model_resources,
const proto::DetectorOptions& detector_options,
proto::DetectionPostprocessingGraphOptions& options) {
MP_RETURN_IF_ERROR(SanityCheckOptions(detector_options));
const auto& model = *model_resources.GetTfLiteModel();
bool in_model_nms = false;
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected a model with a single subgraph, found %d.",
model.subgraphs()->size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
if (model.subgraphs()->Get(0)->outputs()->size() == 2) {
in_model_nms = false;
} else if (model.subgraphs()->Get(0)->outputs()->size() == 4) {
in_model_nms = true;
} else {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected a model with 2 or 4 output tensors, found %d.",
model.subgraphs()->Get(0)->outputs()->size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
const ModelMetadataExtractor* metadata_extractor =
model_resources.GetMetadataExtractor();
if (in_model_nms) {
ASSIGN_OR_RETURN(auto post_processing_specs,
BuildInModelNmsPostProcessingSpecs(detector_options,
metadata_extractor));
ConfigureInModelNmsTensorsToDetectionsCalculator(
post_processing_specs, options.mutable_tensors_to_detections_options());
ConfigureDetectionLabelIdToTextCalculator(
post_processing_specs,
options.mutable_detection_label_ids_to_text_options());
if (post_processing_specs.score_calibration_options.has_value()) {
*options.mutable_score_calibration_options() =
std::move(*post_processing_specs.score_calibration_options);
}
} else {
ASSIGN_OR_RETURN(auto post_processing_specs,
BuildOutModelNmsPostProcessingSpecs(detector_options,
metadata_extractor));
MP_RETURN_IF_ERROR(ConfigureOutModelNmsTensorsToDetectionsCalculator(
metadata_extractor, post_processing_specs,
options.mutable_tensors_to_detections_options()));
MP_RETURN_IF_ERROR(ConfigureSsdAnchorsCalculator(
metadata_extractor, options.mutable_ssd_anchors_options()));
ConfigureNonMaxSuppressionCalculator(
detector_options, options.mutable_non_max_suppression_options());
ConfigureDetectionLabelIdToTextCalculator(
post_processing_specs,
options.mutable_detection_label_ids_to_text_options());
}
return absl::OkStatus();
}
// A DetectionPostprocessingGraph converts raw tensors into
// std::vector<Detection>.
//
// Inputs:
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator. The tensors vector could be
// size 4 or size 2. Tensors vector of size 4 expects the tensors from the
// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector
// of size 2 expects the tensors from the models without the ops.
// [1]:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
// Outputs:
// DETECTIONS - std::vector<Detection>
// The postprocessed detection results.
//
// The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureDetectionPostprocessingGraph()' function. See header
// file for more details.
class DetectionPostprocessingGraph : public mediapipe::Subgraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
BuildDetectionPostprocessing(
*sc->MutableOptions<proto::DetectionPostprocessingGraphOptions>(),
graph.In(kTensorsTag).Cast<std::vector<Tensor>>(), graph));
output_streams.detections >>
graph.Out(kDetectionsTag).Cast<std::vector<Detection>>();
return graph.GetConfig();
}
private:
// Adds an on-device detection postprocessing graph into the provided
// builder::Graph instance. The detection postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream:
// - Detection results as a std::vector<Detection>.
//
// graph_options: the on-device DetectionPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<DetectionPostprocessingOutputStreams>
BuildDetectionPostprocessing(
proto::DetectionPostprocessingGraphOptions& graph_options,
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
std::optional<Source<std::vector<Detection>>> detections;
if (!graph_options.has_non_max_suppression_options()) {
// Calculators to perform score calibration, if specified in the options.
if (graph_options.has_score_calibration_options()) {
ASSIGN_OR_RETURN(tensors_in,
CalibrateScores(tensors_in, graph_options, graph));
}
// Calculator to convert output tensors to a detection proto vector.
auto& tensors_to_detections =
graph.AddNode("TensorsToDetectionsCalculator");
tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
.Swap(graph_options.mutable_tensors_to_detections_options());
tensors_in >> tensors_to_detections.In(kTensorsTag);
detections = tensors_to_detections.Out(kDetectionsTag)
.Cast<std::vector<Detection>>();
} else {
// Generates a single side packet containing a vector of SSD anchors.
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>().Swap(
graph_options.mutable_ssd_anchors_options());
auto anchors =
ssd_anchor.SideOut("").Cast<std::vector<mediapipe::Anchor>>();
// Convert raw output tensors to detections.
auto& tensors_to_detections =
graph.AddNode("TensorsToDetectionsCalculator");
tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
.Swap(graph_options.mutable_tensors_to_detections_options());
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
tensors_in >> tensors_to_detections.In(kTensorsTag);
detections = tensors_to_detections.Out(kDetectionsTag)
.Cast<std::vector<mediapipe::Detection>>();
// Non maximum suppression removes redundant object detections.
auto& non_maximum_suppression =
graph.AddNode("NonMaxSuppressionCalculator");
non_maximum_suppression
.GetOptions<mediapipe::NonMaxSuppressionCalculatorOptions>()
.Swap(graph_options.mutable_non_max_suppression_options());
*detections >> non_maximum_suppression.In("");
detections =
non_maximum_suppression.Out("").Cast<std::vector<Detection>>();
}
// Calculator to assign detection labels.
auto& detection_label_id_to_text =
graph.AddNode("DetectionLabelIdToTextCalculator");
detection_label_id_to_text
.GetOptions<mediapipe::DetectionLabelIdToTextCalculatorOptions>()
.Swap(graph_options.mutable_detection_label_ids_to_text_options());
*detections >> detection_label_id_to_text.In("");
return {
{detection_label_id_to_text.Out("").Cast<std::vector<Detection>>()}};
}
};
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::processors::DetectionPostprocessingGraph); // NOLINT
// clang-format on
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,62 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Configures a DetectionPostprocessingGraph using the provided model
// resources and DetectorOptions.
//
// Example usage:
//
// auto& postprocessing =
// graph.AddNode("mediapipe.tasks.components.processors.DetectionPostprocessingGraph");
// MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph(
// model_resources,
// detector_options,
// &preprocessing.GetOptions<DetectionPostprocessingGraphOptions>()));
//
// The resulting DetectionPostprocessingGraph has the following I/O:
// Inputs:
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator. The tensors vector could be
// size 4 or size 2. Tensors vector of size 4 expects the tensors from the
// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector
// of size 2 expects the tensors from the models without the ops.
// [1]:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
// Outputs:
// DETECTIONS - std::vector<Detection>
// The postprocessed detection results.
absl::Status ConfigureDetectionPostprocessingGraph(
const tasks::core::ModelResources& model_resources,
const proto::DetectorOptions& detector_options,
proto::DetectionPostprocessingGraphOptions& options);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_

View File

@ -0,0 +1,570 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h"
#include <vector>
#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/graph_runner.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "tensorflow/lite/test_util.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::core::ModelResources;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
using ::testing::Pointwise;
using ::testing::proto::Approximately;
using ::testing::proto::Partially;
constexpr absl::string_view kTestDataDirectory =
"/mediapipe/tasks/testdata/vision";
constexpr absl::string_view kMobileSsdWithMetadata =
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
constexpr absl::string_view kMobileSsdWithDummyScoreCalibration =
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration."
"tflite";
constexpr absl::string_view kEfficientDetWithoutNms =
"efficientdet_lite0_fp16_no_nms.tflite";
constexpr char kTestModelResourcesTag[] = "test_model_resources";
constexpr absl::string_view kTensorsTag = "TENSORS";
constexpr absl::string_view kDetectionsTag = "DETECTIONS";
constexpr absl::string_view kTensorsName = "tensors";
constexpr absl::string_view kDetectionsName = "detections";
// Helper function to get ModelResources.
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
absl::string_view model_name) {
auto external_file = std::make_unique<core::proto::ExternalFile>();
external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name));
return ModelResources::Create(kTestModelResourcesTag,
std::move(external_file));
}
class ConfigureTest : public tflite::testing::Test {};
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithMetadata));
proto::DetectorOptions options_in;
options_in.set_max_results(0);
proto::DetectionPostprocessingGraphOptions options_out;
auto status = ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
}
TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithMetadata));
proto::DetectorOptions options_in;
options_in.add_category_allowlist("foo");
options_in.add_category_denylist("bar");
proto::DetectionPostprocessingGraphOptions options_out;
auto status = ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
}
TEST_F(ConfigureTest, SucceedsWithMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithMetadata));
proto::DetectorOptions options_in;
options_in.set_max_results(3);
proto::DetectionPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out));
EXPECT_THAT(
options_out,
Approximately(Partially(EqualsProto(
R"pb(tensors_to_detections_options {
min_score_thresh: -3.4028235e+38
num_classes: 90
num_coords: 4
max_results: 3
tensor_mapping {
detections_tensor_index: 0
classes_tensor_index: 1
scores_tensor_index: 2
num_detections_tensor_index: 3
}
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
}
)pb"))));
}
TEST_F(ConfigureTest, SucceedsWithMaxResultsWithoutModelNms) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources, CreateModelResourcesForModel(
kEfficientDetWithoutNms));
proto::DetectorOptions options_in;
options_in.set_max_results(3);
proto::DetectionPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out));
EXPECT_THAT(options_out, Approximately(Partially(EqualsProto(
R"pb(tensors_to_detections_options {
min_score_thresh: -3.4028235e+38
num_classes: 90
num_boxes: 19206
num_coords: 4
x_scale: 1
y_scale: 1
w_scale: 1
h_scale: 1
keypoint_coord_offset: 0
num_keypoints: 0
num_values_per_keypoint: 2
apply_exponential_on_box_size: true
sigmoid_score: false
tensor_mapping {
detections_tensor_index: 1
scores_tensor_index: 0
}
box_format: YXHW
}
non_max_suppression_options {
max_num_detections: 3
min_suppression_threshold: 0
overlap_type: INTERSECTION_OVER_UNION
algorithm: DEFAULT
}
)pb"))));
EXPECT_THAT(
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
}
TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithMetadata));
proto::DetectorOptions options_in;
options_in.set_score_threshold(0.5);
proto::DetectionPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out));
EXPECT_THAT(
options_out,
Approximately(Partially(EqualsProto(
R"pb(tensors_to_detections_options {
min_score_thresh: 0.5
num_classes: 90
num_coords: 4
tensor_mapping {
detections_tensor_index: 0
classes_tensor_index: 1
scores_tensor_index: 2
num_detections_tensor_index: 3
}
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
}
)pb"))));
EXPECT_THAT(
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
}
TEST_F(ConfigureTest, SucceedsWithAllowlist) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithMetadata));
proto::DetectorOptions options_in;
options_in.add_category_allowlist("bicycle");
proto::DetectionPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out));
// Clear labels ids to text and compare the rest of the options.
options_out.clear_detection_label_ids_to_text_options();
EXPECT_THAT(
options_out,
Approximately(EqualsProto(
R"pb(tensors_to_detections_options {
min_score_thresh: -3.4028235e+38
num_classes: 90
num_coords: 4
allow_classes: 1
tensor_mapping {
detections_tensor_index: 0
classes_tensor_index: 1
scores_tensor_index: 2
num_detections_tensor_index: 3
}
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
}
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithDenylist) {
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithMetadata));
proto::DetectorOptions options_in;
options_in.add_category_denylist("person");
proto::DetectionPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out));
// Clear labels ids to text and compare the rest of the options.
options_out.clear_detection_label_ids_to_text_options();
EXPECT_THAT(
options_out,
Approximately(EqualsProto(
R"pb(tensors_to_detections_options {
min_score_thresh: -3.4028235e+38
num_classes: 90
num_coords: 4
ignore_classes: 0
tensor_mapping {
detections_tensor_index: 0
classes_tensor_index: 1
scores_tensor_index: 2
num_detections_tensor_index: 3
}
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
}
)pb")));
}
TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kMobileSsdWithDummyScoreCalibration));
proto::DetectorOptions options_in;
proto::DetectionPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
options_in, options_out));
// Clear labels ids to text.
options_out.clear_detection_label_ids_to_text_options();
// Check sigmoids size and first element.
ASSERT_EQ(options_out.score_calibration_options().sigmoids_size(), 89);
EXPECT_THAT(options_out.score_calibration_options().sigmoids()[0],
EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb"));
options_out.mutable_score_calibration_options()->clear_sigmoids();
// Compare the rest of the option.
EXPECT_THAT(
options_out,
Approximately(EqualsProto(
R"pb(tensors_to_detections_options {
min_score_thresh: -3.4028235e+38
num_classes: 90
num_coords: 4
tensor_mapping {
detections_tensor_index: 0
classes_tensor_index: 1
scores_tensor_index: 2
num_detections_tensor_index: 3
}
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
}
score_calibration_options {
score_transformation: IDENTITY
default_score: 0.5
}
)pb")));
}
class PostprocessingTest : public tflite::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::DetectorOptions& options) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"DetectionPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph(
*model_resources, options,
postprocessing
.GetOptions<proto::DetectionPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(
std::string(kTensorsName)) >>
postprocessing.In(kTensorsTag);
postprocessing.Out(kDetectionsTag).SetName(std::string(kDetectionsName)) >>
graph[Output<std::vector<Detection>>(kDetectionsTag)];
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
std::string(kDetectionsName)));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
template <typename T>
void AddTensor(const std::vector<T>& tensor,
const Tensor::ElementType& element_type,
const Tensor::Shape& shape) {
tensors_->emplace_back(element_type, shape);
auto view = tensors_->back().GetCpuWriteView();
T* buffer = view.buffer<T>();
std::copy(tensor.begin(), tensor.end(), buffer);
}
absl::Status Run(int timestamp = 0) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
std::string(kTensorsName),
Adopt(tensors_.release()).At(Timestamp(timestamp))));
// Reset tensors for future calls.
tensors_ = absl::make_unique<std::vector<Tensor>>();
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
std::unique_ptr<std::vector<Tensor>> tensors_ =
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Build graph.
proto::DetectorOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(auto poller,
BuildGraph(kMobileSsdWithMetadata, options));
// Build input tensors.
constexpr int kBboxesNum = 5;
// Location tensor.
std::vector<float> location_tensor(kBboxesNum * 4, 0);
for (int i = 0; i < kBboxesNum; ++i) {
location_tensor[i * 4] = 0.1f;
location_tensor[i * 4 + 1] = 0.1f;
location_tensor[i * 4 + 2] = 0.4f;
location_tensor[i * 4 + 3] = 0.5f;
}
// Category tensor.
std::vector<float> category_tensor(kBboxesNum, 0);
for (int i = 0; i < kBboxesNum; ++i) {
category_tensor[i] = i + 1;
}
// Score tensor. Post processed tensor scores are in descending order.
std::vector<float> score_tensor(kBboxesNum, 0);
for (int i = 0; i < kBboxesNum; ++i) {
score_tensor[i] = static_cast<float>(kBboxesNum - i) / kBboxesNum;
}
// Number of detections tensor.
std::vector<float> num_detections_tensor(1, 0);
num_detections_tensor[0] = kBboxesNum;
// Send tensors and get results.
AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4});
AddTensor(category_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum});
AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum});
AddTensor(num_detections_tensor, Tensor::ElementType::kFloat32, {1});
MP_ASSERT_OK(Run());
// Validate results.
EXPECT_THAT(GetResult<std::vector<Detection>>(poller),
IsOkAndHolds(ElementsAre(Approximately(EqualsProto(
R"pb(
label: "bicycle"
score: 1
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.1
ymin: 0.1
width: 0.4
height: 0.3
}
}
)pb")),
Approximately(EqualsProto(
R"pb(
label: "car"
score: 0.8
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.1
ymin: 0.1
width: 0.4
height: 0.3
}
}
)pb")),
Approximately(EqualsProto(
R"pb(
label: "motorcycle"
score: 0.6
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.1
ymin: 0.1
width: 0.4
height: 0.3
}
}
)pb")))));
}
TEST_F(PostprocessingTest, SucceedsWithOutModelNms) {
// Build graph.
proto::DetectorOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(auto poller,
BuildGraph(kEfficientDetWithoutNms, options));
// Build input tensors.
constexpr int kBboxesNum = 19206;
constexpr int kBicycleBboxIdx = 1000;
constexpr int kCarBboxIdx = 2000;
constexpr int kMotoCycleBboxIdx = 4000;
// Location tensor.
std::vector<float> location_tensor(kBboxesNum * 4, 0);
for (int i = 0; i < kBboxesNum; ++i) {
location_tensor[i * 4] = 0.5f;
location_tensor[i * 4 + 1] = 0.5f;
location_tensor[i * 4 + 2] = 0.001f;
location_tensor[i * 4 + 3] = 0.001f;
}
// Detected three objects.
location_tensor[kBicycleBboxIdx * 4] = 0.7f;
location_tensor[kBicycleBboxIdx * 4 + 1] = 0.8f;
location_tensor[kBicycleBboxIdx * 4 + 2] = 0.2f;
location_tensor[kBicycleBboxIdx * 4 + 3] = 0.1f;
location_tensor[kCarBboxIdx * 4] = 0.1f;
location_tensor[kCarBboxIdx * 4 + 1] = 0.1f;
location_tensor[kCarBboxIdx * 4 + 2] = 0.1f;
location_tensor[kCarBboxIdx * 4 + 3] = 0.1f;
location_tensor[kMotoCycleBboxIdx * 4] = 0.2f;
location_tensor[kMotoCycleBboxIdx * 4 + 1] = 0.8f;
location_tensor[kMotoCycleBboxIdx * 4 + 2] = 0.1f;
location_tensor[kMotoCycleBboxIdx * 4 + 3] = 0.2f;
// Score tensor.
constexpr int kClassesNum = 90;
std::vector<float> score_tensor(kBboxesNum * kClassesNum, 1.f / kClassesNum);
// Detected three objects.
score_tensor[kBicycleBboxIdx * kClassesNum + 1] = 1.0f; // bicycle.
score_tensor[kCarBboxIdx * kClassesNum + 2] = 0.9f; // car.
score_tensor[kMotoCycleBboxIdx * kClassesNum + 3] = 0.8f; // motorcycle.
// Send tensors and get results.
AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 90});
AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4});
MP_ASSERT_OK(Run());
// Validate results.
EXPECT_THAT(GetResult<std::vector<Detection>>(poller),
IsOkAndHolds(ElementsAre(Approximately(EqualsProto(
R"pb(
label: "bicycle"
score: 1
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.8137423
ymin: 0.067235775
width: 0.117221
height: 0.064774655
}
}
)pb")),
Approximately(EqualsProto(
R"pb(
label: "car"
score: 0.9
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.53849804
ymin: 0.08949606
width: 0.05861056
height: 0.11722109
}
}
)pb")),
Approximately(EqualsProto(
R"pb(
label: "motorcycle"
score: 0.8
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.13779688
ymin: 0.26394117
width: 0.16322193
height: 0.07384467
}
}
)pb")))));
}
} // namespace
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -23,6 +23,11 @@ mediapipe_proto_library(
srcs = ["classifier_options.proto"],
)
mediapipe_proto_library(
name = "detector_options_proto",
srcs = ["detector_options.proto"],
)
mediapipe_proto_library(
name = "classification_postprocessing_graph_options_proto",
srcs = ["classification_postprocessing_graph_options.proto"],
@ -35,6 +40,20 @@ mediapipe_proto_library(
],
)
mediapipe_proto_library(
name = "detection_postprocessing_graph_options_proto",
srcs = ["detection_postprocessing_graph_options.proto"],
deps = [
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_proto",
"//mediapipe/calculators/tflite:ssd_anchors_calculator_proto",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto",
"//mediapipe/calculators/util:non_max_suppression_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
],
)
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],

View File

@ -0,0 +1,49 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package mediapipe.tasks.components.processors.proto;
import "mediapipe/calculators/tensor/tensors_to_detections_calculator.proto";
import "mediapipe/calculators/tflite/ssd_anchors_calculator.proto";
import "mediapipe/calculators/util/detection_label_id_to_text_calculator.proto";
import "mediapipe/calculators/util/non_max_suppression_calculator.proto";
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
message DetectionPostprocessingGraphOptions {
// Optional SsdAnchorsCalculatorOptions for models without
// non-maximum-suppression in tflite model graph.
optional mediapipe.SsdAnchorsCalculatorOptions ssd_anchors_options = 1;
// Optional TensorsToDetectionsCalculatorOptions for models without
// non-maximum-suppression in tflite model graph.
optional mediapipe.TensorsToDetectionsCalculatorOptions
tensors_to_detections_options = 2;
// Optional NonMaxSuppressionCalculatorOptions for models without
// non-maximum-suppression in tflite model graph.
optional mediapipe.NonMaxSuppressionCalculatorOptions
non_max_suppression_options = 3;
// Optional score calibration options for models with non-maximum-suppression
// in tflite model graph.
optional ScoreCalibrationCalculatorOptions score_calibration_options = 4;
// Optional detection label id to text calculator options.
optional mediapipe.DetectionLabelIdToTextCalculatorOptions
detection_label_ids_to_text_options = 5;
}

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "DetectorOptionsProto";
// Shared options used by all detection tasks.
message DetectorOptions {
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
optional string display_names_locale = 1 [default = "en"];
// The maximum number of top-scored detection results to return. If < 0,
// all available results will be returned. If 0, an invalid argument error is
// returned.
optional int32 max_results = 2 [default = -1];
// Score threshold, overrides the ones provided in the model metadata
// (if any). Results below this value are rejected.
optional float score_threshold = 3;
// Overlapping threshold for non-maximum-suppression calculator. Only used for
// models without built-in non-maximum-suppression, i.e., models that don't
// use the Detection_Postprocess TFLite Op
optional float min_suppression_threshold = 6;
// Optional allowlist of category names. If non-empty, detections whose
// category name is not in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_denylist.
repeated string category_allowlist = 4;
// Optional denylist of category names. If non-empty, detection whose category
// name is in this set will be filtered out. Duplicate or unknown category
// names are ignored. Mutually exclusive with category_allowlist.
repeated string category_denylist = 5;
}

View File

@ -54,12 +54,7 @@ cc_library(
name = "object_detector_graph",
srcs = ["object_detector_graph.cc"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
"//mediapipe/calculators/util:detection_projection_calculator",
"//mediapipe/calculators/util:detection_transformation_calculator",
"//mediapipe/calculators/util:detections_deduplicate_calculator",
@ -71,19 +66,15 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/processors:detection_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],

View File

@ -99,7 +99,20 @@ struct ObjectDetectorOptions {
// - only RGB inputs are supported (`channels` is required to be 3).
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
// attached to the metadata for input normalization.
// Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e:
// Output tensors could be 2 output tensors or 4 output tensors.
// The 2 output tensors must represent locations and scores, respectively.
// (kTfLiteFloat32)
// - locations tensor of size `[num_results x num_coords]`. The num_coords is
// the number of coordinates a location result represent. Usually in the
// form: [4 + 2 * keypoint_num], where 4 location values encode the bounding
// box (y_center, x_center, height, width) and the additional keypoints are in
// (y, x) order.
// (kTfLiteFloat32)
// - scores tensor of size `[num_results x num_classes]`. The values of a
// result represent the classification probability belonging to the class at
// the index, which is denoted in the label file of corresponding tensor
// metadata in the model file.
// The 4 output tensors must come from `DetectionPostProcess` op, i.e:
// (kTfLiteFloat32)
// - locations tensor of size `[num_results x 4]`, the inner array
// representing bounding boxes in the form [top, left, right, bottom].

View File

@ -13,16 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <optional>
#include <type_traits>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
@ -31,19 +25,15 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/label_map_util.h"
namespace mediapipe {
namespace tasks {
@ -56,42 +46,18 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::BoundingBoxProperties;
using ::tflite::ContentProperties;
using ::tflite::ContentProperties_BoundingBoxProperties;
using ::tflite::EnumNameContentProperties;
using ::tflite::ProcessUnit;
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions;
using TensorsSource =
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
constexpr int kDefaultLocationsIndex = 0;
constexpr int kDefaultCategoriesIndex = 1;
constexpr int kDefaultScoresIndex = 2;
constexpr int kDefaultNumResultsIndex = 3;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kLocationTensorName[] = "location";
constexpr char kCategoryTensorName[] = "category";
constexpr char kScoreTensorName[] = "score";
constexpr char kNumberOfDetectionsTensorName[] = "number of detections";
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kImageTag[] = "IMAGE";
constexpr char kIndicesTag[] = "INDICES";
constexpr char kMatrixTag[] = "MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorTag[] = "TENSORS";
// Struct holding the different output streams produced by the object detection
@ -101,34 +67,6 @@ struct ObjectDetectionOutputStreams {
Source<Image> image;
};
// Parameters used for configuring the post-processing calculators.
struct PostProcessingSpecs {
// The maximum number of detection results to return.
int max_results;
// Indices of the output tensors to match the output tensors to the correct
// index order of the output tensors: [location, categories, scores,
// num_detections].
std::vector<int> output_tensor_indices;
// For each pack of 4 coordinates returned by the model, this denotes the
// order in which to get the left, top, right and bottom coordinates.
std::vector<unsigned int> bounding_box_corners_order;
// This is populated by reading the label files from the TFLite Model
// Metadata: if no such files are available, this is left empty and the
// ObjectDetector will only be able to populate the `index` field of the
// detection results.
LabelItems label_items;
// Score threshold. Detections with a confidence below this value are
// discarded. If none is provided via metadata or options, -FLT_MAX is set as
// default value.
float score_threshold;
// Set of category indices to be allowed/denied.
absl::flat_hash_set<int> allow_or_deny_categories;
// Indicates `allow_or_deny_categories` is an allowlist or a denylist.
bool is_allowlist;
// Score calibration options, if any.
std::optional<ScoreCalibrationCalculatorOptions> score_calibration_options;
};
absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) {
if (options.max_results() == 0) {
return CreateStatusWithPayload(
@ -147,310 +85,6 @@ absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) {
return absl::OkStatus();
}
absl::StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
const TensorMetadata& tensor_metadata) {
if (tensor_metadata.content() == nullptr ||
tensor_metadata.content()->content_properties() == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties for tensor %s, found none.",
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
ContentProperties type = tensor_metadata.content()->content_properties_type();
if (type != ContentProperties_BoundingBoxProperties) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties for tensor %s, found %s.",
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
EnumNameContentProperties(type)),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
const BoundingBoxProperties* properties =
tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
// Mobile SSD only supports "BOUNDARIES" bounding box type.
if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
tflite::EnumNameBoundingBoxType(properties->type())),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
// Mobile SSD only supports "RATIO" coordinates type.
if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Mobile SSD only supports CoordinateType RATIO, found %s",
tflite::EnumNameCoordinateType(properties->coordinate_type())),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
// Index is optional, but must contain 4 values if present.
if (properties->index() != nullptr && properties->index()->size() != 4) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"Expected BoundingBoxProperties index to contain 4 values, found "
"%d",
properties->index()->size()),
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
}
return properties;
}
absl::StatusOr<LabelItems> GetLabelItemsIfAny(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata, absl::string_view locale) {
const std::string labels_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
if (labels_filename.empty()) {
LabelItems empty_label_items;
return empty_label_items;
}
ASSIGN_OR_RETURN(absl::string_view labels_file,
metadata_extractor.GetAssociatedFile(labels_filename));
const std::string display_names_filename =
ModelMetadataExtractor::FindFirstAssociatedFileName(
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS,
locale);
absl::string_view display_names_file;
if (!display_names_filename.empty()) {
ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
display_names_filename));
}
return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file);
}
absl::StatusOr<float> GetScoreThreshold(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) {
ASSIGN_OR_RETURN(
const ProcessUnit* score_thresholding_process_unit,
metadata_extractor.FindFirstProcessUnit(
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
if (score_thresholding_process_unit == nullptr) {
return kDefaultScoreThreshold;
}
return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
->global_score_threshold();
}
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
const ObjectDetectorOptionsProto& config, const LabelItems& label_items) {
absl::flat_hash_set<int> category_indices;
// Exit early if no denylist/allowlist.
if (config.category_denylist_size() == 0 &&
config.category_allowlist_size() == 0) {
return category_indices;
}
if (label_items.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Using `category_allowlist` or `category_denylist` requires "
"labels to be present in the TFLite Model Metadata but none was found.",
MediaPipeTasksStatus::kMetadataMissingLabelsError);
}
const auto& category_list = config.category_allowlist_size() > 0
? config.category_allowlist()
: config.category_denylist();
for (const auto& category_name : category_list) {
int index = -1;
for (int i = 0; i < label_items.size(); ++i) {
if (label_items.at(i).name() == category_name) {
index = i;
break;
}
}
// Ignores duplicate or unknown categories.
if (index < 0) {
continue;
}
category_indices.insert(index);
}
return category_indices;
}
absl::StatusOr<std::optional<ScoreCalibrationCalculatorOptions>>
GetScoreCalibrationOptionsIfAny(
const ModelMetadataExtractor& metadata_extractor,
const TensorMetadata& tensor_metadata) {
// Get ScoreCalibrationOptions, if any.
ASSIGN_OR_RETURN(
const ProcessUnit* score_calibration_process_unit,
metadata_extractor.FindFirstProcessUnit(
tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions));
if (score_calibration_process_unit == nullptr) {
return std::nullopt;
}
auto* score_calibration_options =
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
// Get corresponding AssociatedFile.
auto score_calibration_filename =
metadata_extractor.FindFirstAssociatedFileName(
tensor_metadata,
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
if (score_calibration_filename.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kNotFound,
"Found ScoreCalibrationOptions but missing required associated "
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
}
ASSIGN_OR_RETURN(
absl::string_view score_calibration_file,
metadata_extractor.GetAssociatedFile(score_calibration_filename));
ScoreCalibrationCalculatorOptions score_calibration_calculator_options;
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
score_calibration_options->score_transformation(),
score_calibration_options->default_score(), score_calibration_file,
&score_calibration_calculator_options));
return score_calibration_calculator_options;
}
std::vector<int> GetOutputTensorIndices(
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
tensor_metadatas) {
std::vector<int> output_indices = {
core::FindTensorIndexByMetadataName(tensor_metadatas,
kLocationTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas,
kCategoryTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName),
core::FindTensorIndexByMetadataName(tensor_metadatas,
kNumberOfDetectionsTensorName)};
// locations, categories, scores, and number of detections
for (int i = 0; i < 4; i++) {
int output_index = output_indices[i];
// If tensor name is not found, set the default output indices.
if (output_index == -1) {
LOG(WARNING) << absl::StrFormat(
"You don't seem to be matching tensor names in metadata list. The "
"tensor name \"%s\" at index %d in the model metadata doesn't "
"match "
"the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].",
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
kCategoryTensorName, kScoreTensorName, kNumberOfDetectionsTensorName);
output_indices = {kDefaultLocationsIndex, kDefaultCategoriesIndex,
kDefaultScoresIndex, kDefaultNumResultsIndex};
return output_indices;
}
}
return output_indices;
}
// Builds PostProcessingSpecs from ObjectDetectorOptionsProto and model metadata
// for configuring the post-processing calculators.
absl::StatusOr<PostProcessingSpecs> BuildPostProcessingSpecs(
const ObjectDetectorOptionsProto& options,
const ModelMetadataExtractor* metadata_extractor) {
// Checks output tensor metadata is present and consistent with model.
auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr ||
output_tensors_metadata->size() != 4) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Mismatch between number of output tensors (4) and "
"output tensors metadata (%d).",
output_tensors_metadata == nullptr
? 0
: output_tensors_metadata->size()),
MediaPipeTasksStatus::kMetadataInconsistencyError);
}
PostProcessingSpecs specs;
specs.max_results = options.max_results();
specs.output_tensor_indices = GetOutputTensorIndices(output_tensors_metadata);
// Extracts mandatory BoundingBoxProperties and performs sanity checks on the
// fly.
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
GetBoundingBoxProperties(*output_tensors_metadata->Get(
specs.output_tensor_indices[0])));
if (bounding_box_properties->index() == nullptr) {
specs.bounding_box_corners_order = {0, 1, 2, 3};
} else {
auto bounding_box_index = bounding_box_properties->index();
specs.bounding_box_corners_order = {
bounding_box_index->Get(0),
bounding_box_index->Get(1),
bounding_box_index->Get(2),
bounding_box_index->Get(3),
};
}
// Builds label map (if available) from metadata.
ASSIGN_OR_RETURN(specs.label_items,
GetLabelItemsIfAny(*metadata_extractor,
*output_tensors_metadata->Get(
specs.output_tensor_indices[1]),
options.display_names_locale()));
// Obtains allow/deny categories.
specs.is_allowlist = !options.category_allowlist().empty();
ASSIGN_OR_RETURN(
specs.allow_or_deny_categories,
GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items));
// Sets score threshold.
if (options.has_score_threshold()) {
specs.score_threshold = options.score_threshold();
} else {
ASSIGN_OR_RETURN(specs.score_threshold,
GetScoreThreshold(*metadata_extractor,
*output_tensors_metadata->Get(
specs.output_tensor_indices[2])));
}
// Builds score calibration options (if available) from metadata.
ASSIGN_OR_RETURN(
specs.score_calibration_options,
GetScoreCalibrationOptionsIfAny(
*metadata_extractor,
*output_tensors_metadata->Get(specs.output_tensor_indices[2])));
return specs;
}
// Fills in the TensorsToDetectionsCalculatorOptions based on
// PostProcessingSpecs.
void ConfigureTensorsToDetectionsCalculator(
const PostProcessingSpecs& specs,
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
options->set_num_classes(specs.label_items.size());
options->set_num_coords(4);
options->set_min_score_thresh(specs.score_threshold);
if (specs.max_results != -1) {
options->set_max_results(specs.max_results);
}
if (specs.is_allowlist) {
options->mutable_allow_classes()->Assign(
specs.allow_or_deny_categories.begin(),
specs.allow_or_deny_categories.end());
} else {
options->mutable_ignore_classes()->Assign(
specs.allow_or_deny_categories.begin(),
specs.allow_or_deny_categories.end());
}
const auto& output_indices = specs.output_tensor_indices;
// Assigns indices to each the model output tensor.
auto* tensor_mapping = options->mutable_tensor_mapping();
tensor_mapping->set_detections_tensor_index(output_indices[0]);
tensor_mapping->set_classes_tensor_index(output_indices[1]);
tensor_mapping->set_scores_tensor_index(output_indices[2]);
tensor_mapping->set_num_detections_tensor_index(output_indices[3]);
// Assigns the bounding box corner order.
auto box_boundaries_indices = options->mutable_box_boundaries_indices();
box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]);
box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]);
box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]);
box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]);
}
} // namespace
// A "mediapipe.tasks.vision.ObjectDetectorGraph" performs object detection.
@ -530,7 +164,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Checks that the model has 4 outputs.
auto& model = *model_resources.GetTfLiteModel();
if (model.subgraphs()->size() != 1) {
return CreateStatusWithPayload(
@ -539,13 +172,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
model.subgraphs()->size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
if (model.subgraphs()->Get(0)->outputs()->size() != 4) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected a model with 4 output tensors, found %d.",
model.subgraphs()->Get(0)->outputs()->size()),
MediaPipeTasksStatus::kInvalidArgumentError);
}
// Checks that metadata is available.
auto* metadata_extractor = model_resources.GetMetadataExtractor();
if (metadata_extractor->GetModelMetadata() == nullptr ||
@ -577,70 +203,36 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
preprocessing.Out(kTensorTag) >> inference.In(kTensorTag);
// Adds post processing calculators.
ASSIGN_OR_RETURN(
auto post_processing_specs,
BuildPostProcessingSpecs(task_options, metadata_extractor));
// Calculators to perform score calibration, if specified in the metadata.
TensorsSource calibrated_tensors =
TensorsSource model_output_tensors =
inference.Out(kTensorTag).Cast<std::vector<Tensor>>();
if (post_processing_specs.score_calibration_options.has_value()) {
// Split tensors.
auto* split_tensor_vector_node =
&graph.AddNode("SplitTensorVectorCalculator");
auto& split_tensor_vector_options =
split_tensor_vector_node
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
for (int i = 0; i < 4; ++i) {
auto* range = split_tensor_vector_options.add_ranges();
range->set_begin(i);
range->set_end(i + 1);
}
calibrated_tensors >> split_tensor_vector_node->In(0);
// Add score calibration calculator.
auto* score_calibration_node =
&graph.AddNode("ScoreCalibrationCalculator");
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
.CopyFrom(*post_processing_specs.score_calibration_options);
split_tensor_vector_node->Out(
post_processing_specs.output_tensor_indices[1]) >>
score_calibration_node->In(kIndicesTag);
split_tensor_vector_node->Out(
post_processing_specs.output_tensor_indices[2]) >>
score_calibration_node->In(kScoresTag);
// Re-concatenate tensors.
auto* concatenate_tensor_vector_node =
&graph.AddNode("ConcatenateTensorVectorCalculator");
for (int i = 0; i < 4; ++i) {
if (i == post_processing_specs.output_tensor_indices[2]) {
score_calibration_node->Out(kCalibratedScoresTag) >>
concatenate_tensor_vector_node->In(i);
} else {
split_tensor_vector_node->Out(i) >>
concatenate_tensor_vector_node->In(i);
}
}
calibrated_tensors =
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
}
// Calculator to convert output tensors to a detection proto vector.
// Connects TensorsToDetectionsCalculator's input stream to the output
// tensors produced by the inference subgraph.
auto& tensors_to_detections =
graph.AddNode("TensorsToDetectionsCalculator");
ConfigureTensorsToDetectionsCalculator(
post_processing_specs,
&tensors_to_detections
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
calibrated_tensors >> tensors_to_detections.In(kTensorTag);
// Add Detection postprocessing graph to convert tensors to detections.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors.DetectionPostprocessingGraph");
components::processors::proto::DetectorOptions detector_options;
detector_options.set_max_results(task_options.max_results());
detector_options.set_score_threshold(task_options.score_threshold());
detector_options.set_display_names_locale(
task_options.display_names_locale());
detector_options.mutable_category_allowlist()->CopyFrom(
task_options.category_allowlist());
detector_options.mutable_category_denylist()->CopyFrom(
task_options.category_denylist());
// TODO: expose min suppression threshold in
// ObjectDetectorOptions.
detector_options.set_min_suppression_threshold(0.3);
MP_RETURN_IF_ERROR(
components::processors::ConfigureDetectionPostprocessingGraph(
model_resources, detector_options,
postprocessing
.GetOptions<components::processors::proto::
DetectionPostprocessingGraphOptions>()));
model_output_tensors >> postprocessing.In(kTensorTag);
auto detections = postprocessing.Out(kDetectionsTag);
// Calculator to projects detections back to the original coordinate system.
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
tensors_to_detections.Out(kDetectionsTag) >>
detection_projection.In(kDetectionsTag);
detections >> detection_projection.In(kDetectionsTag);
preprocessing.Out(kMatrixTag) >>
detection_projection.In(kProjectionMatrixTag);
@ -652,22 +244,13 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
detection_transformation.In(kDetectionsTag);
preprocessing.Out(kImageSizeTag) >>
detection_transformation.In(kImageSizeTag);
// Calculator to assign detection labels.
auto& detection_label_id_to_text =
graph.AddNode("DetectionLabelIdToTextCalculator");
auto& detection_label_id_to_text_opts =
detection_label_id_to_text
.GetOptions<mediapipe::DetectionLabelIdToTextCalculatorOptions>();
*detection_label_id_to_text_opts.mutable_label_items() =
std::move(post_processing_specs.label_items);
detection_transformation.Out(kPixelDetectionsTag) >>
detection_label_id_to_text.In("");
auto detections_in_pixel =
detection_transformation.Out(kPixelDetectionsTag);
// Deduplicate Detections with same bounding box coordinates.
auto& detections_deduplicate =
graph.AddNode("DetectionsDeduplicateCalculator");
detection_label_id_to_text.Out("") >> detections_deduplicate.In("");
detections_in_pixel >> detections_deduplicate.In("");
// Outputs the labeled detections and the processed image as the subgraph
// output streams.

View File

@ -76,15 +76,18 @@ using ::testing::HasSubstr;
using ::testing::Optional;
using DetectionProto = mediapipe::Detection;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kMobileSsdWithMetadata[] =
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
constexpr char kMobileSsdWithDummyScoreCalibration[] =
constexpr absl::string_view kTestDataDirectory{
"/mediapipe/tasks/testdata/vision/"};
constexpr absl::string_view kMobileSsdWithMetadata{
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"};
constexpr absl::string_view kMobileSsdWithDummyScoreCalibration{
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration."
"tflite";
"tflite"};
// The model has different output tensor order.
constexpr char kEfficientDetWithMetadata[] =
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
constexpr absl::string_view kEfficientDetWithMetadata{
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"};
constexpr absl::string_view kEfficientDetWithoutNms{
"efficientdet_lite0_fp16_no_nms.tflite"};
// Checks that the two provided `Detection` proto vectors are equal, with a
// tolerancy on floating-point scores to account for numerical instabilities.
@ -451,6 +454,51 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
})pb")}));
}
TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"cats_and_dogs.jpg")));
auto options = std::make_unique<ObjectDetectorOptions>();
options->max_results = 4;
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
MP_ASSERT_OK(object_detector->Close());
ExpectApproximatelyEqual(
results,
ConvertToDetectionResult(
{ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "dog"
score: 0.733542
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 636 ymin: 160 width: 282 height: 451 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.699751
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 870 ymin: 411 width: 208 height: 187 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "dog"
score: 0.682425
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 386 ymin: 216 width: 256 height: 376 }
})pb"),
ParseTextProtoOrDie<DetectionProto>(R"pb(
label: "cat"
score: 0.646585
location_data {
format: BOUNDING_BOX
bounding_box { xmin: 83 ymin: 399 width: 347 height: 198 }
})pb")}));
}
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
"./", kTestDataDirectory,

View File

@ -41,6 +41,7 @@ mediapipe_files(srcs = [
"conv2d_input_channel_1.tflite",
"deeplabv3.tflite",
"dense.tflite",
"efficientdet_lite0_fp16_no_nms.tflite",
"face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite",
"face_detection_short_range.tflite",
@ -167,6 +168,7 @@ filegroup(
"conv2d_input_channel_1.tflite",
"deeplabv3.tflite",
"dense.tflite",
"efficientdet_lite0_fp16_no_nms.tflite",
"face_detection_full_range.tflite",
"face_detection_full_range_sparse.tflite",
"face_detection_short_range.tflite",

View File

@ -276,8 +276,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_tflite",
sha256 = "bcda125c96d3767bca894c8cbe7bc458379c9974c9fd8bdc6204e7124a74082a",
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682456096034465"],
sha256 = "237a58389081333e5cf4154e42b593ce7dd357445536fcaf4ca5bc51c2c50f1c",
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682476299542472"],
)
http_file(