Add metadata writer for image segmentation.
PiperOrigin-RevId: 516671364
This commit is contained in:
parent
9a89b47572
commit
51d9640d88
|
@ -7,7 +7,9 @@ package(
|
|||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["metadata_schema.fbs"])
|
||||
exports_files(glob([
|
||||
"*.fbs",
|
||||
]))
|
||||
|
||||
# Generic schema for model metadata.
|
||||
flatbuffer_cc_library(
|
||||
|
@ -24,3 +26,13 @@ flatbuffer_py_library(
|
|||
name = "metadata_schema_py",
|
||||
srcs = ["metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "image_segmenter_metadata_schema_cc",
|
||||
srcs = ["image_segmenter_metadata_schema.fbs"],
|
||||
)
|
||||
|
||||
flatbuffer_py_library(
|
||||
name = "image_segmenter_metadata_schema_py",
|
||||
srcs = ["image_segmenter_metadata_schema.fbs"],
|
||||
)
|
||||
|
|
59
mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs
Normal file
59
mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs
Normal file
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace mediapipe.tasks;
|
||||
|
||||
// Image segmenter metadata contains information specific for the image
|
||||
// segmentation task. The metadata can be added in
|
||||
// SubGraphMetadata.custom_metadata [1] in model metadata.
|
||||
// [1]: https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L685
|
||||
|
||||
// ImageSegmenterOptions.min_parser_version indicates the minimum necessary
|
||||
// image segmenter metadata parser version to fully understand all fields in a
|
||||
// given metadata flatbuffer. This min_parser_version is specific for the
|
||||
// image segmenter metadata defined in this schema file.
|
||||
//
|
||||
// New fields and types will have associated comments with the schema version
|
||||
// for which they were added.
|
||||
//
|
||||
// Schema Semantic version: 1.0.0
|
||||
|
||||
// This indicates the flatbuffer compatibility. The number will bump up when a
|
||||
// break change is applied to the schema, such as removing fields or adding new
|
||||
// fields to the middle of a table.
|
||||
file_identifier "V001";
|
||||
|
||||
// History:
|
||||
// 1.0.0 - Initial version.
|
||||
|
||||
// Supported activation functions.
|
||||
enum Activation: byte {
|
||||
NONE = 0,
|
||||
SIGMOID = 1,
|
||||
SOFTMAX = 2
|
||||
}
|
||||
|
||||
table ImageSegmenterOptions {
|
||||
// The activation function of the output layer in the image segmenter.
|
||||
activation: Activation;
|
||||
|
||||
// The minimum necessary image segmenter metadata parser version to fully
|
||||
// understand all fields in a given metadata flatbuffer. This field is
|
||||
// automaticaly populated by the MetadataPopulator when the metadata is
|
||||
// populated into a TFLite model. This min_parser_version is specific for the
|
||||
// image segmenter metadata defined in this schema file.
|
||||
min_parser_version:string;
|
||||
}
|
||||
|
||||
root_type ImageSegmenterOptions;
|
|
@ -17,10 +17,13 @@
|
|||
import copy
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict, Optional
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
|
@ -789,13 +792,43 @@ class MetadataDisplayer(object):
|
|||
return []
|
||||
|
||||
|
||||
def _get_custom_metadata(metadata_buffer: bytes, name: str):
|
||||
"""Gets the custom metadata in metadata_buffer based on the name.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
name: custom metadata name.
|
||||
|
||||
Returns:
|
||||
Index of custom metadata, custom metadata flatbuffer. Returns (None, None)
|
||||
if the custom metadata is not found.
|
||||
"""
|
||||
model_metadata = _metadata_fb.ModelMetadata.GetRootAs(metadata_buffer)
|
||||
subgraph = model_metadata.SubgraphMetadata(0)
|
||||
if subgraph is None or subgraph.CustomMetadataIsNone():
|
||||
return None, None
|
||||
|
||||
for i in range(subgraph.CustomMetadataLength()):
|
||||
custom_metadata = subgraph.CustomMetadata(i)
|
||||
if custom_metadata.Name().decode("utf-8") == name:
|
||||
return i, custom_metadata.DataAsNumpy().tobytes()
|
||||
return None, None
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer):
|
||||
def convert_to_json(
|
||||
metadata_buffer, custom_metadata_schema: Optional[Dict[str, str]] = None
|
||||
) -> str:
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
custom_metadata_schema: A dict of custom metadata schema, in which key is
|
||||
custom metadata name [1], value is the filepath that defines custom
|
||||
metadata schema. For intance, custom_metadata_schema =
|
||||
{"SEGMENTER_METADATA": "metadata/vision_tasks_metadata_schema.fbs"}. [1]:
|
||||
https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L612
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
@ -803,7 +836,6 @@ def convert_to_json(metadata_buffer):
|
|||
Raises:
|
||||
ValueError: error occured when parsing the metadata schema file.
|
||||
"""
|
||||
|
||||
opt = _pywrap_flatbuffers.IDLOptions()
|
||||
opt.strict_json = True
|
||||
parser = _pywrap_flatbuffers.Parser(opt)
|
||||
|
@ -811,7 +843,35 @@ def convert_to_json(metadata_buffer):
|
|||
metadata_schema_content = f.read()
|
||||
if not parser.parse(metadata_schema_content):
|
||||
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
|
||||
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
|
||||
# Json content which may contain binary custom metadata.
|
||||
raw_json_content = _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
|
||||
if not custom_metadata_schema:
|
||||
return raw_json_content
|
||||
|
||||
json_data = json.loads(raw_json_content)
|
||||
# Gets the custom metadata by name and parse the binary custom metadata into
|
||||
# human readable json content.
|
||||
for name, schema_file in custom_metadata_schema.items():
|
||||
idx, custom_metadata = _get_custom_metadata(metadata_buffer, name)
|
||||
if not custom_metadata:
|
||||
logging.info(
|
||||
"No custom metadata with name %s in metadata flatbuffer.", name
|
||||
)
|
||||
continue
|
||||
_assert_file_exist(schema_file)
|
||||
with _open_file(schema_file, "rb") as f:
|
||||
custom_metadata_schema_content = f.read()
|
||||
if not parser.parse(custom_metadata_schema_content):
|
||||
raise ValueError(
|
||||
"Cannot parse custom metadata schema. Reason: " + parser.error
|
||||
)
|
||||
custom_metadata_json = _pywrap_flatbuffers.generate_text(
|
||||
parser, custom_metadata
|
||||
)
|
||||
json_meta = json_data["subgraph_metadata"][0]["custom_metadata"][idx]
|
||||
json_meta["name"] = name
|
||||
json_meta["data"] = json.loads(custom_metadata_json)
|
||||
return json.dumps(json_data, indent=2)
|
||||
|
||||
|
||||
def _assert_file_exist(filename):
|
||||
|
|
|
@ -50,6 +50,20 @@ py_library(
|
|||
deps = [":metadata_writer"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_segmenter",
|
||||
srcs = ["image_segmenter.py"],
|
||||
data = ["//mediapipe/tasks/metadata:image_segmenter_metadata_schema.fbs"],
|
||||
deps = [
|
||||
":metadata_info",
|
||||
":metadata_writer",
|
||||
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"@flatbuffers//:runtime_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "object_detector",
|
||||
srcs = ["object_detector.py"],
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Writes metadata and label file to the image segmenter models."""
|
||||
import enum
|
||||
from typing import List, Optional
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import image_segmenter_metadata_schema_py_generated as _segmenter_metadata_fb
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
|
||||
_MODEL_NAME = "ImageSegmenter"
|
||||
_MODEL_DESCRIPTION = (
|
||||
"Semantic image segmentation predicts whether each pixel "
|
||||
"of an image is associated with a certain class."
|
||||
)
|
||||
|
||||
# Metadata Schema file for image segmenter.
|
||||
_FLATC_METADATA_SCHEMA_FILE = metadata.get_path_to_datafile(
|
||||
"../../../metadata/image_segmenter_metadata_schema.fbs",
|
||||
)
|
||||
|
||||
# Metadata name in custom metadata field. The metadata name is used to get
|
||||
# image segmenter metadata from SubGraphMetadata.custom_metadata and
|
||||
# shouldn't be changed.
|
||||
_METADATA_NAME = "SEGMENTER_METADATA"
|
||||
|
||||
|
||||
class Activation(enum.Enum):
|
||||
NONE = 0
|
||||
SIGMOID = 1
|
||||
SOFTMAX = 2
|
||||
|
||||
|
||||
# Create an individual method for getting the metadata json file, so that it can
|
||||
# be used as a standalone util.
|
||||
def convert_to_json(metadata_buffer: bytearray) -> str:
|
||||
"""Converts the metadata into a json string.
|
||||
|
||||
Args:
|
||||
metadata_buffer: valid metadata buffer in bytes.
|
||||
|
||||
Returns:
|
||||
Metadata in JSON format.
|
||||
|
||||
Raises:
|
||||
ValueError: error occured when parsing the metadata schema file.
|
||||
"""
|
||||
return metadata.convert_to_json(
|
||||
metadata_buffer,
|
||||
custom_metadata_schema={_METADATA_NAME: _FLATC_METADATA_SCHEMA_FILE},
|
||||
)
|
||||
|
||||
|
||||
class ImageSegmenterOptionsMd(metadata_info.CustomMetadataMd):
|
||||
"""Image segmenter options metadata."""
|
||||
|
||||
_METADATA_FILE_IDENTIFIER = b"V001"
|
||||
|
||||
def __init__(self, activation: Activation) -> None:
|
||||
"""Creates an ImageSegmenterOptionsMd object.
|
||||
|
||||
Args:
|
||||
activation: activation function of the output layer in the image
|
||||
segmenter.
|
||||
"""
|
||||
self.activation = activation
|
||||
super().__init__(name=_METADATA_NAME)
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.CustomMetadataT:
|
||||
"""Creates the image segmenter options metadata.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the custom metadata including image
|
||||
segmenter options metadata.
|
||||
"""
|
||||
segmenter_options = _segmenter_metadata_fb.ImageSegmenterOptionsT()
|
||||
segmenter_options.activation = self.activation.value
|
||||
|
||||
# Get the image segmenter options flatbuffer.
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(segmenter_options.Pack(b), self._METADATA_FILE_IDENTIFIER)
|
||||
segmenter_options_buf = b.Output()
|
||||
|
||||
# Add the image segmenter options flatbuffer in custom metadata.
|
||||
custom_metadata = _metadata_fb.CustomMetadataT()
|
||||
custom_metadata.name = self.name
|
||||
custom_metadata.data = segmenter_options_buf
|
||||
return custom_metadata
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata for image segmenter."""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
input_norm_mean: List[float],
|
||||
input_norm_std: List[float],
|
||||
labels: Optional[metadata_writer.Labels] = None,
|
||||
activation: Optional[Activation] = None,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter to write the metadata for image segmenter.
|
||||
|
||||
The parameters required in this method are mandatory when using MediaPipe
|
||||
Tasks.
|
||||
|
||||
Example usage:
|
||||
metadata_writer = image_segmenter.Metadatawriter.create(model_buffer, ...)
|
||||
tflite_content, json_content = metadata_writer.populate()
|
||||
|
||||
When calling `populate` function in this class, it returns TfLite content
|
||||
and JSON content. Note that only the output TFLite is used for deployment.
|
||||
The output JSON content is used to interpret the metadata content.
|
||||
|
||||
Args:
|
||||
model_buffer: A valid flatbuffer loaded from the TFLite model file.
|
||||
input_norm_mean: the mean value used in the input tensor normalization
|
||||
[1].
|
||||
input_norm_std: the std value used in the input tensor normalizarion [1].
|
||||
labels: an instance of Labels helper class used in the output category
|
||||
tensor [2].
|
||||
activation: activation function for the output layer.
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_image_input(input_norm_mean, input_norm_std)
|
||||
writer.add_segmentation_output(labels=labels)
|
||||
if activation is not None:
|
||||
option_md = ImageSegmenterOptionsMd(activation)
|
||||
writer.add_custom_metadata(option_md)
|
||||
return cls(writer)
|
||||
|
||||
def populate(self) -> tuple[bytearray, str]:
|
||||
model_buf, _ = super().populate()
|
||||
metadata_buf = metadata.get_metadata_buffer(model_buf)
|
||||
json_content = convert_to_json(metadata_buf)
|
||||
return model_buf, json_content
|
|
@ -1030,6 +1030,52 @@ class TensorGroupMd:
|
|||
return group
|
||||
|
||||
|
||||
class SegmentationMaskMd(TensorMd):
|
||||
"""A container for the segmentation mask metadata information."""
|
||||
|
||||
# The output tensor is in the shape of [1, ImageHeight, ImageWidth, N], where
|
||||
# N is the number of objects that the segmentation model can recognize. The
|
||||
# output tensor is essentially a list of grayscale bitmaps, where each value
|
||||
# is the probability of the corresponding pixel belonging to a certain object
|
||||
# type. Therefore, the content dimension range of the output tensor is [1, 2].
|
||||
_CONTENT_DIM_MIN = 1
|
||||
_CONTENT_DIM_MAX = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
label_files: Optional[List[LabelFileMd]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
associated_files = label_files or []
|
||||
super().__init__(
|
||||
name=name, description=description, associated_files=associated_files
|
||||
)
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
|
||||
"""Creates the metadata for the segmentation masks tensor."""
|
||||
masks_metadata = super().create_metadata()
|
||||
|
||||
# Create tensor content information.
|
||||
content = _metadata_fb.ContentT()
|
||||
content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||
content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.GRAYSCALE
|
||||
content.contentPropertiesType = (
|
||||
_metadata_fb.ContentProperties.ImageProperties
|
||||
)
|
||||
# Add the content range. See
|
||||
# https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L323-L385
|
||||
dim_range = _metadata_fb.ValueRangeT()
|
||||
dim_range.min = self._CONTENT_DIM_MIN
|
||||
dim_range.max = self._CONTENT_DIM_MAX
|
||||
content.range = dim_range
|
||||
masks_metadata.content = content
|
||||
|
||||
return masks_metadata
|
||||
|
||||
|
||||
class CustomMetadataMd(abc.ABC):
|
||||
"""An abstract class of a container for the custom metadata information."""
|
||||
|
||||
|
|
|
@ -34,6 +34,10 @@ _INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input '
|
|||
'text to be processed.')
|
||||
_OUTPUT_CLASSIFICATION_NAME = 'score'
|
||||
_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.'
|
||||
_OUTPUT_SEGMENTATION_MASKS_NAME = 'segmentation_masks'
|
||||
_OUTPUT_SEGMENTATION_MASKS_DESCRIPTION = (
|
||||
'Masks over the target objects with high accuracy.'
|
||||
)
|
||||
# Detection tensor result to be grouped together.
|
||||
_DETECTION_GROUP_NAME = 'detection_result'
|
||||
# File name to export score calibration parameters.
|
||||
|
@ -657,6 +661,32 @@ class MetadataWriter(object):
|
|||
self._output_group_mds.append(group_md)
|
||||
return self
|
||||
|
||||
def add_segmentation_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
name: str = _OUTPUT_SEGMENTATION_MASKS_NAME,
|
||||
description: str = _OUTPUT_SEGMENTATION_MASKS_DESCRIPTION,
|
||||
) -> 'MetadataWriter':
|
||||
"""Adds a segmentation head metadata for segmentation output tensor.
|
||||
|
||||
Args:
|
||||
labels: an instance of Labels helper class.
|
||||
name: Metadata name of the tensor. Note that this is different from tensor
|
||||
name in the flatbuffer.
|
||||
description: human readable description of what the output is.
|
||||
|
||||
Returns:
|
||||
The current Writer instance to allow chained operation.
|
||||
"""
|
||||
label_files = self._create_label_file_md(labels)
|
||||
output_md = metadata_info.SegmentationMaskMd(
|
||||
name=name,
|
||||
description=description,
|
||||
label_files=label_files,
|
||||
)
|
||||
self._output_mds.append(output_md)
|
||||
return self
|
||||
|
||||
def add_feature_output(self,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None) -> 'MetadataWriter':
|
||||
|
|
|
@ -91,3 +91,18 @@ py_test(
|
|||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_segmenter_test",
|
||||
srcs = ["image_segmenter_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/metadata:data_files",
|
||||
"//mediapipe/tasks/testdata/metadata:model_files",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/python/metadata",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:image_segmenter",
|
||||
"//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for metadata_writer.image_segmenter."""
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.python.metadata import metadata
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import image_segmenter
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||
_MODEL_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3_without_metadata.tflite")
|
||||
)
|
||||
_LABEL_FILE_NAME = "labels.txt"
|
||||
_LABEL_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "segmenter_labelmap.txt")
|
||||
)
|
||||
_NORM_MEAN = 127.5
|
||||
_NORM_STD = 127.5
|
||||
_JSON_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3.json")
|
||||
)
|
||||
_JSON_FILE_WITHOUT_LABELS = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3_without_labels.json")
|
||||
)
|
||||
_JSON_FILE_WITH_ACTIVATION = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "deeplabv3_with_activation.json")
|
||||
)
|
||||
|
||||
|
||||
class ImageSegmenterTest(absltest.TestCase):
|
||||
|
||||
def test_write_metadata(self):
|
||||
with open(_MODEL_FILE, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_segmenter.MetadataWriter.create(
|
||||
bytearray(model_buffer),
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE),
|
||||
)
|
||||
tflite_content, metadata_json = writer.populate()
|
||||
with open(_JSON_FILE, "r") as f:
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content)
|
||||
label_file_buffer = displayer.get_associated_file_buffer(_LABEL_FILE_NAME)
|
||||
with open(_LABEL_FILE, "rb") as f:
|
||||
expected_labelfile_buffer = f.read()
|
||||
self.assertEqual(label_file_buffer, expected_labelfile_buffer)
|
||||
|
||||
def test_write_metadata_without_labels(self):
|
||||
with open(_MODEL_FILE, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_segmenter.MetadataWriter.create(
|
||||
bytearray(model_buffer),
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
)
|
||||
_, metadata_json = writer.populate()
|
||||
with open(_JSON_FILE_WITHOUT_LABELS, "r") as f:
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_write_metadata_with_activation(self):
|
||||
with open(_MODEL_FILE, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = image_segmenter.MetadataWriter.create(
|
||||
bytearray(model_buffer),
|
||||
[_NORM_MEAN],
|
||||
[_NORM_STD],
|
||||
activation=image_segmenter.Activation.SIGMOID,
|
||||
)
|
||||
_, metadata_json = writer.populate()
|
||||
with open(_JSON_FILE_WITH_ACTIVATION, "r") as f:
|
||||
expected_json = f.read().strip()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
|
@ -455,6 +455,27 @@ class TensorGroupMdMdTest(absltest.TestCase):
|
|||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class SegmentationMaskMdTest(absltest.TestCase):
|
||||
_NAME = "segmentation_masks"
|
||||
_DESCRIPTION = "Masks over the target objects."
|
||||
_EXPECTED_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "segmentation_mask_meta.json")
|
||||
)
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
segmentation_mask_md = metadata_info.SegmentationMaskMd(
|
||||
name=self._NAME, description=self._DESCRIPTION
|
||||
)
|
||||
metadata = segmentation_mask_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_tensor(metadata)
|
||||
)
|
||||
with open(self._EXPECTED_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
def _create_dummy_model_metadata_with_tensor(
|
||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||
# Create a dummy model using the tensor metadata.
|
||||
|
|
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
12
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -28,6 +28,10 @@ mediapipe_files(srcs = [
|
|||
"category_tensor_float_meta.json",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_score_calibration.json",
|
||||
"deeplabv3.json",
|
||||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"labelmap.txt",
|
||||
|
@ -44,6 +48,8 @@ mediapipe_files(srcs = [
|
|||
"mobilenet_v2_1.0_224_without_metadata.tflite",
|
||||
"movie_review.tflite",
|
||||
"score_calibration.csv",
|
||||
"segmentation_mask_meta.json",
|
||||
"segmenter_labelmap.txt",
|
||||
"ssd_mobilenet_v1_no_metadata.json",
|
||||
"ssd_mobilenet_v1_no_metadata.tflite",
|
||||
"tensor_group_meta.json",
|
||||
|
@ -87,6 +93,7 @@ filegroup(
|
|||
"30k-clean.model",
|
||||
"bert_text_classifier_no_metadata.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite",
|
||||
"deeplabv3_without_metadata.tflite",
|
||||
"efficientdet_lite0_v1.tflite",
|
||||
"mobile_ica_8bit-with-custom-metadata.tflite",
|
||||
"mobile_ica_8bit-with-large-min-parser-version.tflite",
|
||||
|
@ -116,6 +123,9 @@ filegroup(
|
|||
"classification_tensor_uint8_meta.json",
|
||||
"classification_tensor_unsupported_meta.json",
|
||||
"coco_ssd_mobilenet_v1_score_calibration.json",
|
||||
"deeplabv3.json",
|
||||
"deeplabv3_with_activation.json",
|
||||
"deeplabv3_without_labels.json",
|
||||
"efficientdet_lite0_v1.json",
|
||||
"external_file",
|
||||
"feature_tensor_meta.json",
|
||||
|
@ -140,6 +150,8 @@ filegroup(
|
|||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
"segmentation_mask_meta.json",
|
||||
"segmenter_labelmap.txt",
|
||||
"sentence_piece_tokenizer_meta.json",
|
||||
"ssd_mobilenet_v1_no_metadata.json",
|
||||
"tensor_group_meta.json",
|
||||
|
|
66
mediapipe/tasks/testdata/metadata/deeplabv3.json
vendored
Normal file
66
mediapipe/tasks/testdata/metadata/deeplabv3.json
vendored
Normal file
|
@ -0,0 +1,66 @@
|
|||
{
|
||||
"name": "ImageSegmenter",
|
||||
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects with high accuracy.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
67
mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json
vendored
Normal file
67
mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json
vendored
Normal file
|
@ -0,0 +1,67 @@
|
|||
{
|
||||
"name": "ImageSegmenter",
|
||||
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects with high accuracy.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {}
|
||||
}
|
||||
],
|
||||
"custom_metadata": [
|
||||
{
|
||||
"name": "SEGMENTER_METADATA",
|
||||
"data": {
|
||||
"activation": "SIGMOID"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.5.0"
|
||||
}
|
59
mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json
vendored
Normal file
59
mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json
vendored
Normal file
|
@ -0,0 +1,59 @@
|
|||
{
|
||||
"name": "ImageSegmenter",
|
||||
"description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "image",
|
||||
"description": "Input image to be processed.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "RGB"
|
||||
}
|
||||
},
|
||||
"process_units": [
|
||||
{
|
||||
"options_type": "NormalizationOptions",
|
||||
"options": {
|
||||
"mean": [
|
||||
127.5
|
||||
],
|
||||
"std": [
|
||||
127.5
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
-1.0
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects with high accuracy.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.0.0"
|
||||
}
|
24
mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json
vendored
Normal file
24
mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json
vendored
Normal file
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "segmentation_masks",
|
||||
"description": "Masks over the target objects.",
|
||||
"content": {
|
||||
"content_properties_type": "ImageProperties",
|
||||
"content_properties": {
|
||||
"color_space": "GRAYSCALE"
|
||||
},
|
||||
"range": {
|
||||
"min": 1,
|
||||
"max": 2
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
21
mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt
vendored
Normal file
21
mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
background
|
||||
aeroplane
|
||||
bicycle
|
||||
bird
|
||||
boat
|
||||
bottle
|
||||
bus
|
||||
car
|
||||
cat
|
||||
chair
|
||||
cow
|
||||
dining table
|
||||
dog
|
||||
horse
|
||||
motorbike
|
||||
person
|
||||
potted plant
|
||||
sheep
|
||||
sofa
|
||||
train
|
||||
tv
|
36
third_party/external_files.bzl
vendored
36
third_party/external_files.bzl
vendored
|
@ -208,12 +208,36 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/corrupted_mobilenet_v1_0.25_224_1_default_1.tflite?generation=1661875706780536"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_json",
|
||||
sha256 = "f299835bd9ea1cceb25fdf40a761a22716cbd20025cd67c365a860527f178b7f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.json?generation=1678818040715103"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_tflite",
|
||||
sha256 = "5faed2c653905d3e22a8f6f29ee198da84e9b0e7936a207bf431f17f6b4d87ff",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1678775085237701"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_with_activation_json",
|
||||
sha256 = "a7633476d02f970db3cc30f5f027bcb608149e02207b2ccae36a4b69d730c82c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_with_activation.json?generation=1678818047050984"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_without_labels_json",
|
||||
sha256 = "7d045a583a4046f17a52d2078b0175607a45ed0cc187558325f9c66534c08401",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_labels.json?generation=1678818050191996"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_deeplabv3_without_metadata_tflite",
|
||||
sha256 = "68a539782c2c6a72f8aac3724600124a85ed977162b44e84cbae5db717c933c6",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_metadata.tflite?generation=1678818053623010"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_dense_tflite",
|
||||
sha256 = "be9323068461b1cbf412692ee916be30dcb1a5fb59a9ee875d470bc340d9e869",
|
||||
|
@ -976,6 +1000,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_input_rotation0.jpg?generation=1661875914048401"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_segmentation_mask_meta_json",
|
||||
sha256 = "4294d53b309c1fbe38a5184de4057576c3dec14e07d16491f1dd459ac9116ab3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_mask_meta.json?generation=1678818065134737"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_segmenter_labelmap_txt",
|
||||
sha256 = "d9efa78274f1799ddbcab1f87263e19dae338c1697de47a5b270c9526c45d364",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/segmenter_labelmap.txt?generation=1678818068181025"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg",
|
||||
sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e",
|
||||
|
|
Loading…
Reference in New Issue
Block a user