Add custom metadata for object detection model with out-of-graph nms.
PiperOrigin-RevId: 527083453
This commit is contained in:
parent
17f5b95387
commit
507ed0d91d
|
@ -328,7 +328,7 @@ class ObjectDetector(classifier.Classifier):
|
|||
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
|
||||
tflite_model = converter.convert()
|
||||
|
||||
writer = object_detector_writer.MetadataWriter.create(
|
||||
writer = object_detector_writer.MetadataWriter.create_for_models_with_nms(
|
||||
tflite_model,
|
||||
self._model_spec.mean_rgb,
|
||||
self._model_spec.stddev_rgb,
|
||||
|
|
|
@ -36,3 +36,13 @@ flatbuffer_py_library(
|
|||
name = "image_segmenter_metadata_schema_py",
|
||||
srcs = ["image_segmenter_metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "object_detector_metadata_schema_cc",
|
||||
srcs = ["object_detector_metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "object_detector_metadata_schema_py",
|
||||
srcs = ["object_detector_metadata_schema.fbs"],
|
||||
)
|
||||
|
|
98
mediapipe/tasks/metadata/object_detector_metadata_schema.fbs
Normal file
98
mediapipe/tasks/metadata/object_detector_metadata_schema.fbs
Normal file
|
@ -0,0 +1,98 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace mediapipe.tasks;
|
||||
|
||||
// ObjectDetectorOptions.min_parser_version indicates the minimum necessary
|
||||
// object detector metadata parser version to fully understand all fields in a
|
||||
// given metadata flatbuffer. This min_parser_version is specific for the
|
||||
// object detector metadata defined in this schema file.
|
||||
//
|
||||
// New fields and types will have associated comments with the schema version
|
||||
// for which they were added.
|
||||
//
|
||||
// Schema Semantic version: 1.0.0
|
||||
|
||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||
// break change is applied to the schema, such as removing fields or adding new
|
||||
// fields to the middle of a table.
|
||||
file_identifier "V001";
|
||||
|
||||
|
||||
// History:
|
||||
// 1.0.0 - Initial version.
|
||||
|
||||
// A fixed size anchor.
|
||||
table FixedAnchor {
|
||||
x_center: float;
|
||||
y_center: float;
|
||||
width: float;
|
||||
height: float;
|
||||
}
|
||||
|
||||
// The schema for a list of anchors with fixed size.
|
||||
table FixedAnchorsSchema {
|
||||
anchors: [FixedAnchor];
|
||||
}
|
||||
|
||||
// The ssd anchors options used in the object detector.
|
||||
table SsdAnchorsOptions {
|
||||
fixed_anchors_schema: FixedAnchorsSchema;
|
||||
}
|
||||
|
||||
// The options for decoding the raw model output tensors. The options are mostly
|
||||
// used in TensorsToDetectionsCalculatorOptions.
|
||||
table TensorsDecodingOptions {
|
||||
// The number of output classes predicted by the detection model.
|
||||
num_classes: int;
|
||||
// The number of output boxes predicted by the detection model.
|
||||
num_boxes: int;
|
||||
// The number of output values per boxes predicted by the detection
|
||||
// model. The values contain bounding boxes, keypoints, etc.
|
||||
num_coords: int;
|
||||
// The offset of keypoint coordinates in the location tensor.
|
||||
keypoint_coord_offset: int;
|
||||
// The number of predicted keypoints.
|
||||
num_keypoints: int;
|
||||
// The dimension of each keypoint, e.g. number of values predicted for each
|
||||
// keypoint.
|
||||
num_values_per_keypoint: int;
|
||||
// Parameters for decoding SSD detection model.
|
||||
x_scale: float;
|
||||
y_scale: float;
|
||||
w_scale: float;
|
||||
h_scale: float;
|
||||
// Whether to apply exponential on box size.
|
||||
apply_exponential_on_box_size: bool;
|
||||
// Whether to apply sigmod function on the score.
|
||||
sigmoid_score: bool;
|
||||
}
|
||||
|
||||
table ObjectDetectorOptions {
|
||||
// TODO: automatically populate min parser string.
|
||||
// The minimum necessary object detector metadata parser version to fully
|
||||
// understand all fields in a given metadata flatbuffer. This field is
|
||||
// automatically populated by the MetadataPopulator when the metadata is
|
||||
// populated into a TFLite model. This min_parser_version is specific for the
|
||||
// object detector metadata defined in this schema file.
|
||||
min_parser_version:string;
|
||||
|
||||
// The options of ssd anchors configs used by the detection model.
|
||||
ssd_anchors_options:SsdAnchorsOptions;
|
||||
|
||||
// The tensors decoding options to convert raw tensors to detection results.
|
||||
tensors_decoding_options:TensorsDecodingOptions;
|
||||
}
|
||||
|
||||
root_type ObjectDetectorOptions;
|
|
@ -67,7 +67,15 @@ py_library(
|
|||
py_library(
|
||||
name = "object_detector",
|
||||
srcs = ["object_detector.py"],
|
||||
deps = [":metadata_writer"],
|
||||
data = ["//mediapipe/tasks/metadata:object_detector_metadata_schema.fbs"],
|
||||
deps = [
|
||||
":metadata_info",
|
||||
":metadata_writer",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:object_detector_metadata_schema_py",
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
import abc
|
||||
import collections
|
||||
import csv
|
||||
import enum
|
||||
import os
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
|
@ -1004,6 +1005,84 @@ class DetectionOutputTensorsMd:
|
|||
return self._output_mds
|
||||
|
||||
|
||||
class RawDetectionOutputTensorsOrder(enum.Enum):
|
||||
"""Output tensors order for detection models without postprocessing.
|
||||
|
||||
Because it is not able to determined the order of output tensors for models
|
||||
without postprocessing, it is needed to specify the output tensors order for
|
||||
metadata writer.
|
||||
"""
|
||||
|
||||
UNSPECIFIED = 0
|
||||
# The first tensor is score, and the second tensor is location.
|
||||
SCORE_LOCATION = 1
|
||||
# The first tensor is location, and the second tensor is score.
|
||||
LOCATION_SCORE = 2
|
||||
|
||||
|
||||
class RawDetectionOutputTensorsMd:
|
||||
"""A container for the output tensor metadata of detection models without postprocessing."""
|
||||
|
||||
_LOCATION_NAME = "location"
|
||||
_LOCATION_DESCRIPTION = "The locations of the detected boxes."
|
||||
_SCORE_NAME = "score"
|
||||
_SCORE_DESCRIPTION = "The scores of the detected boxes."
|
||||
_CONTENT_VALUE_DIM = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_buffer: bytearray,
|
||||
label_files: Optional[List[LabelFileMd]] = None,
|
||||
output_tensors_order: RawDetectionOutputTensorsOrder = RawDetectionOutputTensorsOrder.UNSPECIFIED,
|
||||
) -> None:
|
||||
"""Initializes the instance of DetectionOutputTensorsMd.
|
||||
|
||||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
label_files: information of the label files [1] in the classification
|
||||
tensor.
|
||||
output_tensors_order: the order of the output tensors.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L9
|
||||
"""
|
||||
# Get the output tensor indices and names from the tflite model.
|
||||
tensor_indices_and_names = list(
|
||||
zip(
|
||||
writer_utils.get_output_tensor_indices(model_buffer),
|
||||
writer_utils.get_output_tensor_names(model_buffer),
|
||||
)
|
||||
)
|
||||
location_md = LocationTensorMd(
|
||||
name=self._LOCATION_NAME,
|
||||
description=self._LOCATION_DESCRIPTION,
|
||||
)
|
||||
score_md = ClassificationTensorMd(
|
||||
name=self._SCORE_NAME,
|
||||
description=self._SCORE_DESCRIPTION,
|
||||
label_files=label_files,
|
||||
)
|
||||
|
||||
if output_tensors_order == RawDetectionOutputTensorsOrder.SCORE_LOCATION:
|
||||
self._output_mds = [score_md, location_md]
|
||||
elif output_tensors_order == RawDetectionOutputTensorsOrder.LOCATION_SCORE:
|
||||
self._output_mds = [location_md, score_md]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported OutputTensorsOrder value: {output_tensors_order}"
|
||||
)
|
||||
|
||||
if len(self._output_mds) != len(tensor_indices_and_names):
|
||||
raise ValueError(
|
||||
"The size of TFLite output should be " + str(len(self._output_mds))
|
||||
)
|
||||
for i, output_md in enumerate(self._output_mds):
|
||||
output_md.tensor_name = tensor_indices_and_names[i][1]
|
||||
|
||||
@property
|
||||
def output_mds(self) -> List[TensorMd]:
|
||||
return self._output_mds
|
||||
|
||||
|
||||
class TensorGroupMd:
|
||||
"""A container for a group of tensor metadata information."""
|
||||
|
||||
|
|
|
@ -632,7 +632,7 @@ class MetadataWriter(object):
|
|||
score_calibration: Optional[ScoreCalibration] = None,
|
||||
group_name: str = _DETECTION_GROUP_NAME,
|
||||
) -> 'MetadataWriter':
|
||||
"""Adds a detection head metadata for detection output tensor.
|
||||
"""Adds a detection head metadata for detection output tensor of models with postprocessing.
|
||||
|
||||
Args:
|
||||
labels: an instance of Labels helper class.
|
||||
|
@ -661,6 +661,33 @@ class MetadataWriter(object):
|
|||
self._output_group_mds.append(group_md)
|
||||
return self
|
||||
|
||||
def add_raw_detection_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
output_tensors_order: metadata_info.RawDetectionOutputTensorsOrder = metadata_info.RawDetectionOutputTensorsOrder.UNSPECIFIED,
|
||||
) -> 'MetadataWriter':
|
||||
"""Adds a detection head metadata for detection output tensor of models without postprocessing.
|
||||
|
||||
Args:
|
||||
labels: an instance of Labels helper class.
|
||||
output_tensors_order: the order of the output tensors. For models of
|
||||
out-of-graph non-maximum-suppression only.
|
||||
|
||||
Returns:
|
||||
The current Writer instance to allow chained operation.
|
||||
"""
|
||||
label_files = self._create_label_file_md(labels)
|
||||
detection_output_mds = metadata_info.RawDetectionOutputTensorsMd(
|
||||
self._model_buffer,
|
||||
label_files=label_files,
|
||||
output_tensors_order=output_tensors_order,
|
||||
).output_mds
|
||||
self._output_mds.extend(detection_output_mds)
|
||||
# Outputs are location, score.
|
||||
if len(detection_output_mds) != 2:
|
||||
raise ValueError('The size of detections output should be 2.')
|
||||
return self
|
||||
|
||||
def add_segmentation_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
|
|
|
@ -14,8 +14,14 @@
|
|||
# ==============================================================================
|
||||
"""Writes metadata and label file to the Object Detector models."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.metadata import object_detector_metadata_schema_py_generated as _detector_metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
_MODEL_NAME = "ObjectDetector"
|
||||
|
@ -25,12 +31,187 @@ _MODEL_DESCRIPTION = (
|
|||
"stream."
|
||||
)
|
||||
|
||||
# Metadata Schema file for object detector.
|
||||
_FLATC_METADATA_SCHEMA_FILE = metadata.get_path_to_datafile(
|
||||
"../../../metadata/object_detector_metadata_schema.fbs",
|
||||
)
|
||||
|
||||
# Metadata name in custom metadata field. The metadata name is used to get
|
||||
# object detector metadata from SubGraphMetadata.custom_metadata and shouldn't
|
||||
# be changed.
|
||||
_METADATA_NAME = "DETECTOR_METADATA"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FixedAnchor:
|
||||
"""A fixed size anchor."""
|
||||
|
||||
x_center: float
|
||||
y_center: float
|
||||
width: Optional[float]
|
||||
height: Optional[float]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FixedAnchorsSchema:
|
||||
"""The schema for a list of anchors with fixed size."""
|
||||
|
||||
anchors: List[FixedAnchor]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SsdAnchorsOptions:
|
||||
"""The ssd anchors options used in object detector model."""
|
||||
|
||||
fixed_anchors_schema: Optional[FixedAnchorsSchema]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TensorsDecodingOptions:
|
||||
"""The decoding options to convert model output tensors to detections."""
|
||||
|
||||
# The number of output classes predicted by the detection model.
|
||||
num_classes: int
|
||||
# The number of output boxes predicted by the detection model.
|
||||
num_boxes: int
|
||||
# The number of output values per boxes predicted by the detection
|
||||
# model. The values contain bounding boxes, keypoints, etc.
|
||||
num_coords: int
|
||||
# The offset of keypoint coordinates in the location tensor.
|
||||
keypoint_coord_offset: int
|
||||
# The number of predicted keypoints.
|
||||
num_keypoints: int
|
||||
# The dimension of each keypoint, e.g. number of values predicted for each
|
||||
# keypoint.
|
||||
num_values_per_keypoint: int
|
||||
# Parameters for decoding SSD detection model.
|
||||
x_scale: float
|
||||
y_scale: float
|
||||
w_scale: float
|
||||
h_scale: float
|
||||
# Whether to apply exponential on box size.
|
||||
apply_exponential_on_box_size: bool
|
||||
# Whether to apply sigmod function on the score.
|
||||
sigmoid_score: bool
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer: bytearray) -> str:
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
||||
Raises:
|
||||
ValueError: error occurred when parsing the metadata schema file.
|
||||
"""
|
||||
return metadata.convert_to_json(
|
||||
metadata_buffer,
|
||||
custom_metadata_schema={_METADATA_NAME: _FLATC_METADATA_SCHEMA_FILE},
|
||||
)
|
||||
|
||||
|
||||
class ObjectDetectorOptionsMd(metadata_info.CustomMetadataMd):
|
||||
"""Object detector options metadata."""
|
||||
|
||||
_METADATA_FILE_IDENTIFIER = b"V001"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssd_anchors_options: SsdAnchorsOptions,
|
||||
tensors_decoding_options: TensorsDecodingOptions,
|
||||
) -> None:
|
||||
"""Creates an ObjectDetectorOptionsMd object.
|
||||
|
||||
Args:
|
||||
ssd_anchors_options: the ssd anchors options associated to the object
|
||||
detector model.
|
||||
tensors_decoding_options: the tensors decoding options used to decode the
|
||||
object detector model output.
|
||||
"""
|
||||
if ssd_anchors_options.fixed_anchors_schema is None:
|
||||
raise ValueError(
|
||||
"Currently only support FixedAnchorsSchema, which cannot be found"
|
||||
" in ssd_anchors_options."
|
||||
)
|
||||
self.ssd_anchors_options = ssd_anchors_options
|
||||
self.tensors_decoding_options = tensors_decoding_options
|
||||
super().__init__(name=_METADATA_NAME)
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||
"""Creates the image segmenter options metadata.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the custom metadata including object
|
||||
detector options metadata.
|
||||
"""
|
||||
detector_options = _detector_metadata_fb.ObjectDetectorOptionsT()
|
||||
|
||||
# Set ssd_anchors_options.
|
||||
ssd_anchors_options = _detector_metadata_fb.SsdAnchorsOptionsT()
|
||||
fixed_anchors_schema = _detector_metadata_fb.FixedAnchorsSchemaT()
|
||||
fixed_anchors_schema.anchors = []
|
||||
for anchor in self.ssd_anchors_options.fixed_anchors_schema.anchors:
|
||||
anchor_t = _detector_metadata_fb.FixedAnchorT()
|
||||
anchor_t.xCenter = anchor.x_center
|
||||
anchor_t.yCenter = anchor.y_center
|
||||
anchor_t.width = anchor.width
|
||||
anchor_t.height = anchor.height
|
||||
fixed_anchors_schema.anchors.append(anchor_t)
|
||||
ssd_anchors_options.fixedAnchorsSchema = fixed_anchors_schema
|
||||
detector_options.ssdAnchorsOptions = ssd_anchors_options
|
||||
|
||||
# Set tensors_decoding_options.
|
||||
tensors_decoding_options = _detector_metadata_fb.TensorsDecodingOptionsT()
|
||||
tensors_decoding_options.numClasses = (
|
||||
self.tensors_decoding_options.num_classes
|
||||
)
|
||||
tensors_decoding_options.numBoxes = self.tensors_decoding_options.num_boxes
|
||||
tensors_decoding_options.numCoords = (
|
||||
self.tensors_decoding_options.num_coords
|
||||
)
|
||||
tensors_decoding_options.keypointCoordOffset = (
|
||||
self.tensors_decoding_options.keypoint_coord_offset
|
||||
)
|
||||
tensors_decoding_options.numKeypoints = (
|
||||
self.tensors_decoding_options.num_keypoints
|
||||
)
|
||||
tensors_decoding_options.numValuesPerKeypoint = (
|
||||
self.tensors_decoding_options.num_values_per_keypoint
|
||||
)
|
||||
tensors_decoding_options.xScale = self.tensors_decoding_options.x_scale
|
||||
tensors_decoding_options.yScale = self.tensors_decoding_options.y_scale
|
||||
tensors_decoding_options.wScale = self.tensors_decoding_options.w_scale
|
||||
tensors_decoding_options.hScale = self.tensors_decoding_options.h_scale
|
||||
tensors_decoding_options.applyExponentialOnBoxSize = (
|
||||
self.tensors_decoding_options.apply_exponential_on_box_size
|
||||
)
|
||||
tensors_decoding_options.sigmoidScore = (
|
||||
self.tensors_decoding_options.sigmoid_score
|
||||
)
|
||||
detector_options.tensorsDecodingOptions = tensors_decoding_options
|
||||
|
||||
# Get the object detector options flatbuffer.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(detector_options.Pack(b), self._METADATA_FILE_IDENTIFIER)
|
||||
detector_options_buf = b.Output()
|
||||
|
||||
# Add the object detector options flatbuffer in custom metadata.
|
||||
custom_metadata = _metadata_fb.CustomMetadataT()
|
||||
custom_metadata.name = self.name
|
||||
custom_metadata.data = detector_options_buf
|
||||
return custom_metadata
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata into the object detector."""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
def create_for_models_with_nms(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
input_norm_mean: List[float],
|
||||
|
@ -38,7 +219,9 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
|||
labels: metadata_writer.Labels,
|
||||
score_calibration: Optional[metadata_writer.ScoreCalibration] = None,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for image classifier.
|
||||
"""Creates MetadataWriter to write the metadata for object detector with postprocessing in the model.
|
||||
|
||||
This method create a metadata writer for the models with postprocessing [1].
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
@ -54,18 +237,20 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
|||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
input_norm_mean: the mean value used in the input tensor normalization
|
||||
[1].
|
||||
input_norm_std: the std value used in the input tensor normalizarion [1].
|
||||
[2].
|
||||
input_norm_std: the std value used in the input tensor normalizarion [2].
|
||||
labels: an instance of Labels helper class used in the output
|
||||
classification tensor [2].
|
||||
score_calibration: A container of the score calibration operation [3] in
|
||||
classification tensor [3].
|
||||
score_calibration: A container of the score calibration operation [4] in
|
||||
the classification tensor. Optional if the model does not use score
|
||||
calibration.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
|
||||
|
||||
Returns:
|
||||
|
@ -76,3 +261,70 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
|||
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
writer.add_detection_output(labels, score_calibration)
|
||||
return cls(writer)
|
||||
|
||||
@classmethod
|
||||
def create_for_models_without_nms(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
input_norm_mean: List[float],
|
||||
input_norm_std: List[float],
|
||||
labels: metadata_writer.Labels,
|
||||
ssd_anchors_options: SsdAnchorsOptions,
|
||||
tensors_decoding_options: TensorsDecodingOptions,
|
||||
output_tensors_order: metadata_info.RawDetectionOutputTensorsOrder = metadata_info.RawDetectionOutputTensorsOrder.UNSPECIFIED,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for object detector without postprocessing in the model.
|
||||
|
||||
This method create a metadata writer for the models without postprocessing
|
||||
[1].
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
||||
Example usage:
|
||||
metadata_writer = object_detector.Metadatawriter.create(model_buffer, ...)
|
||||
tflite_content, json_content = metadata_writer.populate()
|
||||
|
||||
When calling `populate` function in this class, it returns TfLite content
|
||||
and JSON content. Note that only the output TFLite is used for deployment.
|
||||
The output JSON content is used to interpret the metadata content.
|
||||
|
||||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
input_norm_mean: the mean value used in the input tensor normalization
|
||||
[2].
|
||||
input_norm_std: the std value used in the input tensor normalizarion [2].
|
||||
labels: an instance of Labels helper class used in the output
|
||||
classification tensor [3].
|
||||
ssd_anchors_options: the ssd anchors options associated to the object
|
||||
detector model.
|
||||
tensors_decoding_options: the tensors decoding options used to decode the
|
||||
object detector model output.
|
||||
output_tensors_order: the order of the output tensors.
|
||||
[1]:
|
||||
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
writer.add_raw_detection_output(
|
||||
labels, output_tensors_order=output_tensors_order
|
||||
)
|
||||
option_md = ObjectDetectorOptionsMd(
|
||||
ssd_anchors_options, tensors_decoding_options
|
||||
)
|
||||
writer.add_custom_metadata(option_md)
|
||||
return cls(writer)
|
||||
|
||||
def populate(self) -> "tuple[bytearray, str]":
|
||||
model_buf, _ = super().populate()
|
||||
metadata_buf = metadata.get_metadata_buffer(model_buf)
|
||||
json_content = convert_to_json(metadata_buf)
|
||||
return model_buf, json_content
|
||||
|
|
|
@ -86,6 +86,7 @@ py_test(
|
|||
deps = [
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""Tests for metadata_writer.object_detector."""
|
||||
|
||||
import csv
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -21,6 +22,7 @@ from absl.testing import parameterized
|
|||
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import object_detector
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
@ -48,6 +50,39 @@ _JSON_FOR_SCORE_CALIBRATION = test_utils.get_test_data_path(
|
|||
os.path.join(_TEST_DATA_DIR, "coco_ssd_mobilenet_v1_score_calibration.json")
|
||||
)
|
||||
|
||||
_EFFICIENTDET_LITE0_ANCHORS_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "efficientdet_lite0_fp16_no_nms_anchors.csv")
|
||||
)
|
||||
|
||||
|
||||
def read_ssd_anchors_from_csv(file_path):
|
||||
with open(file_path, "r") as anchors_file:
|
||||
csv_reader = csv.reader(anchors_file, delimiter=",")
|
||||
parameters = []
|
||||
for row in csv_reader:
|
||||
if not row:
|
||||
parameters.append(None)
|
||||
continue
|
||||
if len(row) != 4:
|
||||
raise ValueError(
|
||||
"Expected empty lines or 4 parameters per line in "
|
||||
f"anchors file, but got {len(row)}."
|
||||
)
|
||||
parameters.append(row)
|
||||
anchors = []
|
||||
for parameter in parameters:
|
||||
anchors.append(
|
||||
object_detector.FixedAnchor(
|
||||
x_center=float(parameter[1]),
|
||||
y_center=float(parameter[0]),
|
||||
width=float(parameter[3]),
|
||||
height=float(parameter[2]),
|
||||
)
|
||||
)
|
||||
return object_detector.SsdAnchorsOptions(
|
||||
fixed_anchors_schema=object_detector.FixedAnchorsSchema(anchors)
|
||||
)
|
||||
|
||||
|
||||
class MetadataWriterTest(parameterized.TestCase, absltest.TestCase):
|
||||
|
||||
|
@ -61,37 +96,41 @@ class MetadataWriterTest(parameterized.TestCase, absltest.TestCase):
|
|||
)
|
||||
with open(model_path, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = object_detector.MetadataWriter.create(
|
||||
model_buffer,
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
writer = (
|
||||
object_detector.MetadataWriter.create_for_models_with_nms(
|
||||
model_buffer,
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
)
|
||||
)
|
||||
_, metadata_json = writer.populate()
|
||||
expected_json_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, model_name + ".json")
|
||||
)
|
||||
with open(expected_json_path, "r") as f:
|
||||
expected_json = f.read()
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_create_with_score_calibration_should_succeed(self):
|
||||
with open(_MODEL_COCO, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = object_detector.MetadataWriter.create(
|
||||
model_buffer,
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
score_calibration=metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_fb.ScoreTransformationType.INVERSE_LOGISTIC,
|
||||
_SCORE_CALIBRATION_FILE,
|
||||
_SCORE_CALIBRATION_DEFAULT_SCORE,
|
||||
),
|
||||
writer = (
|
||||
object_detector.MetadataWriter.create_for_models_with_nms(
|
||||
model_buffer,
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
score_calibration=metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_fb.ScoreTransformationType.INVERSE_LOGISTIC,
|
||||
_SCORE_CALIBRATION_FILE,
|
||||
_SCORE_CALIBRATION_DEFAULT_SCORE,
|
||||
),
|
||||
)
|
||||
)
|
||||
tflite_content, metadata_json = writer.populate()
|
||||
with open(_JSON_FOR_SCORE_CALIBRATION, "r") as f:
|
||||
expected_json = f.read()
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||
|
|
6
mediapipe/tasks/testdata/metadata/BUILD
vendored
6
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -32,7 +32,7 @@ mediapipe_files(srcs = [
|
|||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"efficientdet_lite0_fp16_no_nms.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"labelmap.txt",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
|
@ -64,6 +64,8 @@ exports_files([
|
|||
"classification_tensor_float_meta.json",
|
||||
"classification_tensor_uint8_meta.json",
|
||||
"classification_tensor_unsupported_meta.json",
|
||||
"efficientdet_lite0_fp16_no_nms_anchors.csv",
|
||||
"efficientdet_lite0_fp16_no_nms.json",
|
||||
"feature_tensor_meta.json",
|
||||
"image_tensor_meta.json",
|
||||
"input_image_tensor_float_meta.json",
|
||||
|
@ -94,6 +96,7 @@ filegroup(
|
|||
"bert_text_classifier_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"efficientdet_lite0_fp16_no_nms.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||
|
@ -126,6 +129,7 @@ filegroup(
|
|||
"deeplabv3.json",
|
||||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"efficientdet_lite0_fp16_no_nms_anchors.csv",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"external_file",
|
||||
"feature_tensor_meta.json",
|
||||
|
|
|
@ -56,23 +56,20 @@
|
|||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
},
|
||||
{
|
||||
"name": "category",
|
||||
"description": "The categories of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
},
|
||||
"content_properties": {},
|
||||
"range": {
|
||||
"min": 2,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
},
|
||||
"stats": {},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
|
@ -86,8 +83,7 @@
|
|||
"description": "The scores of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
},
|
||||
"content_properties": {},
|
||||
"range": {
|
||||
"min": 2,
|
||||
"max": 2
|
||||
|
@ -102,8 +98,7 @@
|
|||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
},
|
||||
"stats": {},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "score_calibration.txt",
|
||||
|
@ -117,11 +112,9 @@
|
|||
"description": "The number of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
"content_properties": {}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
}
|
||||
],
|
||||
"output_tensor_groups": [
|
||||
|
|
19206
mediapipe/tasks/testdata/metadata/efficientdet_lite0_fp16_no_nms_anchors.csv
vendored
Normal file
19206
mediapipe/tasks/testdata/metadata/efficientdet_lite0_fp16_no_nms_anchors.csv
vendored
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -42,15 +42,13 @@
|
|||
"description": "The scores of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
},
|
||||
"content_properties": {},
|
||||
"range": {
|
||||
"min": 2,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
},
|
||||
{
|
||||
"name": "location",
|
||||
|
@ -71,34 +69,29 @@
|
|||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
},
|
||||
{
|
||||
"name": "number of detections",
|
||||
"description": "The number of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
"content_properties": {}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
},
|
||||
{
|
||||
"name": "category",
|
||||
"description": "The categories of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
},
|
||||
"content_properties": {},
|
||||
"range": {
|
||||
"min": 2,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
},
|
||||
"stats": {},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
|
|
|
@ -56,23 +56,20 @@
|
|||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
},
|
||||
{
|
||||
"name": "category",
|
||||
"description": "The categories of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
},
|
||||
"content_properties": {},
|
||||
"range": {
|
||||
"min": 2,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
},
|
||||
"stats": {},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
|
@ -86,26 +83,22 @@
|
|||
"description": "The scores of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
},
|
||||
"content_properties": {},
|
||||
"range": {
|
||||
"min": 2,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
},
|
||||
{
|
||||
"name": "number of detections",
|
||||
"description": "The number of the detected boxes.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
"content_properties": {}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
"stats": {}
|
||||
}
|
||||
],
|
||||
"output_tensor_groups": [
|
||||
|
|
30
third_party/external_files.bzl
vendored
30
third_party/external_files.bzl
vendored
|
@ -198,8 +198,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_coco_ssd_mobilenet_v1_score_calibration_json",
|
||||
sha256 = "f377600be924c29697477f9d739db9db5d712aec4a644548526912858db6a082",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"],
|
||||
sha256 = "a850674f9043bfc775527fee7f1b639f7fe0fb56e8d3ed2b710247967c888b09",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1682456086898538"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -262,10 +262,28 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/dynamic_input_classifier.tflite?generation=1680543275416843"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_anchors_csv",
|
||||
sha256 = "284475a0f16e34afcc6c0fe68b05bd871aca5b20c83db0870c6a36dd63827176",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms_anchors.csv?generation=1682456090001817"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_json",
|
||||
sha256 = "dc3b333e41c43fb49ace048c25c18d0e34df78fb5ee77edbe169264368f78b92",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.json?generation=1682456092938505"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_tflite",
|
||||
sha256 = "bcda125c96d3767bca894c8cbe7bc458379c9974c9fd8bdc6204e7124a74082a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682456096034465"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_efficientdet_lite0_v1_json",
|
||||
sha256 = "7a9e1fb625a6130a251e612637fc546cfc8cfabfadc7dbdade44c87f1d8996ca",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_v1.json?generation=1677522746026682"],
|
||||
sha256 = "ef9706696a3ea5d87f4324ac56e877a92033d33e522c4b7d5a416fbcab24d8fc",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_v1.json?generation=1682456098581704"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -1158,8 +1176,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_ssd_mobilenet_v1_no_metadata_json",
|
||||
sha256 = "89157590b736cf3f3247aa9c8be3570c2856f4981a1e9476117e7c629e7c4825",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1_no_metadata.json?generation=1677522786336455"],
|
||||
sha256 = "ae5a5971a1c3df705307448ef97c854d846b7e6f2183fb51015bd5af5d7deb0f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1_no_metadata.json?generation=1682456117002011"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user