Add custom metadata for object detection model with out-of-graph nms.

PiperOrigin-RevId: 527083453
This commit is contained in:
MediaPipe Team 2023-04-25 14:56:38 -07:00 committed by Copybara-Service
parent 17f5b95387
commit 507ed0d91d
15 changed files with 19798 additions and 77 deletions

View File

@ -328,7 +328,7 @@ class ObjectDetector(classifier.Classifier):
converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,) converter.target_spec.supported_ops = (tf.lite.OpsSet.TFLITE_BUILTINS,)
tflite_model = converter.convert() tflite_model = converter.convert()
writer = object_detector_writer.MetadataWriter.create( writer = object_detector_writer.MetadataWriter.create_for_models_with_nms(
tflite_model, tflite_model,
self._model_spec.mean_rgb, self._model_spec.mean_rgb,
self._model_spec.stddev_rgb, self._model_spec.stddev_rgb,

View File

@ -36,3 +36,13 @@ flatbuffer_py_library(
name = "image_segmenter_metadata_schema_py", name = "image_segmenter_metadata_schema_py",
srcs = ["image_segmenter_metadata_schema.fbs"], srcs = ["image_segmenter_metadata_schema.fbs"],
) )
flatbuffer_cc_library(
name = "object_detector_metadata_schema_cc",
srcs = ["object_detector_metadata_schema.fbs"],
)
flatbuffer_py_library(
name = "object_detector_metadata_schema_py",
srcs = ["object_detector_metadata_schema.fbs"],
)

View File

@ -0,0 +1,98 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace mediapipe.tasks;
// ObjectDetectorOptions.min_parser_version indicates the minimum necessary
// object detector metadata parser version to fully understand all fields in a
// given metadata flatbuffer. This min_parser_version is specific for the
// object detector metadata defined in this schema file.
//
// New fields and types will have associated comments with the schema version
// for which they were added.
//
// Schema Semantic version: 1.0.0
// This indicates the flatbuffer compatibility. The number will bump up when a
// break change is applied to the schema, such as removing fields or adding new
// fields to the middle of a table.
file_identifier "V001";
// History:
// 1.0.0 - Initial version.
// A fixed size anchor.
table FixedAnchor {
x_center: float;
y_center: float;
width: float;
height: float;
}
// The schema for a list of anchors with fixed size.
table FixedAnchorsSchema {
anchors: [FixedAnchor];
}
// The ssd anchors options used in the object detector.
table SsdAnchorsOptions {
fixed_anchors_schema: FixedAnchorsSchema;
}
// The options for decoding the raw model output tensors. The options are mostly
// used in TensorsToDetectionsCalculatorOptions.
table TensorsDecodingOptions {
// The number of output classes predicted by the detection model.
num_classes: int;
// The number of output boxes predicted by the detection model.
num_boxes: int;
// The number of output values per boxes predicted by the detection
// model. The values contain bounding boxes, keypoints, etc.
num_coords: int;
// The offset of keypoint coordinates in the location tensor.
keypoint_coord_offset: int;
// The number of predicted keypoints.
num_keypoints: int;
// The dimension of each keypoint, e.g. number of values predicted for each
// keypoint.
num_values_per_keypoint: int;
// Parameters for decoding SSD detection model.
x_scale: float;
y_scale: float;
w_scale: float;
h_scale: float;
// Whether to apply exponential on box size.
apply_exponential_on_box_size: bool;
// Whether to apply sigmod function on the score.
sigmoid_score: bool;
}
table ObjectDetectorOptions {
// TODO: automatically populate min parser string.
// The minimum necessary object detector metadata parser version to fully
// understand all fields in a given metadata flatbuffer. This field is
// automatically populated by the MetadataPopulator when the metadata is
// populated into a TFLite model. This min_parser_version is specific for the
// object detector metadata defined in this schema file.
min_parser_version:string;
// The options of ssd anchors configs used by the detection model.
ssd_anchors_options:SsdAnchorsOptions;
// The tensors decoding options to convert raw tensors to detection results.
tensors_decoding_options:TensorsDecodingOptions;
}
root_type ObjectDetectorOptions;

View File

@ -67,7 +67,15 @@ py_library(
py_library( py_library(
name = "object_detector", name = "object_detector",
srcs = ["object_detector.py"], srcs = ["object_detector.py"],
deps = [":metadata_writer"], data = ["//mediapipe/tasks/metadata:object_detector_metadata_schema.fbs"],
deps = [
":metadata_info",
":metadata_writer",
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:object_detector_metadata_schema_py",
"//mediapipe/tasks/python/metadata",
"@flatbuffers//:runtime_py",
],
) )
py_library( py_library(

View File

@ -17,6 +17,7 @@
import abc import abc
import collections import collections
import csv import csv
import enum
import os import os
from typing import List, Optional, Type, Union from typing import List, Optional, Type, Union
@ -1004,6 +1005,84 @@ class DetectionOutputTensorsMd:
return self._output_mds return self._output_mds
class RawDetectionOutputTensorsOrder(enum.Enum):
"""Output tensors order for detection models without postprocessing.
Because it is not able to determined the order of output tensors for models
without postprocessing, it is needed to specify the output tensors order for
metadata writer.
"""
UNSPECIFIED = 0
# The first tensor is score, and the second tensor is location.
SCORE_LOCATION = 1
# The first tensor is location, and the second tensor is score.
LOCATION_SCORE = 2
class RawDetectionOutputTensorsMd:
"""A container for the output tensor metadata of detection models without postprocessing."""
_LOCATION_NAME = "location"
_LOCATION_DESCRIPTION = "The locations of the detected boxes."
_SCORE_NAME = "score"
_SCORE_DESCRIPTION = "The scores of the detected boxes."
_CONTENT_VALUE_DIM = 2
def __init__(
self,
model_buffer: bytearray,
label_files: Optional[List[LabelFileMd]] = None,
output_tensors_order: RawDetectionOutputTensorsOrder = RawDetectionOutputTensorsOrder.UNSPECIFIED,
) -> None:
"""Initializes the instance of DetectionOutputTensorsMd.
Args:
model_buffer: A valid flatbuffer loaded from the TFLite model file.
label_files: information of the label files [1] in the classification
tensor.
output_tensors_order: the order of the output tensors.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L9
"""
# Get the output tensor indices and names from the tflite model.
tensor_indices_and_names = list(
zip(
writer_utils.get_output_tensor_indices(model_buffer),
writer_utils.get_output_tensor_names(model_buffer),
)
)
location_md = LocationTensorMd(
name=self._LOCATION_NAME,
description=self._LOCATION_DESCRIPTION,
)
score_md = ClassificationTensorMd(
name=self._SCORE_NAME,
description=self._SCORE_DESCRIPTION,
label_files=label_files,
)
if output_tensors_order == RawDetectionOutputTensorsOrder.SCORE_LOCATION:
self._output_mds = [score_md, location_md]
elif output_tensors_order == RawDetectionOutputTensorsOrder.LOCATION_SCORE:
self._output_mds = [location_md, score_md]
else:
raise ValueError(
f"Unsupported OutputTensorsOrder value: {output_tensors_order}"
)
if len(self._output_mds) != len(tensor_indices_and_names):
raise ValueError(
"The size of TFLite output should be " + str(len(self._output_mds))
)
for i, output_md in enumerate(self._output_mds):
output_md.tensor_name = tensor_indices_and_names[i][1]
@property
def output_mds(self) -> List[TensorMd]:
return self._output_mds
class TensorGroupMd: class TensorGroupMd:
"""A container for a group of tensor metadata information.""" """A container for a group of tensor metadata information."""

View File

@ -632,7 +632,7 @@ class MetadataWriter(object):
score_calibration: Optional[ScoreCalibration] = None, score_calibration: Optional[ScoreCalibration] = None,
group_name: str = _DETECTION_GROUP_NAME, group_name: str = _DETECTION_GROUP_NAME,
) -> 'MetadataWriter': ) -> 'MetadataWriter':
"""Adds a detection head metadata for detection output tensor. """Adds a detection head metadata for detection output tensor of models with postprocessing.
Args: Args:
labels: an instance of Labels helper class. labels: an instance of Labels helper class.
@ -661,6 +661,33 @@ class MetadataWriter(object):
self._output_group_mds.append(group_md) self._output_group_mds.append(group_md)
return self return self
def add_raw_detection_output(
self,
labels: Optional[Labels] = None,
output_tensors_order: metadata_info.RawDetectionOutputTensorsOrder = metadata_info.RawDetectionOutputTensorsOrder.UNSPECIFIED,
) -> 'MetadataWriter':
"""Adds a detection head metadata for detection output tensor of models without postprocessing.
Args:
labels: an instance of Labels helper class.
output_tensors_order: the order of the output tensors. For models of
out-of-graph non-maximum-suppression only.
Returns:
The current Writer instance to allow chained operation.
"""
label_files = self._create_label_file_md(labels)
detection_output_mds = metadata_info.RawDetectionOutputTensorsMd(
self._model_buffer,
label_files=label_files,
output_tensors_order=output_tensors_order,
).output_mds
self._output_mds.extend(detection_output_mds)
# Outputs are location, score.
if len(detection_output_mds) != 2:
raise ValueError('The size of detections output should be 2.')
return self
def add_segmentation_output( def add_segmentation_output(
self, self,
labels: Optional[Labels] = None, labels: Optional[Labels] = None,

View File

@ -14,8 +14,14 @@
# ============================================================================== # ==============================================================================
"""Writes metadata and label file to the Object Detector models.""" """Writes metadata and label file to the Object Detector models."""
import dataclasses
from typing import List, Optional from typing import List, Optional
import flatbuffers
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import object_detector_metadata_schema_py_generated as _detector_metadata_fb
from mediapipe.tasks.python.metadata import metadata
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "ObjectDetector" _MODEL_NAME = "ObjectDetector"
@ -25,12 +31,187 @@ _MODEL_DESCRIPTION = (
"stream." "stream."
) )
# Metadata Schema file for object detector.
_FLATC_METADATA_SCHEMA_FILE = metadata.get_path_to_datafile(
"../../../metadata/object_detector_metadata_schema.fbs",
)
# Metadata name in custom metadata field. The metadata name is used to get
# object detector metadata from SubGraphMetadata.custom_metadata and shouldn't
# be changed.
_METADATA_NAME = "DETECTOR_METADATA"
@dataclasses.dataclass
class FixedAnchor:
"""A fixed size anchor."""
x_center: float
y_center: float
width: Optional[float]
height: Optional[float]
@dataclasses.dataclass
class FixedAnchorsSchema:
"""The schema for a list of anchors with fixed size."""
anchors: List[FixedAnchor]
@dataclasses.dataclass
class SsdAnchorsOptions:
"""The ssd anchors options used in object detector model."""
fixed_anchors_schema: Optional[FixedAnchorsSchema]
@dataclasses.dataclass
class TensorsDecodingOptions:
"""The decoding options to convert model output tensors to detections."""
# The number of output classes predicted by the detection model.
num_classes: int
# The number of output boxes predicted by the detection model.
num_boxes: int
# The number of output values per boxes predicted by the detection
# model. The values contain bounding boxes, keypoints, etc.
num_coords: int
# The offset of keypoint coordinates in the location tensor.
keypoint_coord_offset: int
# The number of predicted keypoints.
num_keypoints: int
# The dimension of each keypoint, e.g. number of values predicted for each
# keypoint.
num_values_per_keypoint: int
# Parameters for decoding SSD detection model.
x_scale: float
y_scale: float
w_scale: float
h_scale: float
# Whether to apply exponential on box size.
apply_exponential_on_box_size: bool
# Whether to apply sigmod function on the score.
sigmoid_score: bool
# Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util.
def convert_to_json(metadata_buffer: bytearray) -> str:
"""Converts the metadata into a json string.
Args:
metadata_buffer: valid metadata buffer in bytes.
Returns:
Metadata in JSON format.
Raises:
ValueError: error occurred when parsing the metadata schema file.
"""
return metadata.convert_to_json(
metadata_buffer,
custom_metadata_schema={_METADATA_NAME: _FLATC_METADATA_SCHEMA_FILE},
)
class ObjectDetectorOptionsMd(metadata_info.CustomMetadataMd):
"""Object detector options metadata."""
_METADATA_FILE_IDENTIFIER = b"V001"
def __init__(
self,
ssd_anchors_options: SsdAnchorsOptions,
tensors_decoding_options: TensorsDecodingOptions,
) -> None:
"""Creates an ObjectDetectorOptionsMd object.
Args:
ssd_anchors_options: the ssd anchors options associated to the object
detector model.
tensors_decoding_options: the tensors decoding options used to decode the
object detector model output.
"""
if ssd_anchors_options.fixed_anchors_schema is None:
raise ValueError(
"Currently only support FixedAnchorsSchema, which cannot be found"
" in ssd_anchors_options."
)
self.ssd_anchors_options = ssd_anchors_options
self.tensors_decoding_options = tensors_decoding_options
super().__init__(name=_METADATA_NAME)
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
"""Creates the image segmenter options metadata.
Returns:
A Flatbuffers Python object of the custom metadata including object
detector options metadata.
"""
detector_options = _detector_metadata_fb.ObjectDetectorOptionsT()
# Set ssd_anchors_options.
ssd_anchors_options = _detector_metadata_fb.SsdAnchorsOptionsT()
fixed_anchors_schema = _detector_metadata_fb.FixedAnchorsSchemaT()
fixed_anchors_schema.anchors = []
for anchor in self.ssd_anchors_options.fixed_anchors_schema.anchors:
anchor_t = _detector_metadata_fb.FixedAnchorT()
anchor_t.xCenter = anchor.x_center
anchor_t.yCenter = anchor.y_center
anchor_t.width = anchor.width
anchor_t.height = anchor.height
fixed_anchors_schema.anchors.append(anchor_t)
ssd_anchors_options.fixedAnchorsSchema = fixed_anchors_schema
detector_options.ssdAnchorsOptions = ssd_anchors_options
# Set tensors_decoding_options.
tensors_decoding_options = _detector_metadata_fb.TensorsDecodingOptionsT()
tensors_decoding_options.numClasses = (
self.tensors_decoding_options.num_classes
)
tensors_decoding_options.numBoxes = self.tensors_decoding_options.num_boxes
tensors_decoding_options.numCoords = (
self.tensors_decoding_options.num_coords
)
tensors_decoding_options.keypointCoordOffset = (
self.tensors_decoding_options.keypoint_coord_offset
)
tensors_decoding_options.numKeypoints = (
self.tensors_decoding_options.num_keypoints
)
tensors_decoding_options.numValuesPerKeypoint = (
self.tensors_decoding_options.num_values_per_keypoint
)
tensors_decoding_options.xScale = self.tensors_decoding_options.x_scale
tensors_decoding_options.yScale = self.tensors_decoding_options.y_scale
tensors_decoding_options.wScale = self.tensors_decoding_options.w_scale
tensors_decoding_options.hScale = self.tensors_decoding_options.h_scale
tensors_decoding_options.applyExponentialOnBoxSize = (
self.tensors_decoding_options.apply_exponential_on_box_size
)
tensors_decoding_options.sigmoidScore = (
self.tensors_decoding_options.sigmoid_score
)
detector_options.tensorsDecodingOptions = tensors_decoding_options
# Get the object detector options flatbuffer.
b = flatbuffers.Builder(0)
b.Finish(detector_options.Pack(b), self._METADATA_FILE_IDENTIFIER)
detector_options_buf = b.Output()
# Add the object detector options flatbuffer in custom metadata.
custom_metadata = _metadata_fb.CustomMetadataT()
custom_metadata.name = self.name
custom_metadata.data = detector_options_buf
return custom_metadata
class MetadataWriter(metadata_writer.MetadataWriterBase): class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata into the object detector.""" """MetadataWriter to write the metadata into the object detector."""
@classmethod @classmethod
def create( def create_for_models_with_nms(
cls, cls,
model_buffer: bytearray, model_buffer: bytearray,
input_norm_mean: List[float], input_norm_mean: List[float],
@ -38,7 +219,9 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
labels: metadata_writer.Labels, labels: metadata_writer.Labels,
score_calibration: Optional[metadata_writer.ScoreCalibration] = None, score_calibration: Optional[metadata_writer.ScoreCalibration] = None,
) -> "MetadataWriter": ) -> "MetadataWriter":
"""Creates MetadataWriter to write the metadata for image classifier. """Creates MetadataWriter to write the metadata for object detector with postprocessing in the model.
This method create a metadata writer for the models with postprocessing [1].
The parameters required in this method are mandatory when using MediaPipe The parameters required in this method are mandatory when using MediaPipe
Tasks. Tasks.
@ -54,18 +237,20 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
Args: Args:
model_buffer: A valid flatbuffer loaded from the TFLite model file. model_buffer: A valid flatbuffer loaded from the TFLite model file.
input_norm_mean: the mean value used in the input tensor normalization input_norm_mean: the mean value used in the input tensor normalization
[1]. [2].
input_norm_std: the std value used in the input tensor normalizarion [1]. input_norm_std: the std value used in the input tensor normalizarion [2].
labels: an instance of Labels helper class used in the output labels: an instance of Labels helper class used in the output
classification tensor [2]. classification tensor [3].
score_calibration: A container of the score calibration operation [3] in score_calibration: A container of the score calibration operation [4] in
the classification tensor. Optional if the model does not use score the classification tensor. Optional if the model does not use score
calibration. calibration.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
[2]: [2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
[3]: [3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456
Returns: Returns:
@ -76,3 +261,70 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
writer.add_image_input(input_norm_mean, input_norm_std) writer.add_image_input(input_norm_mean, input_norm_std)
writer.add_detection_output(labels, score_calibration) writer.add_detection_output(labels, score_calibration)
return cls(writer) return cls(writer)
@classmethod
def create_for_models_without_nms(
cls,
model_buffer: bytearray,
input_norm_mean: List[float],
input_norm_std: List[float],
labels: metadata_writer.Labels,
ssd_anchors_options: SsdAnchorsOptions,
tensors_decoding_options: TensorsDecodingOptions,
output_tensors_order: metadata_info.RawDetectionOutputTensorsOrder = metadata_info.RawDetectionOutputTensorsOrder.UNSPECIFIED,
) -> "MetadataWriter":
"""Creates MetadataWriter to write the metadata for object detector without postprocessing in the model.
This method create a metadata writer for the models without postprocessing
[1].
The parameters required in this method are mandatory when using MediaPipe
Tasks.
Example usage:
metadata_writer = object_detector.Metadatawriter.create(model_buffer, ...)
tflite_content, json_content = metadata_writer.populate()
When calling `populate` function in this class, it returns TfLite content
and JSON content. Note that only the output TFLite is used for deployment.
The output JSON content is used to interpret the metadata content.
Args:
model_buffer: A valid flatbuffer loaded from the TFLite model file.
input_norm_mean: the mean value used in the input tensor normalization
[2].
input_norm_std: the std value used in the input tensor normalizarion [2].
labels: an instance of Labels helper class used in the output
classification tensor [3].
ssd_anchors_options: the ssd anchors options associated to the object
detector model.
tensors_decoding_options: the tensors decoding options used to decode the
object detector model output.
output_tensors_order: the order of the output tensors.
[1]:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_image_input(input_norm_mean, input_norm_std)
writer.add_raw_detection_output(
labels, output_tensors_order=output_tensors_order
)
option_md = ObjectDetectorOptionsMd(
ssd_anchors_options, tensors_decoding_options
)
writer.add_custom_metadata(option_md)
return cls(writer)
def populate(self) -> "tuple[bytearray, str]":
model_buf, _ = super().populate()
metadata_buf = metadata.get_metadata_buffer(model_buf)
json_content = convert_to_json(metadata_buf)
return model_buf, json_content

View File

@ -86,6 +86,7 @@ py_test(
deps = [ deps = [
"//mediapipe/tasks/metadata:metadata_schema_py", "//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/python/metadata", "//mediapipe/tasks/python/metadata",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_info",
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
"//mediapipe/tasks/python/metadata/metadata_writers:object_detector", "//mediapipe/tasks/python/metadata/metadata_writers:object_detector",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",

View File

@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for metadata_writer.object_detector.""" """Tests for metadata_writer.object_detector."""
import csv
import os import os
from absl.testing import absltest from absl.testing import absltest
@ -21,6 +22,7 @@ from absl.testing import parameterized
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
from mediapipe.tasks.python.metadata import metadata from mediapipe.tasks.python.metadata import metadata
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import object_detector from mediapipe.tasks.python.metadata.metadata_writers import object_detector
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@ -48,6 +50,39 @@ _JSON_FOR_SCORE_CALIBRATION = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "coco_ssd_mobilenet_v1_score_calibration.json") os.path.join(_TEST_DATA_DIR, "coco_ssd_mobilenet_v1_score_calibration.json")
) )
_EFFICIENTDET_LITE0_ANCHORS_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "efficientdet_lite0_fp16_no_nms_anchors.csv")
)
def read_ssd_anchors_from_csv(file_path):
with open(file_path, "r") as anchors_file:
csv_reader = csv.reader(anchors_file, delimiter=",")
parameters = []
for row in csv_reader:
if not row:
parameters.append(None)
continue
if len(row) != 4:
raise ValueError(
"Expected empty lines or 4 parameters per line in "
f"anchors file, but got {len(row)}."
)
parameters.append(row)
anchors = []
for parameter in parameters:
anchors.append(
object_detector.FixedAnchor(
x_center=float(parameter[1]),
y_center=float(parameter[0]),
width=float(parameter[3]),
height=float(parameter[2]),
)
)
return object_detector.SsdAnchorsOptions(
fixed_anchors_schema=object_detector.FixedAnchorsSchema(anchors)
)
class MetadataWriterTest(parameterized.TestCase, absltest.TestCase): class MetadataWriterTest(parameterized.TestCase, absltest.TestCase):
@ -61,37 +96,41 @@ class MetadataWriterTest(parameterized.TestCase, absltest.TestCase):
) )
with open(model_path, "rb") as f: with open(model_path, "rb") as f:
model_buffer = f.read() model_buffer = f.read()
writer = object_detector.MetadataWriter.create( writer = (
model_buffer, object_detector.MetadataWriter.create_for_models_with_nms(
[_NORM_MEAN], model_buffer,
[_NORM_STD], [_NORM_MEAN],
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE), [_NORM_STD],
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
)
) )
_, metadata_json = writer.populate() _, metadata_json = writer.populate()
expected_json_path = test_utils.get_test_data_path( expected_json_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, model_name + ".json") os.path.join(_TEST_DATA_DIR, model_name + ".json")
) )
with open(expected_json_path, "r") as f: with open(expected_json_path, "r") as f:
expected_json = f.read() expected_json = f.read().strip()
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)
def test_create_with_score_calibration_should_succeed(self): def test_create_with_score_calibration_should_succeed(self):
with open(_MODEL_COCO, "rb") as f: with open(_MODEL_COCO, "rb") as f:
model_buffer = f.read() model_buffer = f.read()
writer = object_detector.MetadataWriter.create( writer = (
model_buffer, object_detector.MetadataWriter.create_for_models_with_nms(
[_NORM_MEAN], model_buffer,
[_NORM_STD], [_NORM_MEAN],
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE), [_NORM_STD],
score_calibration=metadata_writer.ScoreCalibration.create_from_file( labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
metadata_fb.ScoreTransformationType.INVERSE_LOGISTIC, score_calibration=metadata_writer.ScoreCalibration.create_from_file(
_SCORE_CALIBRATION_FILE, metadata_fb.ScoreTransformationType.INVERSE_LOGISTIC,
_SCORE_CALIBRATION_DEFAULT_SCORE, _SCORE_CALIBRATION_FILE,
), _SCORE_CALIBRATION_DEFAULT_SCORE,
),
)
) )
tflite_content, metadata_json = writer.populate() tflite_content, metadata_json = writer.populate()
with open(_JSON_FOR_SCORE_CALIBRATION, "r") as f: with open(_JSON_FOR_SCORE_CALIBRATION, "r") as f:
expected_json = f.read() expected_json = f.read().strip()
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content) displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)

View File

@ -32,7 +32,7 @@ mediapipe_files(srcs = [
"deeplabv3_with_activation.json", "deeplabv3_with_activation.json",
"deeplabv3_without_labels.json", "deeplabv3_without_labels.json",
"deeplabv3_without_metadata.tflite", "deeplabv3_without_metadata.tflite",
"efficientdet_lite0_v1.json", "efficientdet_lite0_fp16_no_nms.tflite",
"efficientdet_lite0_v1.tflite", "efficientdet_lite0_v1.tflite",
"labelmap.txt", "labelmap.txt",
"mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-custom-metadata.tflite",
@ -64,6 +64,8 @@ exports_files([
"classification_tensor_float_meta.json", "classification_tensor_float_meta.json",
"classification_tensor_uint8_meta.json", "classification_tensor_uint8_meta.json",
"classification_tensor_unsupported_meta.json", "classification_tensor_unsupported_meta.json",
"efficientdet_lite0_fp16_no_nms_anchors.csv",
"efficientdet_lite0_fp16_no_nms.json",
"feature_tensor_meta.json", "feature_tensor_meta.json",
"image_tensor_meta.json", "image_tensor_meta.json",
"input_image_tensor_float_meta.json", "input_image_tensor_float_meta.json",
@ -94,6 +96,7 @@ filegroup(
"bert_text_classifier_no_metadata.tflite", "bert_text_classifier_no_metadata.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
"deeplabv3_without_metadata.tflite", "deeplabv3_without_metadata.tflite",
"efficientdet_lite0_fp16_no_nms.tflite",
"efficientdet_lite0_v1.tflite", "efficientdet_lite0_v1.tflite",
"mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-custom-metadata.tflite",
"mobile_ica_8bit-with-large-min-parser-version.tflite", "mobile_ica_8bit-with-large-min-parser-version.tflite",
@ -126,6 +129,7 @@ filegroup(
"deeplabv3.json", "deeplabv3.json",
"deeplabv3_with_activation.json", "deeplabv3_with_activation.json",
"deeplabv3_without_labels.json", "deeplabv3_without_labels.json",
"efficientdet_lite0_fp16_no_nms_anchors.csv",
"efficientdet_lite0_v1.json", "efficientdet_lite0_v1.json",
"external_file", "external_file",
"feature_tensor_meta.json", "feature_tensor_meta.json",

View File

@ -56,23 +56,20 @@
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {}
}
}, },
{ {
"name": "category", "name": "category",
"description": "The categories of the detected boxes.", "description": "The categories of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {},
},
"range": { "range": {
"min": 2, "min": 2,
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {},
},
"associated_files": [ "associated_files": [
{ {
"name": "labels.txt", "name": "labels.txt",
@ -86,8 +83,7 @@
"description": "The scores of the detected boxes.", "description": "The scores of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {},
},
"range": { "range": {
"min": 2, "min": 2,
"max": 2 "max": 2
@ -102,8 +98,7 @@
} }
} }
], ],
"stats": { "stats": {},
},
"associated_files": [ "associated_files": [
{ {
"name": "score_calibration.txt", "name": "score_calibration.txt",
@ -117,11 +112,9 @@
"description": "The number of the detected boxes.", "description": "The number of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {}
}
}, },
"stats": { "stats": {}
}
} }
], ],
"output_tensor_groups": [ "output_tensor_groups": [

File diff suppressed because it is too large Load Diff

View File

@ -42,15 +42,13 @@
"description": "The scores of the detected boxes.", "description": "The scores of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {},
},
"range": { "range": {
"min": 2, "min": 2,
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {}
}
}, },
{ {
"name": "location", "name": "location",
@ -71,34 +69,29 @@
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {}
}
}, },
{ {
"name": "number of detections", "name": "number of detections",
"description": "The number of the detected boxes.", "description": "The number of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {}
}
}, },
"stats": { "stats": {}
}
}, },
{ {
"name": "category", "name": "category",
"description": "The categories of the detected boxes.", "description": "The categories of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {},
},
"range": { "range": {
"min": 2, "min": 2,
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {},
},
"associated_files": [ "associated_files": [
{ {
"name": "labels.txt", "name": "labels.txt",

View File

@ -56,23 +56,20 @@
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {}
}
}, },
{ {
"name": "category", "name": "category",
"description": "The categories of the detected boxes.", "description": "The categories of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {},
},
"range": { "range": {
"min": 2, "min": 2,
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {},
},
"associated_files": [ "associated_files": [
{ {
"name": "labels.txt", "name": "labels.txt",
@ -86,26 +83,22 @@
"description": "The scores of the detected boxes.", "description": "The scores of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {},
},
"range": { "range": {
"min": 2, "min": 2,
"max": 2 "max": 2
} }
}, },
"stats": { "stats": {}
}
}, },
{ {
"name": "number of detections", "name": "number of detections",
"description": "The number of the detected boxes.", "description": "The number of the detected boxes.",
"content": { "content": {
"content_properties_type": "FeatureProperties", "content_properties_type": "FeatureProperties",
"content_properties": { "content_properties": {}
}
}, },
"stats": { "stats": {}
}
} }
], ],
"output_tensor_groups": [ "output_tensor_groups": [

View File

@ -198,8 +198,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_coco_ssd_mobilenet_v1_score_calibration_json", name = "com_google_mediapipe_coco_ssd_mobilenet_v1_score_calibration_json",
sha256 = "f377600be924c29697477f9d739db9db5d712aec4a644548526912858db6a082", sha256 = "a850674f9043bfc775527fee7f1b639f7fe0fb56e8d3ed2b710247967c888b09",
urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1677522739770755"], urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_score_calibration.json?generation=1682456086898538"],
) )
http_file( http_file(
@ -262,10 +262,28 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/dynamic_input_classifier.tflite?generation=1680543275416843"], urls = ["https://storage.googleapis.com/mediapipe-assets/dynamic_input_classifier.tflite?generation=1680543275416843"],
) )
http_file(
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_anchors_csv",
sha256 = "284475a0f16e34afcc6c0fe68b05bd871aca5b20c83db0870c6a36dd63827176",
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms_anchors.csv?generation=1682456090001817"],
)
http_file(
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_json",
sha256 = "dc3b333e41c43fb49ace048c25c18d0e34df78fb5ee77edbe169264368f78b92",
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.json?generation=1682456092938505"],
)
http_file(
name = "com_google_mediapipe_efficientdet_lite0_fp16_no_nms_tflite",
sha256 = "bcda125c96d3767bca894c8cbe7bc458379c9974c9fd8bdc6204e7124a74082a",
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_fp16_no_nms.tflite?generation=1682456096034465"],
)
http_file( http_file(
name = "com_google_mediapipe_efficientdet_lite0_v1_json", name = "com_google_mediapipe_efficientdet_lite0_v1_json",
sha256 = "7a9e1fb625a6130a251e612637fc546cfc8cfabfadc7dbdade44c87f1d8996ca", sha256 = "ef9706696a3ea5d87f4324ac56e877a92033d33e522c4b7d5a416fbcab24d8fc",
urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_v1.json?generation=1677522746026682"], urls = ["https://storage.googleapis.com/mediapipe-assets/efficientdet_lite0_v1.json?generation=1682456098581704"],
) )
http_file( http_file(
@ -1158,8 +1176,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_ssd_mobilenet_v1_no_metadata_json", name = "com_google_mediapipe_ssd_mobilenet_v1_no_metadata_json",
sha256 = "89157590b736cf3f3247aa9c8be3570c2856f4981a1e9476117e7c629e7c4825", sha256 = "ae5a5971a1c3df705307448ef97c854d846b7e6f2183fb51015bd5af5d7deb0f",
urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1_no_metadata.json?generation=1677522786336455"], urls = ["https://storage.googleapis.com/mediapipe-assets/ssd_mobilenet_v1_no_metadata.json?generation=1682456117002011"],
) )
http_file( http_file(