Updated implementation and tests

This commit is contained in:
kinaryml 2022-09-29 03:40:56 -07:00
parent 1461bcf97d
commit bef2f6cced
7 changed files with 40 additions and 21 deletions

View File

@ -17,12 +17,3 @@
package(default_visibility = ["//mediapipe/tasks:internal"]) package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
py_library(
name = "segmenter_options",
srcs = ["segmenter_options.py"],
deps = [
"//mediapipe/tasks/cc/components:segmenter_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -0,0 +1,28 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
py_library(
name = "segmenter_options",
srcs = ["segmenter_options.py"],
deps = [
"//mediapipe/tasks/cc/components/proto:segmenter_options_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -17,7 +17,7 @@ import dataclasses
import enum import enum
from typing import Any, Optional from typing import Any, Optional
from mediapipe.tasks.cc.components import segmenter_options_pb2 from mediapipe.tasks.cc.components.proto import segmenter_options_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions

View File

@ -48,7 +48,7 @@ py_test(
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_util", "//mediapipe/tasks/python/test:test_util",
"//mediapipe/tasks/python/components:segmenter_options", "//mediapipe/tasks/python/components/proto:segmenter_options",
"//mediapipe/tasks/python/vision:image_segmenter", "//mediapipe/tasks/python/vision:image_segmenter",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@absl_py//absl/testing:parameterized", "@absl_py//absl/testing:parameterized",

View File

@ -22,7 +22,7 @@ from absl.testing import parameterized
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module from mediapipe.python._framework_bindings import image_frame as image_frame_module
from mediapipe.tasks.python.components import segmenter_options from mediapipe.tasks.python.components.proto import segmenter_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_util from mediapipe.tasks.python.test import test_util
from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision import image_segmenter
@ -66,7 +66,7 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_create_from_options_succeeds_with_valid_model_path(self): def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully. # Creates with options containing model file successfully.
base_options = _BaseOptions(file_name=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
@ -77,14 +77,14 @@ class ImageSegmenterTest(parameterized.TestCase):
ValueError, ValueError,
r"ExternalFile must specify at least one of 'file_content', " r"ExternalFile must specify at least one of 'file_content', "
r"'file_name' or 'file_descriptor_meta'."): r"'file_name' or 'file_descriptor_meta'."):
base_options = _BaseOptions(file_name='') base_options = _BaseOptions(model_asset_path='')
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options) _ImageSegmenter.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self): def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully. # Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f: with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(file_content=f.read()) base_options = _BaseOptions(model_asset_buffer=f.read())
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
@ -95,11 +95,11 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_succeeds_with_category_mask(self, model_file_type): def test_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter. # Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(file_name=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT: elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f: with open(self.model_path, 'rb') as f:
model_content = f.read() model_content = f.read()
base_options = _BaseOptions(file_content=model_content) base_options = _BaseOptions(model_asset_buffer=model_content)
else: else:
# Should never happen # Should never happen
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
@ -147,7 +147,7 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_succeeds_with_confidence_mask(self): def test_succeeds_with_confidence_mask(self):
# Creates segmenter. # Creates segmenter.
base_options = _BaseOptions(file_name=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode. # Run segmentation on the model in CATEGORY_MASK mode.
segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK) segmenter_options = _SegmenterOptions(output_type=_OutputType.CATEGORY_MASK)

View File

@ -47,7 +47,7 @@ py_library(
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_py_pb2",
"//mediapipe/tasks/python/components:segmenter_options", "//mediapipe/tasks/python/components/proto:segmenter_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",

View File

@ -22,7 +22,7 @@ from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.python._framework_bindings import task_runner as task_runner_module from mediapipe.python._framework_bindings import task_runner as task_runner_module
from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_options_pb2
from mediapipe.tasks.python.components import segmenter_options from mediapipe.tasks.python.components.proto import segmenter_options
from mediapipe.tasks.python.core import base_options as base_options_module from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
@ -103,7 +103,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
file such as invalid file path. file such as invalid file path.
RuntimeError: If other types of error occurred. RuntimeError: If other types of error occurred.
""" """
base_options = _BaseOptions(file_name=model_path) base_options = _BaseOptions(model_asset_path=model_path)
options = ImageSegmenterOptions( options = ImageSegmenterOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE) base_options=base_options, running_mode=_RunningMode.IMAGE)
return cls.create_from_options(options) return cls.create_from_options(options)