Internal update

PiperOrigin-RevId: 535376471
This commit is contained in:
MediaPipe Team 2023-05-25 14:00:49 -07:00 committed by Copybara-Service
parent 169bdf15b4
commit e8f2541cbd
2 changed files with 10 additions and 1 deletions

View File

@ -57,6 +57,7 @@ py_library(
deps = [ deps = [
":metadata_info", ":metadata_info",
":metadata_writer", ":metadata_writer",
":writer_utils",
"//mediapipe/tasks/metadata:image_segmenter_metadata_schema_py", "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_py",
"//mediapipe/tasks/metadata:metadata_schema_py", "//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/python/metadata", "//mediapipe/tasks/python/metadata",

View File

@ -22,6 +22,7 @@ from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_f
from mediapipe.tasks.python.metadata import metadata from mediapipe.tasks.python.metadata import metadata
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
_MODEL_NAME = "ImageSegmenter" _MODEL_NAME = "ImageSegmenter"
@ -148,10 +149,17 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
writer = metadata_writer.MetadataWriter(model_buffer) writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION) writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_image_input(input_norm_mean, input_norm_std) writer.add_image_input(input_norm_mean, input_norm_std)
writer.add_segmentation_output(labels=labels)
if activation is not None: if activation is not None:
option_md = ImageSegmenterOptionsMd(activation) option_md = ImageSegmenterOptionsMd(activation)
writer.add_custom_metadata(option_md) writer.add_custom_metadata(option_md)
num_output_tensors = writer_utils.get_subgraph(model_buffer).OutputsLength()
if num_output_tensors == 2:
# For image segmenter model with 2 output tensors, the first one is
# quality score, and the second one is matting mask.
writer.add_feature_output(
"quality score", "The quality score of matting result."
)
writer.add_segmentation_output(labels=labels)
return cls(writer) return cls(writer)
def populate(self) -> tuple[bytearray, str]: def populate(self) -> tuple[bytearray, str]: