DetectionPostProcessingGraph for post processing raw tensors from detection models.
PiperOrigin-RevId: 527363291
This commit is contained in:
parent
48aa88f39d
commit
c44cc30ece
|
@ -161,3 +161,49 @@ cc_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_postprocessing_graph",
|
||||
srcs = ["detection_postprocessing_graph.cc"],
|
||||
hdrs = ["detection_postprocessing_graph.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/formats/object_detection:anchor_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/tasks/metadata:object_detector_metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,886 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/object_detection/anchor.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "mediapipe/tasks/metadata/object_detector_metadata_schema_generated.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/label_map_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::flatbuffers::Offset;
|
||||
using ::flatbuffers::Vector;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::tflite::BoundingBoxProperties;
|
||||
using ::tflite::ContentProperties;
|
||||
using ::tflite::ContentProperties_BoundingBoxProperties;
|
||||
using ::tflite::EnumNameContentProperties;
|
||||
using ::tflite::ProcessUnit;
|
||||
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
|
||||
using ::tflite::TensorMetadata;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
|
||||
using TensorsSource =
|
||||
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
|
||||
|
||||
constexpr int kInModelNmsDefaultLocationsIndex = 0;
|
||||
constexpr int kInModelNmsDefaultCategoriesIndex = 1;
|
||||
constexpr int kInModelNmsDefaultScoresIndex = 2;
|
||||
constexpr int kInModelNmsDefaultNumResultsIndex = 3;
|
||||
|
||||
constexpr int kOutModelNmsDefaultLocationsIndex = 0;
|
||||
constexpr int kOutModelNmsDefaultScoresIndex = 1;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
constexpr absl::string_view kLocationTensorName = "location";
|
||||
constexpr absl::string_view kCategoryTensorName = "category";
|
||||
constexpr absl::string_view kScoreTensorName = "score";
|
||||
constexpr absl::string_view kNumberOfDetectionsTensorName =
|
||||
"number of detections";
|
||||
constexpr absl::string_view kDetectorMetadataName = "DETECTOR_METADATA";
|
||||
constexpr absl::string_view kCalibratedScoresTag = "CALIBRATED_SCORES";
|
||||
constexpr absl::string_view kDetectionsTag = "DETECTIONS";
|
||||
constexpr absl::string_view kIndicesTag = "INDICES";
|
||||
constexpr absl::string_view kScoresTag = "SCORES";
|
||||
constexpr absl::string_view kTensorsTag = "TENSORS";
|
||||
constexpr absl::string_view kAnchorsTag = "ANCHORS";
|
||||
|
||||
// Struct holding the different output streams produced by the graph.
|
||||
struct DetectionPostprocessingOutputStreams {
|
||||
Source<std::vector<Detection>> detections;
|
||||
};
|
||||
|
||||
// Parameters used for configuring the post-processing calculators.
|
||||
struct PostProcessingSpecs {
|
||||
// The maximum number of detection results to return.
|
||||
int max_results;
|
||||
// Indices of the output tensors to match the output tensors to the correct
|
||||
// index order of the output tensors: [location, categories, scores,
|
||||
// num_detections].
|
||||
std::vector<int> output_tensor_indices;
|
||||
// For each pack of 4 coordinates returned by the model, this denotes the
|
||||
// order in which to get the left, top, right and bottom coordinates.
|
||||
std::vector<unsigned int> bounding_box_corners_order;
|
||||
// This is populated by reading the label files from the TFLite Model
|
||||
// Metadata: if no such files are available, this is left empty and the
|
||||
// ObjectDetector will only be able to populate the `index` field of the
|
||||
// detection results.
|
||||
LabelItems label_items;
|
||||
// Score threshold. Detections with a confidence below this value are
|
||||
// discarded. If none is provided via metadata or options, -FLT_MAX is set as
|
||||
// default value.
|
||||
float score_threshold;
|
||||
// Set of category indices to be allowed/denied.
|
||||
absl::flat_hash_set<int> allow_or_deny_categories;
|
||||
// Indicates `allow_or_deny_categories` is an allowlist or a denylist.
|
||||
bool is_allowlist;
|
||||
// Score calibration options, if any.
|
||||
std::optional<ScoreCalibrationCalculatorOptions> score_calibration_options;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(const proto::DetectorOptions& options) {
|
||||
if (options.max_results() == 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Invalid `max_results` option: value must be != 0",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
if (options.category_allowlist_size() > 0 &&
|
||||
options.category_denylist_size() > 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"`category_allowlist` and `category_denylist` are mutually "
|
||||
"exclusive options.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
|
||||
const TensorMetadata& tensor_metadata) {
|
||||
if (tensor_metadata.content() == nullptr ||
|
||||
tensor_metadata.content()->content_properties() == nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected BoundingBoxProperties for tensor %s, found none.",
|
||||
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
ContentProperties type = tensor_metadata.content()->content_properties_type();
|
||||
if (type != ContentProperties_BoundingBoxProperties) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected BoundingBoxProperties for tensor %s, found %s.",
|
||||
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
|
||||
EnumNameContentProperties(type)),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
const BoundingBoxProperties* properties =
|
||||
tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
|
||||
|
||||
// Mobile SSD only supports "BOUNDARIES" bounding box type.
|
||||
if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
|
||||
tflite::EnumNameBoundingBoxType(properties->type())),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
// Mobile SSD only supports "RATIO" coordinates type.
|
||||
if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Mobile SSD only supports CoordinateType RATIO, found %s",
|
||||
tflite::EnumNameCoordinateType(properties->coordinate_type())),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
// Index is optional, but must contain 4 values if present.
|
||||
if (properties->index() != nullptr && properties->index()->size() != 4) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected BoundingBoxProperties index to contain 4 values, found "
|
||||
"%d",
|
||||
properties->index()->size()),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
return properties;
|
||||
}
|
||||
|
||||
absl::StatusOr<LabelItems> GetLabelItemsIfAny(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
const TensorMetadata& tensor_metadata,
|
||||
tflite::AssociatedFileType associated_file_type, absl::string_view locale) {
|
||||
const std::string labels_filename =
|
||||
ModelMetadataExtractor::FindFirstAssociatedFileName(tensor_metadata,
|
||||
associated_file_type);
|
||||
if (labels_filename.empty()) {
|
||||
LabelItems empty_label_items;
|
||||
return empty_label_items;
|
||||
}
|
||||
ASSIGN_OR_RETURN(absl::string_view labels_file,
|
||||
metadata_extractor.GetAssociatedFile(labels_filename));
|
||||
const std::string display_names_filename =
|
||||
ModelMetadataExtractor::FindFirstAssociatedFileName(
|
||||
tensor_metadata, associated_file_type, locale);
|
||||
absl::string_view display_names_file;
|
||||
if (!display_names_filename.empty()) {
|
||||
ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
|
||||
display_names_filename));
|
||||
}
|
||||
return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file);
|
||||
}
|
||||
|
||||
absl::StatusOr<float> GetScoreThreshold(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
const TensorMetadata& tensor_metadata) {
|
||||
ASSIGN_OR_RETURN(
|
||||
const ProcessUnit* score_thresholding_process_unit,
|
||||
metadata_extractor.FindFirstProcessUnit(
|
||||
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
|
||||
if (score_thresholding_process_unit == nullptr) {
|
||||
return kDefaultScoreThreshold;
|
||||
}
|
||||
return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
|
||||
->global_score_threshold();
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||
const proto::DetectorOptions& config, const LabelItems& label_items) {
|
||||
absl::flat_hash_set<int> category_indices;
|
||||
// Exit early if no denylist/allowlist.
|
||||
if (config.category_denylist_size() == 0 &&
|
||||
config.category_allowlist_size() == 0) {
|
||||
return category_indices;
|
||||
}
|
||||
if (label_items.empty()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Using `category_allowlist` or `category_denylist` requires "
|
||||
"labels to be present in the TFLite Model Metadata but none was found.",
|
||||
MediaPipeTasksStatus::kMetadataMissingLabelsError);
|
||||
}
|
||||
const auto& category_list = config.category_allowlist_size() > 0
|
||||
? config.category_allowlist()
|
||||
: config.category_denylist();
|
||||
for (const auto& category_name : category_list) {
|
||||
int index = -1;
|
||||
for (int i = 0; i < label_items.size(); ++i) {
|
||||
if (label_items.at(i).name() == category_name) {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Ignores duplicate or unknown categories.
|
||||
if (index < 0) {
|
||||
continue;
|
||||
}
|
||||
category_indices.insert(index);
|
||||
}
|
||||
return category_indices;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::optional<ScoreCalibrationCalculatorOptions>>
|
||||
GetScoreCalibrationOptionsIfAny(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
const TensorMetadata& tensor_metadata) {
|
||||
// Get ScoreCalibrationOptions, if any.
|
||||
ASSIGN_OR_RETURN(
|
||||
const ProcessUnit* score_calibration_process_unit,
|
||||
metadata_extractor.FindFirstProcessUnit(
|
||||
tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions));
|
||||
if (score_calibration_process_unit == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto* score_calibration_options =
|
||||
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
|
||||
// Get corresponding AssociatedFile.
|
||||
auto score_calibration_filename =
|
||||
metadata_extractor.FindFirstAssociatedFileName(
|
||||
tensor_metadata,
|
||||
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
|
||||
if (score_calibration_filename.empty()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kNotFound,
|
||||
"Found ScoreCalibrationOptions but missing required associated "
|
||||
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
absl::string_view score_calibration_file,
|
||||
metadata_extractor.GetAssociatedFile(score_calibration_filename));
|
||||
ScoreCalibrationCalculatorOptions score_calibration_calculator_options;
|
||||
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
|
||||
score_calibration_options->score_transformation(),
|
||||
score_calibration_options->default_score(), score_calibration_file,
|
||||
&score_calibration_calculator_options));
|
||||
return score_calibration_calculator_options;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<int>> GetOutputTensorIndices(
|
||||
const Vector<Offset<TensorMetadata>>* tensor_metadatas) {
|
||||
std::vector<int> output_indices;
|
||||
if (tensor_metadatas->size() == 4) {
|
||||
output_indices = {
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kLocationTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kCategoryTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kNumberOfDetectionsTensorName)};
|
||||
// locations, categories, scores, and number of detections
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int output_index = output_indices[i];
|
||||
// If tensor name is not found, set the default output indices.
|
||||
if (output_index == -1) {
|
||||
LOG(WARNING) << absl::StrFormat(
|
||||
"You don't seem to be matching tensor names in metadata list. The "
|
||||
"tensor name \"%s\" at index %d in the model metadata doesn't "
|
||||
"match "
|
||||
"the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].",
|
||||
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
|
||||
kCategoryTensorName, kScoreTensorName,
|
||||
kNumberOfDetectionsTensorName);
|
||||
output_indices = {
|
||||
kInModelNmsDefaultLocationsIndex, kInModelNmsDefaultCategoriesIndex,
|
||||
kInModelNmsDefaultScoresIndex, kInModelNmsDefaultNumResultsIndex};
|
||||
return output_indices;
|
||||
}
|
||||
}
|
||||
} else if (tensor_metadatas->size() == 2) {
|
||||
output_indices = {core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kLocationTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kScoreTensorName)};
|
||||
// location, score
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int output_index = output_indices[i];
|
||||
// If tensor name is not found, set the default output indices.
|
||||
if (output_index == -1) {
|
||||
LOG(WARNING) << absl::StrFormat(
|
||||
"You don't seem to be matching tensor names in metadata list. The "
|
||||
"tensor name \"%s\" at index %d in the model metadata doesn't "
|
||||
"match "
|
||||
"the available output names: [\"%s\", \"%s\"].",
|
||||
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
|
||||
kScoreTensorName);
|
||||
output_indices = {kOutModelNmsDefaultLocationsIndex,
|
||||
kOutModelNmsDefaultScoresIndex};
|
||||
return output_indices;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected a model with 2 or 4 output tensors metadata, found %d.",
|
||||
tensor_metadatas->size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
return output_indices;
|
||||
}
|
||||
|
||||
// Builds PostProcessingSpecs from DetectorOptions and model metadata for
|
||||
// configuring the post-processing calculators.
|
||||
absl::StatusOr<PostProcessingSpecs> BuildPostProcessingSpecs(
|
||||
const proto::DetectorOptions& options, bool in_model_nms,
|
||||
const ModelMetadataExtractor* metadata_extractor) {
|
||||
const auto* output_tensors_metadata =
|
||||
metadata_extractor->GetOutputTensorMetadata();
|
||||
PostProcessingSpecs specs;
|
||||
specs.max_results = options.max_results();
|
||||
ASSIGN_OR_RETURN(specs.output_tensor_indices,
|
||||
GetOutputTensorIndices(output_tensors_metadata));
|
||||
// Extracts mandatory BoundingBoxProperties and performs sanity checks on the
|
||||
// fly.
|
||||
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
|
||||
GetBoundingBoxProperties(*output_tensors_metadata->Get(
|
||||
specs.output_tensor_indices[0])));
|
||||
if (bounding_box_properties->index() == nullptr) {
|
||||
specs.bounding_box_corners_order = {0, 1, 2, 3};
|
||||
} else {
|
||||
auto bounding_box_index = bounding_box_properties->index();
|
||||
specs.bounding_box_corners_order = {
|
||||
bounding_box_index->Get(0),
|
||||
bounding_box_index->Get(1),
|
||||
bounding_box_index->Get(2),
|
||||
bounding_box_index->Get(3),
|
||||
};
|
||||
}
|
||||
// Builds label map (if available) from metadata.
|
||||
// For models with in-model-nms, the label map is stored in the Category
|
||||
// tensor which use TENSOR_VALUE_LABELS. For models with out-of-model-nms, the
|
||||
// label map is stored in the Score tensor which use TENSOR_AXIS_LABELS.
|
||||
ASSIGN_OR_RETURN(
|
||||
specs.label_items,
|
||||
GetLabelItemsIfAny(
|
||||
*metadata_extractor,
|
||||
*output_tensors_metadata->Get(specs.output_tensor_indices[1]),
|
||||
in_model_nms ? tflite::AssociatedFileType_TENSOR_VALUE_LABELS
|
||||
: tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
|
||||
options.display_names_locale()));
|
||||
// Obtains allow/deny categories.
|
||||
specs.is_allowlist = !options.category_allowlist().empty();
|
||||
ASSIGN_OR_RETURN(
|
||||
specs.allow_or_deny_categories,
|
||||
GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items));
|
||||
|
||||
// Sets score threshold.
|
||||
if (options.has_score_threshold()) {
|
||||
specs.score_threshold = options.score_threshold();
|
||||
} else {
|
||||
ASSIGN_OR_RETURN(
|
||||
specs.score_threshold,
|
||||
GetScoreThreshold(
|
||||
*metadata_extractor,
|
||||
*output_tensors_metadata->Get(
|
||||
specs.output_tensor_indices
|
||||
[in_model_nms ? kInModelNmsDefaultScoresIndex
|
||||
: kOutModelNmsDefaultScoresIndex])));
|
||||
}
|
||||
if (in_model_nms) {
|
||||
// Builds score calibration options (if available) from metadata.
|
||||
ASSIGN_OR_RETURN(
|
||||
specs.score_calibration_options,
|
||||
GetScoreCalibrationOptionsIfAny(
|
||||
*metadata_extractor,
|
||||
*output_tensors_metadata->Get(
|
||||
specs.output_tensor_indices[kInModelNmsDefaultScoresIndex])));
|
||||
}
|
||||
return specs;
|
||||
}
|
||||
|
||||
// Builds PostProcessingSpecs from DetectorOptions and model metadata for
|
||||
// configuring the post-processing calculators for models with
|
||||
// non-maximum-suppression.
|
||||
absl::StatusOr<PostProcessingSpecs> BuildInModelNmsPostProcessingSpecs(
|
||||
const proto::DetectorOptions& options,
|
||||
const ModelMetadataExtractor* metadata_extractor) {
|
||||
// Checks output tensor metadata is present and consistent with model.
|
||||
auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr ||
|
||||
output_tensors_metadata->size() != 4) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Mismatch between number of output tensors (4) and "
|
||||
"output tensors metadata (%d).",
|
||||
output_tensors_metadata == nullptr
|
||||
? 0
|
||||
: output_tensors_metadata->size()),
|
||||
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||
}
|
||||
return BuildPostProcessingSpecs(options, /*in_model_nms=*/true,
|
||||
metadata_extractor);
|
||||
}
|
||||
|
||||
// Fills in the TensorsToDetectionsCalculatorOptions based on
|
||||
// PostProcessingSpecs.
|
||||
void ConfigureInModelNmsTensorsToDetectionsCalculator(
|
||||
const PostProcessingSpecs& specs,
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
|
||||
options->set_num_classes(specs.label_items.size());
|
||||
options->set_num_coords(4);
|
||||
options->set_min_score_thresh(specs.score_threshold);
|
||||
if (specs.max_results != -1) {
|
||||
options->set_max_results(specs.max_results);
|
||||
}
|
||||
if (specs.is_allowlist) {
|
||||
options->mutable_allow_classes()->Assign(
|
||||
specs.allow_or_deny_categories.begin(),
|
||||
specs.allow_or_deny_categories.end());
|
||||
} else {
|
||||
options->mutable_ignore_classes()->Assign(
|
||||
specs.allow_or_deny_categories.begin(),
|
||||
specs.allow_or_deny_categories.end());
|
||||
}
|
||||
|
||||
const auto& output_indices = specs.output_tensor_indices;
|
||||
// Assigns indices to each the model output tensor.
|
||||
auto* tensor_mapping = options->mutable_tensor_mapping();
|
||||
tensor_mapping->set_detections_tensor_index(output_indices[0]);
|
||||
tensor_mapping->set_classes_tensor_index(output_indices[1]);
|
||||
tensor_mapping->set_scores_tensor_index(output_indices[2]);
|
||||
tensor_mapping->set_num_detections_tensor_index(output_indices[3]);
|
||||
|
||||
// Assigns the bounding box corner order.
|
||||
auto box_boundaries_indices = options->mutable_box_boundaries_indices();
|
||||
box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]);
|
||||
box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]);
|
||||
box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]);
|
||||
box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]);
|
||||
}
|
||||
|
||||
// Builds PostProcessingSpecs from DetectorOptions and model metadata for
|
||||
// configuring the post-processing calculators for models without
|
||||
// non-maximum-suppression.
|
||||
absl::StatusOr<PostProcessingSpecs> BuildOutModelNmsPostProcessingSpecs(
|
||||
const proto::DetectorOptions& options,
|
||||
const ModelMetadataExtractor* metadata_extractor) {
|
||||
// Checks output tensor metadata is present and consistent with model.
|
||||
auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr ||
|
||||
output_tensors_metadata->size() != 2) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Mismatch between number of output tensors (2) and "
|
||||
"output tensors metadata (%d).",
|
||||
output_tensors_metadata == nullptr
|
||||
? 0
|
||||
: output_tensors_metadata->size()),
|
||||
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||
}
|
||||
return BuildPostProcessingSpecs(options, /*in_model_nms=*/false,
|
||||
metadata_extractor);
|
||||
}
|
||||
|
||||
// Configures the TensorsToDetectionCalculator for models without
|
||||
// non-maximum-suppression in tflite model. The required config parameters are
|
||||
// extracted from the ObjectDetectorMetadata
|
||||
// (metadata/object_detector_metadata_schema.fbs).
|
||||
absl::Status ConfigureOutModelNmsTensorsToDetectionsCalculator(
|
||||
const ModelMetadataExtractor* metadata_extractor,
|
||||
const PostProcessingSpecs& specs,
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
|
||||
bool found_detector_metadata = false;
|
||||
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
|
||||
metadata_extractor->GetCustomMetadataList()->size() > 0) {
|
||||
for (const auto* custom_metadata :
|
||||
*metadata_extractor->GetCustomMetadataList()) {
|
||||
if (custom_metadata->name()->str() == kDetectorMetadataName) {
|
||||
found_detector_metadata = true;
|
||||
const auto* tensors_decoding_options =
|
||||
GetObjectDetectorOptions(custom_metadata->data()->data())
|
||||
->tensors_decoding_options();
|
||||
// Here we don't set the max results for TensorsToDetectionsCalculator.
|
||||
// For models without nms, the results are filtered by max_results in
|
||||
// NonMaxSuppressionCalculator.
|
||||
options->set_num_classes(tensors_decoding_options->num_classes());
|
||||
options->set_num_boxes(tensors_decoding_options->num_boxes());
|
||||
options->set_num_coords(tensors_decoding_options->num_coords());
|
||||
options->set_keypoint_coord_offset(
|
||||
tensors_decoding_options->keypoint_coord_offset());
|
||||
options->set_num_keypoints(tensors_decoding_options->num_keypoints());
|
||||
options->set_num_values_per_keypoint(
|
||||
tensors_decoding_options->num_values_per_keypoint());
|
||||
options->set_x_scale(tensors_decoding_options->x_scale());
|
||||
options->set_y_scale(tensors_decoding_options->y_scale());
|
||||
options->set_w_scale(tensors_decoding_options->w_scale());
|
||||
options->set_h_scale(tensors_decoding_options->h_scale());
|
||||
options->set_apply_exponential_on_box_size(
|
||||
tensors_decoding_options->apply_exponential_on_box_size());
|
||||
options->set_sigmoid_score(tensors_decoding_options->sigmoid_score());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found_detector_metadata) {
|
||||
return absl::InvalidArgumentError(
|
||||
"TensorsDecodingOptions is not found in the object detector "
|
||||
"metadata.");
|
||||
}
|
||||
// Options not configured through metadata.
|
||||
options->set_box_format(
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions::YXHW);
|
||||
options->set_min_score_thresh(specs.score_threshold);
|
||||
if (specs.is_allowlist) {
|
||||
options->mutable_allow_classes()->Assign(
|
||||
specs.allow_or_deny_categories.begin(),
|
||||
specs.allow_or_deny_categories.end());
|
||||
} else {
|
||||
options->mutable_ignore_classes()->Assign(
|
||||
specs.allow_or_deny_categories.begin(),
|
||||
specs.allow_or_deny_categories.end());
|
||||
}
|
||||
|
||||
const auto& output_indices = specs.output_tensor_indices;
|
||||
// Assigns indices to each the model output tensor.
|
||||
auto* tensor_mapping = options->mutable_tensor_mapping();
|
||||
tensor_mapping->set_detections_tensor_index(output_indices[0]);
|
||||
tensor_mapping->set_scores_tensor_index(output_indices[1]);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Configures the SsdAnchorsCalculator for models without
|
||||
// non-maximum-suppression in tflite model. The required config parameters are
|
||||
// extracted from the ObjectDetectorMetadata
|
||||
// (metadata/object_detector_metadata_schema.fbs).
|
||||
absl::Status ConfigureSsdAnchorsCalculator(
|
||||
const ModelMetadataExtractor* metadata_extractor,
|
||||
mediapipe::SsdAnchorsCalculatorOptions* options) {
|
||||
bool found_detector_metadata = false;
|
||||
if (metadata_extractor->GetCustomMetadataList() != nullptr &&
|
||||
metadata_extractor->GetCustomMetadataList()->size() > 0) {
|
||||
for (const auto* custom_metadata :
|
||||
*metadata_extractor->GetCustomMetadataList()) {
|
||||
if (custom_metadata->name()->str() == kDetectorMetadataName) {
|
||||
found_detector_metadata = true;
|
||||
const auto* ssd_anchors_options =
|
||||
GetObjectDetectorOptions(custom_metadata->data()->data())
|
||||
->ssd_anchors_options();
|
||||
for (const auto* ssd_anchor :
|
||||
*ssd_anchors_options->fixed_anchors_schema()->anchors()) {
|
||||
auto* fixed_anchor = options->add_fixed_anchors();
|
||||
fixed_anchor->set_y_center(ssd_anchor->y_center());
|
||||
fixed_anchor->set_x_center(ssd_anchor->x_center());
|
||||
fixed_anchor->set_h(ssd_anchor->height());
|
||||
fixed_anchor->set_w(ssd_anchor->width());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found_detector_metadata) {
|
||||
return absl::InvalidArgumentError(
|
||||
"SsdAnchorsOptions is not found in the object detector "
|
||||
"metadata.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Sets the default IoU-based non-maximum-suppression configs, and set the
|
||||
// min_suppression_threshold and max_results for detection models without
|
||||
// non-maximum-suppression.
|
||||
void ConfigureNonMaxSuppressionCalculator(
|
||||
const proto::DetectorOptions& detector_options,
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions* options) {
|
||||
options->set_min_suppression_threshold(
|
||||
detector_options.min_suppression_threshold());
|
||||
options->set_overlap_type(
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION);
|
||||
options->set_algorithm(
|
||||
mediapipe::NonMaxSuppressionCalculatorOptions::DEFAULT);
|
||||
options->set_max_num_detections(detector_options.max_results());
|
||||
}
|
||||
|
||||
// Sets the labels from post PostProcessingSpecs.
|
||||
void ConfigureDetectionLabelIdToTextCalculator(
|
||||
PostProcessingSpecs& specs,
|
||||
mediapipe::DetectionLabelIdToTextCalculatorOptions* options) {
|
||||
*options->mutable_label_items() = std::move(specs.label_items);
|
||||
}
|
||||
|
||||
// Splits the vector of 4 output tensors from model inference and calibrate the
|
||||
// score tensors according to the metadata, if any. Then concatenate the tensors
|
||||
// back to a vector of 4 tensors.
|
||||
absl::StatusOr<Source<std::vector<Tensor>>> CalibrateScores(
|
||||
Source<std::vector<Tensor>> model_output_tensors,
|
||||
const proto::DetectionPostprocessingGraphOptions& options, Graph& graph) {
|
||||
// Split tensors.
|
||||
auto* split_tensor_vector_node =
|
||||
&graph.AddNode("SplitTensorVectorCalculator");
|
||||
auto& split_tensor_vector_options =
|
||||
split_tensor_vector_node
|
||||
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto* range = split_tensor_vector_options.add_ranges();
|
||||
range->set_begin(i);
|
||||
range->set_end(i + 1);
|
||||
}
|
||||
model_output_tensors >> split_tensor_vector_node->In(0);
|
||||
|
||||
// Add score calibration calculator.
|
||||
auto* score_calibration_node = &graph.AddNode("ScoreCalibrationCalculator");
|
||||
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
|
||||
.CopyFrom(options.score_calibration_options());
|
||||
const auto& tensor_mapping =
|
||||
options.tensors_to_detections_options().tensor_mapping();
|
||||
split_tensor_vector_node->Out(tensor_mapping.classes_tensor_index()) >>
|
||||
score_calibration_node->In(kIndicesTag);
|
||||
split_tensor_vector_node->Out(tensor_mapping.scores_tensor_index()) >>
|
||||
score_calibration_node->In(kScoresTag);
|
||||
|
||||
// Re-concatenate tensors.
|
||||
auto* concatenate_tensor_vector_node =
|
||||
&graph.AddNode("ConcatenateTensorVectorCalculator");
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (i == tensor_mapping.scores_tensor_index()) {
|
||||
score_calibration_node->Out(kCalibratedScoresTag) >>
|
||||
concatenate_tensor_vector_node->In(i);
|
||||
} else {
|
||||
split_tensor_vector_node->Out(i) >> concatenate_tensor_vector_node->In(i);
|
||||
}
|
||||
}
|
||||
model_output_tensors =
|
||||
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
|
||||
return model_output_tensors;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureDetectionPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::DetectorOptions& detector_options,
|
||||
proto::DetectionPostprocessingGraphOptions& options) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(detector_options));
|
||||
const auto& model = *model_resources.GetTfLiteModel();
|
||||
bool in_model_nms = false;
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected a model with a single subgraph, found %d.",
|
||||
model.subgraphs()->size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
if (model.subgraphs()->Get(0)->outputs()->size() == 2) {
|
||||
in_model_nms = false;
|
||||
} else if (model.subgraphs()->Get(0)->outputs()->size() == 4) {
|
||||
in_model_nms = true;
|
||||
} else {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected a model with 2 or 4 output tensors, found %d.",
|
||||
model.subgraphs()->Get(0)->outputs()->size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
model_resources.GetMetadataExtractor();
|
||||
if (in_model_nms) {
|
||||
ASSIGN_OR_RETURN(auto post_processing_specs,
|
||||
BuildInModelNmsPostProcessingSpecs(detector_options,
|
||||
metadata_extractor));
|
||||
ConfigureInModelNmsTensorsToDetectionsCalculator(
|
||||
post_processing_specs, options.mutable_tensors_to_detections_options());
|
||||
ConfigureDetectionLabelIdToTextCalculator(
|
||||
post_processing_specs,
|
||||
options.mutable_detection_label_ids_to_text_options());
|
||||
if (post_processing_specs.score_calibration_options.has_value()) {
|
||||
*options.mutable_score_calibration_options() =
|
||||
std::move(*post_processing_specs.score_calibration_options);
|
||||
}
|
||||
} else {
|
||||
ASSIGN_OR_RETURN(auto post_processing_specs,
|
||||
BuildOutModelNmsPostProcessingSpecs(detector_options,
|
||||
metadata_extractor));
|
||||
MP_RETURN_IF_ERROR(ConfigureOutModelNmsTensorsToDetectionsCalculator(
|
||||
metadata_extractor, post_processing_specs,
|
||||
options.mutable_tensors_to_detections_options()));
|
||||
MP_RETURN_IF_ERROR(ConfigureSsdAnchorsCalculator(
|
||||
metadata_extractor, options.mutable_ssd_anchors_options()));
|
||||
ConfigureNonMaxSuppressionCalculator(
|
||||
detector_options, options.mutable_non_max_suppression_options());
|
||||
ConfigureDetectionLabelIdToTextCalculator(
|
||||
post_processing_specs,
|
||||
options.mutable_detection_label_ids_to_text_options());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// A DetectionPostprocessingGraph converts raw tensors into
|
||||
// std::vector<Detection>.
|
||||
//
|
||||
// Inputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator. The tensors vector could be
|
||||
// size 4 or size 2. Tensors vector of size 4 expects the tensors from the
|
||||
// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector
|
||||
// of size 2 expects the tensors from the models without the ops.
|
||||
// [1]:
|
||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
|
||||
// Outputs:
|
||||
// DETECTIONS - std::vector<Detection>
|
||||
// The postprocessed detection results.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API
|
||||
// using the 'ConfigureDetectionPostprocessingGraph()' function. See header
|
||||
// file for more details.
|
||||
class DetectionPostprocessingGraph : public mediapipe::Subgraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildDetectionPostprocessing(
|
||||
*sc->MutableOptions<proto::DetectionPostprocessingGraphOptions>(),
|
||||
graph.In(kTensorsTag).Cast<std::vector<Tensor>>(), graph));
|
||||
output_streams.detections >>
|
||||
graph.Out(kDetectionsTag).Cast<std::vector<Detection>>();
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds an on-device detection postprocessing graph into the provided
|
||||
// builder::Graph instance. The detection postprocessing graph takes
|
||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
||||
// stream:
|
||||
// - Detection results as a std::vector<Detection>.
|
||||
//
|
||||
// graph_options: the on-device DetectionPostprocessingGraphOptions.
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<DetectionPostprocessingOutputStreams>
|
||||
BuildDetectionPostprocessing(
|
||||
proto::DetectionPostprocessingGraphOptions& graph_options,
|
||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
||||
std::optional<Source<std::vector<Detection>>> detections;
|
||||
if (!graph_options.has_non_max_suppression_options()) {
|
||||
// Calculators to perform score calibration, if specified in the options.
|
||||
if (graph_options.has_score_calibration_options()) {
|
||||
ASSIGN_OR_RETURN(tensors_in,
|
||||
CalibrateScores(tensors_in, graph_options, graph));
|
||||
}
|
||||
// Calculator to convert output tensors to a detection proto vector.
|
||||
auto& tensors_to_detections =
|
||||
graph.AddNode("TensorsToDetectionsCalculator");
|
||||
tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
||||
.Swap(graph_options.mutable_tensors_to_detections_options());
|
||||
tensors_in >> tensors_to_detections.In(kTensorsTag);
|
||||
detections = tensors_to_detections.Out(kDetectionsTag)
|
||||
.Cast<std::vector<Detection>>();
|
||||
} else {
|
||||
// Generates a single side packet containing a vector of SSD anchors.
|
||||
auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator");
|
||||
ssd_anchor.GetOptions<mediapipe::SsdAnchorsCalculatorOptions>().Swap(
|
||||
graph_options.mutable_ssd_anchors_options());
|
||||
auto anchors =
|
||||
ssd_anchor.SideOut("").Cast<std::vector<mediapipe::Anchor>>();
|
||||
// Convert raw output tensors to detections.
|
||||
auto& tensors_to_detections =
|
||||
graph.AddNode("TensorsToDetectionsCalculator");
|
||||
tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>()
|
||||
.Swap(graph_options.mutable_tensors_to_detections_options());
|
||||
anchors >> tensors_to_detections.SideIn(kAnchorsTag);
|
||||
tensors_in >> tensors_to_detections.In(kTensorsTag);
|
||||
detections = tensors_to_detections.Out(kDetectionsTag)
|
||||
.Cast<std::vector<mediapipe::Detection>>();
|
||||
// Non maximum suppression removes redundant object detections.
|
||||
auto& non_maximum_suppression =
|
||||
graph.AddNode("NonMaxSuppressionCalculator");
|
||||
non_maximum_suppression
|
||||
.GetOptions<mediapipe::NonMaxSuppressionCalculatorOptions>()
|
||||
.Swap(graph_options.mutable_non_max_suppression_options());
|
||||
*detections >> non_maximum_suppression.In("");
|
||||
detections =
|
||||
non_maximum_suppression.Out("").Cast<std::vector<Detection>>();
|
||||
}
|
||||
|
||||
// Calculator to assign detection labels.
|
||||
auto& detection_label_id_to_text =
|
||||
graph.AddNode("DetectionLabelIdToTextCalculator");
|
||||
detection_label_id_to_text
|
||||
.GetOptions<mediapipe::DetectionLabelIdToTextCalculatorOptions>()
|
||||
.Swap(graph_options.mutable_detection_label_ids_to_text_options());
|
||||
*detections >> detection_label_id_to_text.In("");
|
||||
return {
|
||||
{detection_label_id_to_text.Out("").Cast<std::vector<Detection>>()}};
|
||||
}
|
||||
};
|
||||
|
||||
// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly.
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::processors::DetectionPostprocessingGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,62 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Configures a DetectionPostprocessingGraph using the provided model
|
||||
// resources and DetectorOptions.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// auto& postprocessing =
|
||||
// graph.AddNode("mediapipe.tasks.components.processors.DetectionPostprocessingGraph");
|
||||
// MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph(
|
||||
// model_resources,
|
||||
// detector_options,
|
||||
// &preprocessing.GetOptions<DetectionPostprocessingGraphOptions>()));
|
||||
//
|
||||
// The resulting DetectionPostprocessingGraph has the following I/O:
|
||||
// Inputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator. The tensors vector could be
|
||||
// size 4 or size 2. Tensors vector of size 4 expects the tensors from the
|
||||
// models with DETECTION_POSTPROCESS ops in the tflite graph. Tensors vector
|
||||
// of size 2 expects the tensors from the models without the ops.
|
||||
// [1]:
|
||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
|
||||
// Outputs:
|
||||
// DETECTIONS - std::vector<Detection>
|
||||
// The postprocessed detection results.
|
||||
absl::Status ConfigureDetectionPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::DetectorOptions& detector_options,
|
||||
proto::DetectionPostprocessingGraphOptions& options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_DETECTION_POSTPROCESSING_GRAPH_H_
|
|
@ -0,0 +1,570 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/graph_runner.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "tensorflow/lite/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Pointwise;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr absl::string_view kTestDataDirectory =
|
||||
"/mediapipe/tasks/testdata/vision";
|
||||
constexpr absl::string_view kMobileSsdWithMetadata =
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||
constexpr absl::string_view kMobileSsdWithDummyScoreCalibration =
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration."
|
||||
"tflite";
|
||||
constexpr absl::string_view kEfficientDetWithoutNms =
|
||||
"efficientdet_lite0_fp16_no_nms.tflite";
|
||||
|
||||
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
||||
|
||||
constexpr absl::string_view kTensorsTag = "TENSORS";
|
||||
constexpr absl::string_view kDetectionsTag = "DETECTIONS";
|
||||
constexpr absl::string_view kTensorsName = "tensors";
|
||||
constexpr absl::string_view kDetectionsName = "detections";
|
||||
|
||||
// Helper function to get ModelResources.
|
||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||
absl::string_view model_name) {
|
||||
auto external_file = std::make_unique<core::proto::ExternalFile>();
|
||||
external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name));
|
||||
return ModelResources::Create(kTestModelResourcesTag,
|
||||
std::move(external_file));
|
||||
}
|
||||
|
||||
class ConfigureTest : public tflite::testing::Test {};
|
||||
|
||||
TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithMetadata));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.set_max_results(0);
|
||||
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithMetadata));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.add_category_allowlist("foo");
|
||||
options_in.add_category_denylist("bar");
|
||||
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithMaxResults) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithMetadata));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.set_max_results(3);
|
||||
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out));
|
||||
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
Approximately(Partially(EqualsProto(
|
||||
R"pb(tensors_to_detections_options {
|
||||
min_score_thresh: -3.4028235e+38
|
||||
num_classes: 90
|
||||
num_coords: 4
|
||||
max_results: 3
|
||||
tensor_mapping {
|
||||
detections_tensor_index: 0
|
||||
classes_tensor_index: 1
|
||||
scores_tensor_index: 2
|
||||
num_detections_tensor_index: 3
|
||||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
)pb"))));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithMaxResultsWithoutModelNms) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources, CreateModelResourcesForModel(
|
||||
kEfficientDetWithoutNms));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.set_max_results(3);
|
||||
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out));
|
||||
EXPECT_THAT(options_out, Approximately(Partially(EqualsProto(
|
||||
R"pb(tensors_to_detections_options {
|
||||
min_score_thresh: -3.4028235e+38
|
||||
num_classes: 90
|
||||
num_boxes: 19206
|
||||
num_coords: 4
|
||||
x_scale: 1
|
||||
y_scale: 1
|
||||
w_scale: 1
|
||||
h_scale: 1
|
||||
keypoint_coord_offset: 0
|
||||
num_keypoints: 0
|
||||
num_values_per_keypoint: 2
|
||||
apply_exponential_on_box_size: true
|
||||
sigmoid_score: false
|
||||
tensor_mapping {
|
||||
detections_tensor_index: 1
|
||||
scores_tensor_index: 0
|
||||
}
|
||||
box_format: YXHW
|
||||
}
|
||||
non_max_suppression_options {
|
||||
max_num_detections: 3
|
||||
min_suppression_threshold: 0
|
||||
overlap_type: INTERSECTION_OVER_UNION
|
||||
algorithm: DEFAULT
|
||||
}
|
||||
)pb"))));
|
||||
EXPECT_THAT(
|
||||
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithMetadata));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.set_score_threshold(0.5);
|
||||
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out));
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
Approximately(Partially(EqualsProto(
|
||||
R"pb(tensors_to_detections_options {
|
||||
min_score_thresh: 0.5
|
||||
num_classes: 90
|
||||
num_coords: 4
|
||||
tensor_mapping {
|
||||
detections_tensor_index: 0
|
||||
classes_tensor_index: 1
|
||||
scores_tensor_index: 2
|
||||
num_detections_tensor_index: 3
|
||||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
)pb"))));
|
||||
EXPECT_THAT(
|
||||
options_out.detection_label_ids_to_text_options().label_items_size(), 90);
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithMetadata));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.add_category_allowlist("bicycle");
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out));
|
||||
// Clear labels ids to text and compare the rest of the options.
|
||||
options_out.clear_detection_label_ids_to_text_options();
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
Approximately(EqualsProto(
|
||||
R"pb(tensors_to_detections_options {
|
||||
min_score_thresh: -3.4028235e+38
|
||||
num_classes: 90
|
||||
num_coords: 4
|
||||
allow_classes: 1
|
||||
tensor_mapping {
|
||||
detections_tensor_index: 0
|
||||
classes_tensor_index: 1
|
||||
scores_tensor_index: 2
|
||||
num_detections_tensor_index: 3
|
||||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithMetadata));
|
||||
proto::DetectorOptions options_in;
|
||||
options_in.add_category_denylist("person");
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out));
|
||||
// Clear labels ids to text and compare the rest of the options.
|
||||
options_out.clear_detection_label_ids_to_text_options();
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
Approximately(EqualsProto(
|
||||
R"pb(tensors_to_detections_options {
|
||||
min_score_thresh: -3.4028235e+38
|
||||
num_classes: 90
|
||||
num_coords: 4
|
||||
ignore_classes: 0
|
||||
tensor_mapping {
|
||||
detections_tensor_index: 0
|
||||
classes_tensor_index: 1
|
||||
scores_tensor_index: 2
|
||||
num_detections_tensor_index: 3
|
||||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kMobileSsdWithDummyScoreCalibration));
|
||||
proto::DetectorOptions options_in;
|
||||
proto::DetectionPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureDetectionPostprocessingGraph(*model_resources,
|
||||
options_in, options_out));
|
||||
// Clear labels ids to text.
|
||||
options_out.clear_detection_label_ids_to_text_options();
|
||||
// Check sigmoids size and first element.
|
||||
ASSERT_EQ(options_out.score_calibration_options().sigmoids_size(), 89);
|
||||
EXPECT_THAT(options_out.score_calibration_options().sigmoids()[0],
|
||||
EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb"));
|
||||
options_out.mutable_score_calibration_options()->clear_sigmoids();
|
||||
// Compare the rest of the option.
|
||||
EXPECT_THAT(
|
||||
options_out,
|
||||
Approximately(EqualsProto(
|
||||
R"pb(tensors_to_detections_options {
|
||||
min_score_thresh: -3.4028235e+38
|
||||
num_classes: 90
|
||||
num_coords: 4
|
||||
tensor_mapping {
|
||||
detections_tensor_index: 0
|
||||
classes_tensor_index: 1
|
||||
scores_tensor_index: 2
|
||||
num_detections_tensor_index: 3
|
||||
}
|
||||
box_boundaries_indices { ymin: 0 xmin: 1 ymax: 2 xmax: 3 }
|
||||
}
|
||||
score_calibration_options {
|
||||
score_transformation: IDENTITY
|
||||
default_score: 0.5
|
||||
}
|
||||
)pb")));
|
||||
}
|
||||
|
||||
class PostprocessingTest : public tflite::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::DetectorOptions& options) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"DetectionPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureDetectionPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
postprocessing
|
||||
.GetOptions<proto::DetectionPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(
|
||||
std::string(kTensorsName)) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
postprocessing.Out(kDetectionsTag).SetName(std::string(kDetectionsName)) >>
|
||||
graph[Output<std::vector<Detection>>(kDetectionsTag)];
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
std::string(kDetectionsName)));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddTensor(const std::vector<T>& tensor,
|
||||
const Tensor::ElementType& element_type,
|
||||
const Tensor::Shape& shape) {
|
||||
tensors_->emplace_back(element_type, shape);
|
||||
auto view = tensors_->back().GetCpuWriteView();
|
||||
T* buffer = view.buffer<T>();
|
||||
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||
}
|
||||
|
||||
absl::Status Run(int timestamp = 0) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
std::string(kTensorsName),
|
||||
Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||
// Reset tensors for future calls.
|
||||
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||
absl::make_unique<std::vector<Tensor>>();
|
||||
};
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||
// Build graph.
|
||||
proto::DetectorOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||
BuildGraph(kMobileSsdWithMetadata, options));
|
||||
|
||||
// Build input tensors.
|
||||
constexpr int kBboxesNum = 5;
|
||||
// Location tensor.
|
||||
std::vector<float> location_tensor(kBboxesNum * 4, 0);
|
||||
for (int i = 0; i < kBboxesNum; ++i) {
|
||||
location_tensor[i * 4] = 0.1f;
|
||||
location_tensor[i * 4 + 1] = 0.1f;
|
||||
location_tensor[i * 4 + 2] = 0.4f;
|
||||
location_tensor[i * 4 + 3] = 0.5f;
|
||||
}
|
||||
// Category tensor.
|
||||
std::vector<float> category_tensor(kBboxesNum, 0);
|
||||
for (int i = 0; i < kBboxesNum; ++i) {
|
||||
category_tensor[i] = i + 1;
|
||||
}
|
||||
|
||||
// Score tensor. Post processed tensor scores are in descending order.
|
||||
std::vector<float> score_tensor(kBboxesNum, 0);
|
||||
for (int i = 0; i < kBboxesNum; ++i) {
|
||||
score_tensor[i] = static_cast<float>(kBboxesNum - i) / kBboxesNum;
|
||||
}
|
||||
|
||||
// Number of detections tensor.
|
||||
std::vector<float> num_detections_tensor(1, 0);
|
||||
num_detections_tensor[0] = kBboxesNum;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4});
|
||||
AddTensor(category_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum});
|
||||
AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum});
|
||||
AddTensor(num_detections_tensor, Tensor::ElementType::kFloat32, {1});
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(GetResult<std::vector<Detection>>(poller),
|
||||
IsOkAndHolds(ElementsAre(Approximately(EqualsProto(
|
||||
R"pb(
|
||||
label: "bicycle"
|
||||
score: 1
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.1
|
||||
ymin: 0.1
|
||||
width: 0.4
|
||||
height: 0.3
|
||||
}
|
||||
}
|
||||
)pb")),
|
||||
Approximately(EqualsProto(
|
||||
R"pb(
|
||||
label: "car"
|
||||
score: 0.8
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.1
|
||||
ymin: 0.1
|
||||
width: 0.4
|
||||
height: 0.3
|
||||
}
|
||||
}
|
||||
)pb")),
|
||||
Approximately(EqualsProto(
|
||||
R"pb(
|
||||
label: "motorcycle"
|
||||
score: 0.6
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.1
|
||||
ymin: 0.1
|
||||
width: 0.4
|
||||
height: 0.3
|
||||
}
|
||||
}
|
||||
)pb")))));
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithOutModelNms) {
|
||||
// Build graph.
|
||||
proto::DetectorOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||
BuildGraph(kEfficientDetWithoutNms, options));
|
||||
|
||||
// Build input tensors.
|
||||
constexpr int kBboxesNum = 19206;
|
||||
constexpr int kBicycleBboxIdx = 1000;
|
||||
constexpr int kCarBboxIdx = 2000;
|
||||
constexpr int kMotoCycleBboxIdx = 4000;
|
||||
// Location tensor.
|
||||
std::vector<float> location_tensor(kBboxesNum * 4, 0);
|
||||
for (int i = 0; i < kBboxesNum; ++i) {
|
||||
location_tensor[i * 4] = 0.5f;
|
||||
location_tensor[i * 4 + 1] = 0.5f;
|
||||
location_tensor[i * 4 + 2] = 0.001f;
|
||||
location_tensor[i * 4 + 3] = 0.001f;
|
||||
}
|
||||
|
||||
// Detected three objects.
|
||||
location_tensor[kBicycleBboxIdx * 4] = 0.7f;
|
||||
location_tensor[kBicycleBboxIdx * 4 + 1] = 0.8f;
|
||||
location_tensor[kBicycleBboxIdx * 4 + 2] = 0.2f;
|
||||
location_tensor[kBicycleBboxIdx * 4 + 3] = 0.1f;
|
||||
|
||||
location_tensor[kCarBboxIdx * 4] = 0.1f;
|
||||
location_tensor[kCarBboxIdx * 4 + 1] = 0.1f;
|
||||
location_tensor[kCarBboxIdx * 4 + 2] = 0.1f;
|
||||
location_tensor[kCarBboxIdx * 4 + 3] = 0.1f;
|
||||
|
||||
location_tensor[kMotoCycleBboxIdx * 4] = 0.2f;
|
||||
location_tensor[kMotoCycleBboxIdx * 4 + 1] = 0.8f;
|
||||
location_tensor[kMotoCycleBboxIdx * 4 + 2] = 0.1f;
|
||||
location_tensor[kMotoCycleBboxIdx * 4 + 3] = 0.2f;
|
||||
|
||||
// Score tensor.
|
||||
constexpr int kClassesNum = 90;
|
||||
std::vector<float> score_tensor(kBboxesNum * kClassesNum, 1.f / kClassesNum);
|
||||
|
||||
// Detected three objects.
|
||||
score_tensor[kBicycleBboxIdx * kClassesNum + 1] = 1.0f; // bicycle.
|
||||
score_tensor[kCarBboxIdx * kClassesNum + 2] = 0.9f; // car.
|
||||
score_tensor[kMotoCycleBboxIdx * kClassesNum + 3] = 0.8f; // motorcycle.
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(score_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 90});
|
||||
AddTensor(location_tensor, Tensor::ElementType::kFloat32, {1, kBboxesNum, 4});
|
||||
MP_ASSERT_OK(Run());
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(GetResult<std::vector<Detection>>(poller),
|
||||
IsOkAndHolds(ElementsAre(Approximately(EqualsProto(
|
||||
R"pb(
|
||||
label: "bicycle"
|
||||
score: 1
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.8137423
|
||||
ymin: 0.067235775
|
||||
width: 0.117221
|
||||
height: 0.064774655
|
||||
}
|
||||
}
|
||||
)pb")),
|
||||
Approximately(EqualsProto(
|
||||
R"pb(
|
||||
label: "car"
|
||||
score: 0.9
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.53849804
|
||||
ymin: 0.08949606
|
||||
width: 0.05861056
|
||||
height: 0.11722109
|
||||
}
|
||||
}
|
||||
)pb")),
|
||||
Approximately(EqualsProto(
|
||||
R"pb(
|
||||
label: "motorcycle"
|
||||
score: 0.8
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.13779688
|
||||
ymin: 0.26394117
|
||||
width: 0.16322193
|
||||
height: 0.07384467
|
||||
}
|
||||
}
|
||||
)pb")))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -23,6 +23,11 @@ mediapipe_proto_library(
|
|||
srcs = ["classifier_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "detector_options_proto",
|
||||
srcs = ["detector_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classification_postprocessing_graph_options_proto",
|
||||
srcs = ["classification_postprocessing_graph_options.proto"],
|
||||
|
@ -35,6 +40,20 @@ mediapipe_proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "detection_postprocessing_graph_options_proto",
|
||||
srcs = ["detection_postprocessing_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_proto",
|
||||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_proto",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_proto",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedder_options_proto",
|
||||
srcs = ["embedder_options.proto"],
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
import "mediapipe/calculators/tensor/tensors_to_detections_calculator.proto";
|
||||
import "mediapipe/calculators/tflite/ssd_anchors_calculator.proto";
|
||||
import "mediapipe/calculators/util/detection_label_id_to_text_calculator.proto";
|
||||
import "mediapipe/calculators/util/non_max_suppression_calculator.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
|
||||
|
||||
message DetectionPostprocessingGraphOptions {
|
||||
// Optional SsdAnchorsCalculatorOptions for models without
|
||||
// non-maximum-suppression in tflite model graph.
|
||||
optional mediapipe.SsdAnchorsCalculatorOptions ssd_anchors_options = 1;
|
||||
|
||||
// Optional TensorsToDetectionsCalculatorOptions for models without
|
||||
// non-maximum-suppression in tflite model graph.
|
||||
optional mediapipe.TensorsToDetectionsCalculatorOptions
|
||||
tensors_to_detections_options = 2;
|
||||
|
||||
// Optional NonMaxSuppressionCalculatorOptions for models without
|
||||
// non-maximum-suppression in tflite model graph.
|
||||
optional mediapipe.NonMaxSuppressionCalculatorOptions
|
||||
non_max_suppression_options = 3;
|
||||
|
||||
// Optional score calibration options for models with non-maximum-suppression
|
||||
// in tflite model graph.
|
||||
optional ScoreCalibrationCalculatorOptions score_calibration_options = 4;
|
||||
|
||||
// Optional detection label id to text calculator options.
|
||||
optional mediapipe.DetectionLabelIdToTextCalculatorOptions
|
||||
detection_label_ids_to_text_options = 5;
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||
option java_outer_classname = "DetectorOptionsProto";
|
||||
|
||||
// Shared options used by all detection tasks.
|
||||
message DetectorOptions {
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
// Metadata, if any. Defaults to English.
|
||||
optional string display_names_locale = 1 [default = "en"];
|
||||
|
||||
// The maximum number of top-scored detection results to return. If < 0,
|
||||
// all available results will be returned. If 0, an invalid argument error is
|
||||
// returned.
|
||||
optional int32 max_results = 2 [default = -1];
|
||||
|
||||
// Score threshold, overrides the ones provided in the model metadata
|
||||
// (if any). Results below this value are rejected.
|
||||
optional float score_threshold = 3;
|
||||
|
||||
// Overlapping threshold for non-maximum-suppression calculator. Only used for
|
||||
// models without built-in non-maximum-suppression, i.e., models that don't
|
||||
// use the Detection_Postprocess TFLite Op
|
||||
optional float min_suppression_threshold = 6;
|
||||
|
||||
// Optional allowlist of category names. If non-empty, detections whose
|
||||
// category name is not in this set will be filtered out. Duplicate or unknown
|
||||
// category names are ignored. Mutually exclusive with category_denylist.
|
||||
repeated string category_allowlist = 4;
|
||||
|
||||
// Optional denylist of category names. If non-empty, detection whose category
|
||||
// name is in this set will be filtered out. Duplicate or unknown category
|
||||
// names are ignored. Mutually exclusive with category_allowlist.
|
||||
repeated string category_denylist = 5;
|
||||
}
|
|
@ -54,12 +54,7 @@ cc_library(
|
|||
name = "object_detector_graph",
|
||||
srcs = ["object_detector_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_projection_calculator",
|
||||
"//mediapipe/calculators/util:detection_transformation_calculator",
|
||||
"//mediapipe/calculators/util:detections_deduplicate_calculator",
|
||||
|
@ -71,19 +66,15 @@ cc_library(
|
|||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/processors:detection_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:detection_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:detector_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
|
|
|
@ -99,7 +99,20 @@ struct ObjectDetectorOptions {
|
|||
// - only RGB inputs are supported (`channels` is required to be 3).
|
||||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||
// attached to the metadata for input normalization.
|
||||
// Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e:
|
||||
// Output tensors could be 2 output tensors or 4 output tensors.
|
||||
// The 2 output tensors must represent locations and scores, respectively.
|
||||
// (kTfLiteFloat32)
|
||||
// - locations tensor of size `[num_results x num_coords]`. The num_coords is
|
||||
// the number of coordinates a location result represent. Usually in the
|
||||
// form: [4 + 2 * keypoint_num], where 4 location values encode the bounding
|
||||
// box (y_center, x_center, height, width) and the additional keypoints are in
|
||||
// (y, x) order.
|
||||
// (kTfLiteFloat32)
|
||||
// - scores tensor of size `[num_results x num_classes]`. The values of a
|
||||
// result represent the classification probability belonging to the class at
|
||||
// the index, which is denoted in the label file of corresponding tensor
|
||||
// metadata in the model file.
|
||||
// The 4 output tensors must come from `DetectionPostProcess` op, i.e:
|
||||
// (kTfLiteFloat32)
|
||||
// - locations tensor of size `[num_results x 4]`, the inner array
|
||||
// representing bounding boxes in the form [top, left, right, bottom].
|
||||
|
|
|
@ -13,16 +13,10 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
|
@ -31,19 +25,15 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/detection_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/label_map_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -56,42 +46,18 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::tflite::BoundingBoxProperties;
|
||||
using ::tflite::ContentProperties;
|
||||
using ::tflite::ContentProperties_BoundingBoxProperties;
|
||||
using ::tflite::EnumNameContentProperties;
|
||||
using ::tflite::ProcessUnit;
|
||||
using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
|
||||
using ::tflite::TensorMetadata;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>;
|
||||
using ObjectDetectorOptionsProto =
|
||||
object_detector::proto::ObjectDetectorOptions;
|
||||
using TensorsSource =
|
||||
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
|
||||
|
||||
constexpr int kDefaultLocationsIndex = 0;
|
||||
constexpr int kDefaultCategoriesIndex = 1;
|
||||
constexpr int kDefaultScoresIndex = 2;
|
||||
constexpr int kDefaultNumResultsIndex = 3;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
constexpr char kLocationTensorName[] = "location";
|
||||
constexpr char kCategoryTensorName[] = "category";
|
||||
constexpr char kScoreTensorName[] = "score";
|
||||
constexpr char kNumberOfDetectionsTensorName[] = "number of detections";
|
||||
|
||||
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kIndicesTag[] = "INDICES";
|
||||
constexpr char kMatrixTag[] = "MATRIX";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
|
||||
constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX";
|
||||
constexpr char kScoresTag[] = "SCORES";
|
||||
constexpr char kTensorTag[] = "TENSORS";
|
||||
|
||||
// Struct holding the different output streams produced by the object detection
|
||||
|
@ -101,34 +67,6 @@ struct ObjectDetectionOutputStreams {
|
|||
Source<Image> image;
|
||||
};
|
||||
|
||||
// Parameters used for configuring the post-processing calculators.
|
||||
struct PostProcessingSpecs {
|
||||
// The maximum number of detection results to return.
|
||||
int max_results;
|
||||
// Indices of the output tensors to match the output tensors to the correct
|
||||
// index order of the output tensors: [location, categories, scores,
|
||||
// num_detections].
|
||||
std::vector<int> output_tensor_indices;
|
||||
// For each pack of 4 coordinates returned by the model, this denotes the
|
||||
// order in which to get the left, top, right and bottom coordinates.
|
||||
std::vector<unsigned int> bounding_box_corners_order;
|
||||
// This is populated by reading the label files from the TFLite Model
|
||||
// Metadata: if no such files are available, this is left empty and the
|
||||
// ObjectDetector will only be able to populate the `index` field of the
|
||||
// detection results.
|
||||
LabelItems label_items;
|
||||
// Score threshold. Detections with a confidence below this value are
|
||||
// discarded. If none is provided via metadata or options, -FLT_MAX is set as
|
||||
// default value.
|
||||
float score_threshold;
|
||||
// Set of category indices to be allowed/denied.
|
||||
absl::flat_hash_set<int> allow_or_deny_categories;
|
||||
// Indicates `allow_or_deny_categories` is an allowlist or a denylist.
|
||||
bool is_allowlist;
|
||||
// Score calibration options, if any.
|
||||
std::optional<ScoreCalibrationCalculatorOptions> score_calibration_options;
|
||||
};
|
||||
|
||||
absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) {
|
||||
if (options.max_results() == 0) {
|
||||
return CreateStatusWithPayload(
|
||||
|
@ -147,310 +85,6 @@ absl::Status SanityCheckOptions(const ObjectDetectorOptionsProto& options) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
|
||||
const TensorMetadata& tensor_metadata) {
|
||||
if (tensor_metadata.content() == nullptr ||
|
||||
tensor_metadata.content()->content_properties() == nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected BoundingBoxProperties for tensor %s, found none.",
|
||||
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
ContentProperties type = tensor_metadata.content()->content_properties_type();
|
||||
if (type != ContentProperties_BoundingBoxProperties) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected BoundingBoxProperties for tensor %s, found %s.",
|
||||
tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
|
||||
EnumNameContentProperties(type)),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
const BoundingBoxProperties* properties =
|
||||
tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
|
||||
|
||||
// Mobile SSD only supports "BOUNDARIES" bounding box type.
|
||||
if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
|
||||
tflite::EnumNameBoundingBoxType(properties->type())),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
// Mobile SSD only supports "RATIO" coordinates type.
|
||||
if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Mobile SSD only supports CoordinateType RATIO, found %s",
|
||||
tflite::EnumNameCoordinateType(properties->coordinate_type())),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
// Index is optional, but must contain 4 values if present.
|
||||
if (properties->index() != nullptr && properties->index()->size() != 4) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat(
|
||||
"Expected BoundingBoxProperties index to contain 4 values, found "
|
||||
"%d",
|
||||
properties->index()->size()),
|
||||
MediaPipeTasksStatus::kMetadataInvalidContentPropertiesError);
|
||||
}
|
||||
|
||||
return properties;
|
||||
}
|
||||
|
||||
absl::StatusOr<LabelItems> GetLabelItemsIfAny(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
const TensorMetadata& tensor_metadata, absl::string_view locale) {
|
||||
const std::string labels_filename =
|
||||
ModelMetadataExtractor::FindFirstAssociatedFileName(
|
||||
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
|
||||
if (labels_filename.empty()) {
|
||||
LabelItems empty_label_items;
|
||||
return empty_label_items;
|
||||
}
|
||||
ASSIGN_OR_RETURN(absl::string_view labels_file,
|
||||
metadata_extractor.GetAssociatedFile(labels_filename));
|
||||
const std::string display_names_filename =
|
||||
ModelMetadataExtractor::FindFirstAssociatedFileName(
|
||||
tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS,
|
||||
locale);
|
||||
absl::string_view display_names_file;
|
||||
if (!display_names_filename.empty()) {
|
||||
ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
|
||||
display_names_filename));
|
||||
}
|
||||
return mediapipe::BuildLabelMapFromFiles(labels_file, display_names_file);
|
||||
}
|
||||
|
||||
absl::StatusOr<float> GetScoreThreshold(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
const TensorMetadata& tensor_metadata) {
|
||||
ASSIGN_OR_RETURN(
|
||||
const ProcessUnit* score_thresholding_process_unit,
|
||||
metadata_extractor.FindFirstProcessUnit(
|
||||
tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
|
||||
if (score_thresholding_process_unit == nullptr) {
|
||||
return kDefaultScoreThreshold;
|
||||
}
|
||||
return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
|
||||
->global_score_threshold();
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||
const ObjectDetectorOptionsProto& config, const LabelItems& label_items) {
|
||||
absl::flat_hash_set<int> category_indices;
|
||||
// Exit early if no denylist/allowlist.
|
||||
if (config.category_denylist_size() == 0 &&
|
||||
config.category_allowlist_size() == 0) {
|
||||
return category_indices;
|
||||
}
|
||||
if (label_items.empty()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Using `category_allowlist` or `category_denylist` requires "
|
||||
"labels to be present in the TFLite Model Metadata but none was found.",
|
||||
MediaPipeTasksStatus::kMetadataMissingLabelsError);
|
||||
}
|
||||
const auto& category_list = config.category_allowlist_size() > 0
|
||||
? config.category_allowlist()
|
||||
: config.category_denylist();
|
||||
for (const auto& category_name : category_list) {
|
||||
int index = -1;
|
||||
for (int i = 0; i < label_items.size(); ++i) {
|
||||
if (label_items.at(i).name() == category_name) {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Ignores duplicate or unknown categories.
|
||||
if (index < 0) {
|
||||
continue;
|
||||
}
|
||||
category_indices.insert(index);
|
||||
}
|
||||
return category_indices;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::optional<ScoreCalibrationCalculatorOptions>>
|
||||
GetScoreCalibrationOptionsIfAny(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
const TensorMetadata& tensor_metadata) {
|
||||
// Get ScoreCalibrationOptions, if any.
|
||||
ASSIGN_OR_RETURN(
|
||||
const ProcessUnit* score_calibration_process_unit,
|
||||
metadata_extractor.FindFirstProcessUnit(
|
||||
tensor_metadata, tflite::ProcessUnitOptions_ScoreCalibrationOptions));
|
||||
if (score_calibration_process_unit == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto* score_calibration_options =
|
||||
score_calibration_process_unit->options_as_ScoreCalibrationOptions();
|
||||
// Get corresponding AssociatedFile.
|
||||
auto score_calibration_filename =
|
||||
metadata_extractor.FindFirstAssociatedFileName(
|
||||
tensor_metadata,
|
||||
tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
|
||||
if (score_calibration_filename.empty()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kNotFound,
|
||||
"Found ScoreCalibrationOptions but missing required associated "
|
||||
"parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
absl::string_view score_calibration_file,
|
||||
metadata_extractor.GetAssociatedFile(score_calibration_filename));
|
||||
ScoreCalibrationCalculatorOptions score_calibration_calculator_options;
|
||||
MP_RETURN_IF_ERROR(ConfigureScoreCalibration(
|
||||
score_calibration_options->score_transformation(),
|
||||
score_calibration_options->default_score(), score_calibration_file,
|
||||
&score_calibration_calculator_options));
|
||||
return score_calibration_calculator_options;
|
||||
}
|
||||
|
||||
std::vector<int> GetOutputTensorIndices(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
|
||||
tensor_metadatas) {
|
||||
std::vector<int> output_indices = {
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kLocationTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kCategoryTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas, kScoreTensorName),
|
||||
core::FindTensorIndexByMetadataName(tensor_metadatas,
|
||||
kNumberOfDetectionsTensorName)};
|
||||
// locations, categories, scores, and number of detections
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int output_index = output_indices[i];
|
||||
// If tensor name is not found, set the default output indices.
|
||||
if (output_index == -1) {
|
||||
LOG(WARNING) << absl::StrFormat(
|
||||
"You don't seem to be matching tensor names in metadata list. The "
|
||||
"tensor name \"%s\" at index %d in the model metadata doesn't "
|
||||
"match "
|
||||
"the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].",
|
||||
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
|
||||
kCategoryTensorName, kScoreTensorName, kNumberOfDetectionsTensorName);
|
||||
output_indices = {kDefaultLocationsIndex, kDefaultCategoriesIndex,
|
||||
kDefaultScoresIndex, kDefaultNumResultsIndex};
|
||||
return output_indices;
|
||||
}
|
||||
}
|
||||
return output_indices;
|
||||
}
|
||||
|
||||
// Builds PostProcessingSpecs from ObjectDetectorOptionsProto and model metadata
|
||||
// for configuring the post-processing calculators.
|
||||
absl::StatusOr<PostProcessingSpecs> BuildPostProcessingSpecs(
|
||||
const ObjectDetectorOptionsProto& options,
|
||||
const ModelMetadataExtractor* metadata_extractor) {
|
||||
// Checks output tensor metadata is present and consistent with model.
|
||||
auto* output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr ||
|
||||
output_tensors_metadata->size() != 4) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Mismatch between number of output tensors (4) and "
|
||||
"output tensors metadata (%d).",
|
||||
output_tensors_metadata == nullptr
|
||||
? 0
|
||||
: output_tensors_metadata->size()),
|
||||
MediaPipeTasksStatus::kMetadataInconsistencyError);
|
||||
}
|
||||
PostProcessingSpecs specs;
|
||||
specs.max_results = options.max_results();
|
||||
specs.output_tensor_indices = GetOutputTensorIndices(output_tensors_metadata);
|
||||
// Extracts mandatory BoundingBoxProperties and performs sanity checks on the
|
||||
// fly.
|
||||
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
|
||||
GetBoundingBoxProperties(*output_tensors_metadata->Get(
|
||||
specs.output_tensor_indices[0])));
|
||||
if (bounding_box_properties->index() == nullptr) {
|
||||
specs.bounding_box_corners_order = {0, 1, 2, 3};
|
||||
} else {
|
||||
auto bounding_box_index = bounding_box_properties->index();
|
||||
specs.bounding_box_corners_order = {
|
||||
bounding_box_index->Get(0),
|
||||
bounding_box_index->Get(1),
|
||||
bounding_box_index->Get(2),
|
||||
bounding_box_index->Get(3),
|
||||
};
|
||||
}
|
||||
// Builds label map (if available) from metadata.
|
||||
ASSIGN_OR_RETURN(specs.label_items,
|
||||
GetLabelItemsIfAny(*metadata_extractor,
|
||||
*output_tensors_metadata->Get(
|
||||
specs.output_tensor_indices[1]),
|
||||
options.display_names_locale()));
|
||||
// Obtains allow/deny categories.
|
||||
specs.is_allowlist = !options.category_allowlist().empty();
|
||||
ASSIGN_OR_RETURN(
|
||||
specs.allow_or_deny_categories,
|
||||
GetAllowOrDenyCategoryIndicesIfAny(options, specs.label_items));
|
||||
// Sets score threshold.
|
||||
if (options.has_score_threshold()) {
|
||||
specs.score_threshold = options.score_threshold();
|
||||
} else {
|
||||
ASSIGN_OR_RETURN(specs.score_threshold,
|
||||
GetScoreThreshold(*metadata_extractor,
|
||||
*output_tensors_metadata->Get(
|
||||
specs.output_tensor_indices[2])));
|
||||
}
|
||||
// Builds score calibration options (if available) from metadata.
|
||||
ASSIGN_OR_RETURN(
|
||||
specs.score_calibration_options,
|
||||
GetScoreCalibrationOptionsIfAny(
|
||||
*metadata_extractor,
|
||||
*output_tensors_metadata->Get(specs.output_tensor_indices[2])));
|
||||
return specs;
|
||||
}
|
||||
|
||||
// Fills in the TensorsToDetectionsCalculatorOptions based on
|
||||
// PostProcessingSpecs.
|
||||
void ConfigureTensorsToDetectionsCalculator(
|
||||
const PostProcessingSpecs& specs,
|
||||
mediapipe::TensorsToDetectionsCalculatorOptions* options) {
|
||||
options->set_num_classes(specs.label_items.size());
|
||||
options->set_num_coords(4);
|
||||
options->set_min_score_thresh(specs.score_threshold);
|
||||
if (specs.max_results != -1) {
|
||||
options->set_max_results(specs.max_results);
|
||||
}
|
||||
if (specs.is_allowlist) {
|
||||
options->mutable_allow_classes()->Assign(
|
||||
specs.allow_or_deny_categories.begin(),
|
||||
specs.allow_or_deny_categories.end());
|
||||
} else {
|
||||
options->mutable_ignore_classes()->Assign(
|
||||
specs.allow_or_deny_categories.begin(),
|
||||
specs.allow_or_deny_categories.end());
|
||||
}
|
||||
|
||||
const auto& output_indices = specs.output_tensor_indices;
|
||||
// Assigns indices to each the model output tensor.
|
||||
auto* tensor_mapping = options->mutable_tensor_mapping();
|
||||
tensor_mapping->set_detections_tensor_index(output_indices[0]);
|
||||
tensor_mapping->set_classes_tensor_index(output_indices[1]);
|
||||
tensor_mapping->set_scores_tensor_index(output_indices[2]);
|
||||
tensor_mapping->set_num_detections_tensor_index(output_indices[3]);
|
||||
|
||||
// Assigns the bounding box corner order.
|
||||
auto box_boundaries_indices = options->mutable_box_boundaries_indices();
|
||||
box_boundaries_indices->set_xmin(specs.bounding_box_corners_order[0]);
|
||||
box_boundaries_indices->set_ymin(specs.bounding_box_corners_order[1]);
|
||||
box_boundaries_indices->set_xmax(specs.bounding_box_corners_order[2]);
|
||||
box_boundaries_indices->set_ymax(specs.bounding_box_corners_order[3]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.ObjectDetectorGraph" performs object detection.
|
||||
|
@ -530,7 +164,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
// Checks that the model has 4 outputs.
|
||||
auto& model = *model_resources.GetTfLiteModel();
|
||||
if (model.subgraphs()->size() != 1) {
|
||||
return CreateStatusWithPayload(
|
||||
|
@ -539,13 +172,6 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
model.subgraphs()->size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
if (model.subgraphs()->Get(0)->outputs()->size() != 4) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected a model with 4 output tensors, found %d.",
|
||||
model.subgraphs()->Get(0)->outputs()->size()),
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
// Checks that metadata is available.
|
||||
auto* metadata_extractor = model_resources.GetMetadataExtractor();
|
||||
if (metadata_extractor->GetModelMetadata() == nullptr ||
|
||||
|
@ -577,70 +203,36 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
preprocessing.Out(kTensorTag) >> inference.In(kTensorTag);
|
||||
|
||||
// Adds post processing calculators.
|
||||
ASSIGN_OR_RETURN(
|
||||
auto post_processing_specs,
|
||||
BuildPostProcessingSpecs(task_options, metadata_extractor));
|
||||
// Calculators to perform score calibration, if specified in the metadata.
|
||||
TensorsSource calibrated_tensors =
|
||||
TensorsSource model_output_tensors =
|
||||
inference.Out(kTensorTag).Cast<std::vector<Tensor>>();
|
||||
if (post_processing_specs.score_calibration_options.has_value()) {
|
||||
// Split tensors.
|
||||
auto* split_tensor_vector_node =
|
||||
&graph.AddNode("SplitTensorVectorCalculator");
|
||||
auto& split_tensor_vector_options =
|
||||
split_tensor_vector_node
|
||||
->GetOptions<mediapipe::SplitVectorCalculatorOptions>();
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto* range = split_tensor_vector_options.add_ranges();
|
||||
range->set_begin(i);
|
||||
range->set_end(i + 1);
|
||||
}
|
||||
calibrated_tensors >> split_tensor_vector_node->In(0);
|
||||
|
||||
// Add score calibration calculator.
|
||||
auto* score_calibration_node =
|
||||
&graph.AddNode("ScoreCalibrationCalculator");
|
||||
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
|
||||
.CopyFrom(*post_processing_specs.score_calibration_options);
|
||||
split_tensor_vector_node->Out(
|
||||
post_processing_specs.output_tensor_indices[1]) >>
|
||||
score_calibration_node->In(kIndicesTag);
|
||||
split_tensor_vector_node->Out(
|
||||
post_processing_specs.output_tensor_indices[2]) >>
|
||||
score_calibration_node->In(kScoresTag);
|
||||
|
||||
// Re-concatenate tensors.
|
||||
auto* concatenate_tensor_vector_node =
|
||||
&graph.AddNode("ConcatenateTensorVectorCalculator");
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (i == post_processing_specs.output_tensor_indices[2]) {
|
||||
score_calibration_node->Out(kCalibratedScoresTag) >>
|
||||
concatenate_tensor_vector_node->In(i);
|
||||
} else {
|
||||
split_tensor_vector_node->Out(i) >>
|
||||
concatenate_tensor_vector_node->In(i);
|
||||
}
|
||||
}
|
||||
calibrated_tensors =
|
||||
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
|
||||
}
|
||||
// Calculator to convert output tensors to a detection proto vector.
|
||||
// Connects TensorsToDetectionsCalculator's input stream to the output
|
||||
// tensors produced by the inference subgraph.
|
||||
auto& tensors_to_detections =
|
||||
graph.AddNode("TensorsToDetectionsCalculator");
|
||||
ConfigureTensorsToDetectionsCalculator(
|
||||
post_processing_specs,
|
||||
&tensors_to_detections
|
||||
.GetOptions<mediapipe::TensorsToDetectionsCalculatorOptions>());
|
||||
calibrated_tensors >> tensors_to_detections.In(kTensorTag);
|
||||
// Add Detection postprocessing graph to convert tensors to detections.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors.DetectionPostprocessingGraph");
|
||||
components::processors::proto::DetectorOptions detector_options;
|
||||
detector_options.set_max_results(task_options.max_results());
|
||||
detector_options.set_score_threshold(task_options.score_threshold());
|
||||
detector_options.set_display_names_locale(
|
||||
task_options.display_names_locale());
|
||||
detector_options.mutable_category_allowlist()->CopyFrom(
|
||||
task_options.category_allowlist());
|
||||
detector_options.mutable_category_denylist()->CopyFrom(
|
||||
task_options.category_denylist());
|
||||
// TODO: expose min suppression threshold in
|
||||
// ObjectDetectorOptions.
|
||||
detector_options.set_min_suppression_threshold(0.3);
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureDetectionPostprocessingGraph(
|
||||
model_resources, detector_options,
|
||||
postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
DetectionPostprocessingGraphOptions>()));
|
||||
model_output_tensors >> postprocessing.In(kTensorTag);
|
||||
auto detections = postprocessing.Out(kDetectionsTag);
|
||||
|
||||
// Calculator to projects detections back to the original coordinate system.
|
||||
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
|
||||
tensors_to_detections.Out(kDetectionsTag) >>
|
||||
detection_projection.In(kDetectionsTag);
|
||||
detections >> detection_projection.In(kDetectionsTag);
|
||||
preprocessing.Out(kMatrixTag) >>
|
||||
detection_projection.In(kProjectionMatrixTag);
|
||||
|
||||
|
@ -652,22 +244,13 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
detection_transformation.In(kDetectionsTag);
|
||||
preprocessing.Out(kImageSizeTag) >>
|
||||
detection_transformation.In(kImageSizeTag);
|
||||
|
||||
// Calculator to assign detection labels.
|
||||
auto& detection_label_id_to_text =
|
||||
graph.AddNode("DetectionLabelIdToTextCalculator");
|
||||
auto& detection_label_id_to_text_opts =
|
||||
detection_label_id_to_text
|
||||
.GetOptions<mediapipe::DetectionLabelIdToTextCalculatorOptions>();
|
||||
*detection_label_id_to_text_opts.mutable_label_items() =
|
||||
std::move(post_processing_specs.label_items);
|
||||
detection_transformation.Out(kPixelDetectionsTag) >>
|
||||
detection_label_id_to_text.In("");
|
||||
auto detections_in_pixel =
|
||||
detection_transformation.Out(kPixelDetectionsTag);
|
||||
|
||||
// Deduplicate Detections with same bounding box coordinates.
|
||||
auto& detections_deduplicate =
|
||||
graph.AddNode("DetectionsDeduplicateCalculator");
|
||||
detection_label_id_to_text.Out("") >> detections_deduplicate.In("");
|
||||
detections_in_pixel >> detections_deduplicate.In("");
|
||||
|
||||
// Outputs the labeled detections and the processed image as the subgraph
|
||||
// output streams.
|
||||
|
|
|
@ -76,15 +76,18 @@ using ::testing::HasSubstr;
|
|||
using ::testing::Optional;
|
||||
using DetectionProto = mediapipe::Detection;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kMobileSsdWithMetadata[] =
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||
constexpr char kMobileSsdWithDummyScoreCalibration[] =
|
||||
constexpr absl::string_view kTestDataDirectory{
|
||||
"/mediapipe/tasks/testdata/vision/"};
|
||||
constexpr absl::string_view kMobileSsdWithMetadata{
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"};
|
||||
constexpr absl::string_view kMobileSsdWithDummyScoreCalibration{
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration."
|
||||
"tflite";
|
||||
"tflite"};
|
||||
// The model has different output tensor order.
|
||||
constexpr char kEfficientDetWithMetadata[] =
|
||||
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
|
||||
constexpr absl::string_view kEfficientDetWithMetadata{
|
||||
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite"};
|
||||
constexpr absl::string_view kEfficientDetWithoutNms{
|
||||
"efficientdet_lite0_fp16_no_nms.tflite"};
|
||||
|
||||
// Checks that the two provided `Detection` proto vectors are equal, with a
|
||||
// tolerancy on floating-point scores to account for numerical instabilities.
|
||||
|
@ -451,6 +454,51 @@ TEST_F(ImageModeTest, SucceedsEfficientDetModel) {
|
|||
})pb")}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsEfficientDetNoNmsModel) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"cats_and_dogs.jpg")));
|
||||
auto options = std::make_unique<ObjectDetectorOptions>();
|
||||
options->max_results = 4;
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kEfficientDetWithoutNms);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, object_detector->Detect(image));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
ExpectApproximatelyEqual(
|
||||
results,
|
||||
ConvertToDetectionResult(
|
||||
{ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "dog"
|
||||
score: 0.733542
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 636 ymin: 160 width: 282 height: 451 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.699751
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 870 ymin: 411 width: 208 height: 187 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "dog"
|
||||
score: 0.682425
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 386 ymin: 216 width: 256 height: 376 }
|
||||
})pb"),
|
||||
ParseTextProtoOrDie<DetectionProto>(R"pb(
|
||||
label: "cat"
|
||||
score: 0.646585
|
||||
location_data {
|
||||
format: BOUNDING_BOX
|
||||
bounding_box { xmin: 83 ymin: 399 width: 347 height: 198 }
|
||||
})pb")}));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithoutImageResizing) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath(
|
||||
"./", kTestDataDirectory,
|
||||
|
|
2
mediapipe/tasks/testdata/vision/BUILD
vendored
2
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -41,6 +41,7 @@ mediapipe_files(srcs = [
|
|||
"conv2d_input_channel_1.tflite",
|
||||
"deeplabv3.tflite",
|
||||
"dense.tflite",
|
||||
"efficientdet_lite0_fp16_no_nms.tflite",
|
||||
"face_detection_full_range.tflite",
|
||||
"face_detection_full_range_sparse.tflite",
|
||||
"face_detection_short_range.tflite",
|
||||
|
@ -167,6 +168,7 @@ filegroup(
|
|||
"conv2d_input_channel_1.tflite",
|
||||
"deeplabv3.tflite",
|
||||
"dense.tflite",
|
||||
"efficientdet_lite0_fp16_no_nms.tflite",
|
||||
"face_detection_full_range.tflite",
|
||||
"face_detection_full_range_sparse.tflite",
|
||||
"face_detection_short_range.tflite",
|
||||
|
|
4
third_party/external_files.bzl
vendored
4
third_party/external_files.bzl
vendored
|
@ -276,8 +276,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_tflite",
|
||||
sha256 = "bcda125c96d3767bca894c8cbe7bc458379c9974c9fd8bdc6204e7124a74082a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682456096034465"],
|
||||
sha256 = "237a58389081333e5cf4154e42b593ce7dd357445536fcaf4ca5bc51c2c50f1c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682476299542472"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user