Fix test utils bug when two file names have same ending.
PiperOrigin-RevId: 485780917
This commit is contained in:
parent
1b5da09a92
commit
ddf37d014e
|
@ -27,6 +27,8 @@ from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||||
from mediapipe.tasks.python.metadata import metadata as _metadata
|
from mediapipe.tasks.python.metadata import metadata as _metadata
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(enum.Enum):
|
class Tokenizer(enum.Enum):
|
||||||
BERT_TOKENIZER = 0
|
BERT_TOKENIZER = 0
|
||||||
|
@ -810,7 +812,8 @@ class MetadataDisplayerTest(MetadataTest):
|
||||||
actual_json = _metadata.convert_to_json(actual_buffer)
|
actual_json = _metadata.convert_to_json(actual_buffer)
|
||||||
|
|
||||||
# Verifies the generated json file.
|
# Verifies the generated json file.
|
||||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
golden_json_file_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
|
||||||
with open(golden_json_file_path, "r") as f:
|
with open(golden_json_file_path, "r") as f:
|
||||||
expected = f.read()
|
expected = f.read()
|
||||||
self.assertEqual(actual_json, expected)
|
self.assertEqual(actual_json, expected)
|
||||||
|
@ -821,7 +824,8 @@ class MetadataDisplayerTest(MetadataTest):
|
||||||
actual = displayer.get_metadata_json()
|
actual = displayer.get_metadata_json()
|
||||||
|
|
||||||
# Verifies the generated json file.
|
# Verifies the generated json file.
|
||||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
golden_json_file_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
|
||||||
expected = _read_file(golden_json_file_path, "r")
|
expected = _read_file(golden_json_file_path, "r")
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
@ -848,7 +852,8 @@ class MetadataUtilTest(MetadataTest):
|
||||||
metadata_json = _metadata.convert_to_json(metadata_buf)
|
metadata_json = _metadata.convert_to_json(metadata_buf)
|
||||||
|
|
||||||
# Verifies the generated json file.
|
# Verifies the generated json file.
|
||||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
golden_json_file_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
|
||||||
expected = _read_file(golden_json_file_path, "r")
|
expected = _read_file(golden_json_file_path, "r")
|
||||||
self.assertEqual(metadata_json, expected)
|
self.assertEqual(metadata_json, expected)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for metadata_writer.image_classifier."""
|
"""Tests for metadata_writer.image_classifier."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
@ -23,18 +25,25 @@ from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||||
_FLOAT_MODEL = test_utils.get_test_data_path(
|
_FLOAT_MODEL = test_utils.get_test_data_path(
|
||||||
"mobilenet_v2_1.0_224_without_metadata.tflite")
|
os.path.join(_TEST_DATA_DIR,
|
||||||
|
"mobilenet_v2_1.0_224_without_metadata.tflite"))
|
||||||
_QUANT_MODEL = test_utils.get_test_data_path(
|
_QUANT_MODEL = test_utils.get_test_data_path(
|
||||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite")
|
os.path.join(_TEST_DATA_DIR,
|
||||||
_LABEL_FILE = test_utils.get_test_data_path("labels.txt")
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite"))
|
||||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
|
_LABEL_FILE = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "labels.txt"))
|
||||||
|
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "score_calibration.txt"))
|
||||||
_SCORE_CALIBRATION_FILENAME = "score_calibration.txt"
|
_SCORE_CALIBRATION_FILENAME = "score_calibration.txt"
|
||||||
_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2
|
_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2
|
||||||
_NORM_MEAN = 127.5
|
_NORM_MEAN = 127.5
|
||||||
_NORM_STD = 127.5
|
_NORM_STD = 127.5
|
||||||
_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json")
|
_FLOAT_JSON = test_utils.get_test_data_path(
|
||||||
_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json")
|
os.path.join(_TEST_DATA_DIR, "mobilenet_v2_1.0_224.json"))
|
||||||
|
_QUANT_JSON = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "mobilenet_v2_1.0_224_quant.json"))
|
||||||
|
|
||||||
|
|
||||||
class ImageClassifierTest(parameterized.TestCase):
|
class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for metadata info classes."""
|
"""Tests for metadata info classes."""
|
||||||
|
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -26,13 +27,15 @@ from mediapipe.tasks.python.metadata import metadata as _metadata
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
|
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||||
|
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "score_calibration.txt"))
|
||||||
|
|
||||||
|
|
||||||
class GeneralMdTest(absltest.TestCase):
|
class GeneralMdTest(absltest.TestCase):
|
||||||
|
|
||||||
_EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path(
|
_EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path(
|
||||||
"general_meta.json")
|
os.path.join(_TEST_DATA_DIR, "general_meta.json"))
|
||||||
|
|
||||||
def test_create_metadata_should_succeed(self):
|
def test_create_metadata_should_succeed(self):
|
||||||
general_md = metadata_info.GeneralMd(
|
general_md = metadata_info.GeneralMd(
|
||||||
|
@ -59,7 +62,7 @@ class GeneralMdTest(absltest.TestCase):
|
||||||
class AssociatedFileMdTest(absltest.TestCase):
|
class AssociatedFileMdTest(absltest.TestCase):
|
||||||
|
|
||||||
_EXPECTED_META_JSON = test_utils.get_test_data_path(
|
_EXPECTED_META_JSON = test_utils.get_test_data_path(
|
||||||
"associated_file_meta.json")
|
os.path.join(_TEST_DATA_DIR, "associated_file_meta.json"))
|
||||||
|
|
||||||
def test_create_metadata_should_succeed(self):
|
def test_create_metadata_should_succeed(self):
|
||||||
file_md = metadata_info.AssociatedFileMd(
|
file_md = metadata_info.AssociatedFileMd(
|
||||||
|
@ -92,11 +95,11 @@ class TensorMdTest(parameterized.TestCase):
|
||||||
_LABEL_FILE_EN = "labels.txt"
|
_LABEL_FILE_EN = "labels.txt"
|
||||||
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
|
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
|
||||||
_EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"feature_tensor_meta.json")
|
os.path.join(_TEST_DATA_DIR, "feature_tensor_meta.json"))
|
||||||
_EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"image_tensor_meta.json")
|
os.path.join(_TEST_DATA_DIR, "image_tensor_meta.json"))
|
||||||
_EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"bounding_box_tensor_meta.json")
|
os.path.join(_TEST_DATA_DIR, "bounding_box_tensor_meta.json"))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
|
@ -142,11 +145,11 @@ class InputImageTensorMdTest(parameterized.TestCase):
|
||||||
_NORM_STD = (127.5, 127.5, 127.5)
|
_NORM_STD = (127.5, 127.5, 127.5)
|
||||||
_COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB
|
_COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB
|
||||||
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"input_image_tensor_float_meta.json")
|
os.path.join(_TEST_DATA_DIR, "input_image_tensor_float_meta.json"))
|
||||||
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"input_image_tensor_uint8_meta.json")
|
os.path.join(_TEST_DATA_DIR, "input_image_tensor_uint8_meta.json"))
|
||||||
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"input_image_tensor_unsupported_meta.json")
|
os.path.join(_TEST_DATA_DIR, "input_image_tensor_unsupported_meta.json"))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
|
@ -196,11 +199,12 @@ class ClassificationTensorMdTest(parameterized.TestCase):
|
||||||
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
|
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
|
||||||
_CALIBRATION_DEFAULT_SCORE = 0.2
|
_CALIBRATION_DEFAULT_SCORE = 0.2
|
||||||
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"classification_tensor_float_meta.json")
|
os.path.join(_TEST_DATA_DIR, "classification_tensor_float_meta.json"))
|
||||||
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"classification_tensor_uint8_meta.json")
|
os.path.join(_TEST_DATA_DIR, "classification_tensor_uint8_meta.json"))
|
||||||
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"classification_tensor_unsupported_meta.json")
|
os.path.join(_TEST_DATA_DIR,
|
||||||
|
"classification_tensor_unsupported_meta.json"))
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
|
@ -243,9 +247,9 @@ class ClassificationTensorMdTest(parameterized.TestCase):
|
||||||
class ScoreCalibrationMdTest(absltest.TestCase):
|
class ScoreCalibrationMdTest(absltest.TestCase):
|
||||||
_DEFAULT_VALUE = 0.2
|
_DEFAULT_VALUE = 0.2
|
||||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"score_calibration_tensor_meta.json")
|
os.path.join(_TEST_DATA_DIR, "score_calibration_tensor_meta.json"))
|
||||||
_EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path(
|
_EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path(
|
||||||
"score_calibration_file_meta.json")
|
os.path.join(_TEST_DATA_DIR, "score_calibration_file_meta.json"))
|
||||||
|
|
||||||
def test_create_metadata_should_succeed(self):
|
def test_create_metadata_should_succeed(self):
|
||||||
score_calibration_md = metadata_info.ScoreCalibrationMd(
|
score_calibration_md = metadata_info.ScoreCalibrationMd(
|
||||||
|
@ -310,7 +314,7 @@ class ScoreCalibrationMdTest(absltest.TestCase):
|
||||||
class ScoreThresholdingMdTest(absltest.TestCase):
|
class ScoreThresholdingMdTest(absltest.TestCase):
|
||||||
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
||||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
"score_thresholding_meta.json")
|
os.path.join(_TEST_DATA_DIR, "score_thresholding_meta.json"))
|
||||||
|
|
||||||
def test_create_metadata_should_succeed(self):
|
def test_create_metadata_should_succeed(self):
|
||||||
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||||
|
|
|
@ -21,9 +21,12 @@ from absl.testing import absltest
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/metadata'
|
||||||
|
|
||||||
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
|
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
|
||||||
'mobilenet_v1_0.25_224_1_default_1.tflite')
|
os.path.join(_TEST_DATA_DIR, 'mobilenet_v1_0.25_224_1_default_1.tflite'))
|
||||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt')
|
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
|
||||||
|
|
||||||
|
|
||||||
class LabelsTest(absltest.TestCase):
|
class LabelsTest(absltest.TestCase):
|
||||||
|
@ -85,8 +88,7 @@ class ScoreCalibrationTest(absltest.TestCase):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
'Expected empty lines or 3 or 4 parameters per line in score '
|
'Expected empty lines or 3 or 4 parameters per line in score '
|
||||||
'calibration file, but got 2.'
|
'calibration file, but got 2.'):
|
||||||
):
|
|
||||||
metadata_writer.ScoreCalibration.create_from_file(
|
metadata_writer.ScoreCalibration.create_from_file(
|
||||||
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||||
test_file)
|
test_file)
|
||||||
|
|
|
@ -42,13 +42,15 @@ def test_srcdir():
|
||||||
raise RuntimeError("Missing TEST_SRCDIR environment.")
|
raise RuntimeError("Missing TEST_SRCDIR environment.")
|
||||||
|
|
||||||
|
|
||||||
def get_test_data_path(file_or_dirname: str) -> str:
|
def get_test_data_path(file_or_dirname_path: str) -> str:
|
||||||
"""Returns full test data path."""
|
"""Returns full test data path."""
|
||||||
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
||||||
for f in subdirs + files:
|
for f in subdirs + files:
|
||||||
if f.endswith(file_or_dirname):
|
path = os.path.join(directory, f)
|
||||||
return os.path.join(directory, f)
|
if path.endswith(file_or_dirname_path):
|
||||||
raise ValueError("No %s in test directory" % file_or_dirname)
|
return path
|
||||||
|
raise ValueError("No %s in test directory: %s." %
|
||||||
|
(file_or_dirname_path, test_srcdir()))
|
||||||
|
|
||||||
|
|
||||||
def create_calibration_file(file_dir: str,
|
def create_calibration_file(file_dir: str,
|
||||||
|
|
|
@ -14,10 +14,12 @@
|
||||||
"""Tests for image classifier."""
|
"""Tests for image classifier."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image
|
from mediapipe.python._framework_bindings import image
|
||||||
|
@ -48,6 +50,7 @@ _ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||||
_DENY_LIST = ['cheeseburger']
|
_DENY_LIST = ['cheeseburger']
|
||||||
_SCORE_THRESHOLD = 0.5
|
_SCORE_THRESHOLD = 0.5
|
||||||
_MAX_RESULTS = 3
|
_MAX_RESULTS = 3
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
|
||||||
|
|
||||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
@ -124,8 +127,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.test_image = _Image.create_from_file(
|
self.test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
test_utils.get_test_data_path(
|
||||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
@ -220,7 +225,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Load the test image.
|
# Load the test image.
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||||
# NormalizedRect around the soccer ball.
|
# NormalizedRect around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _NormalizedRect(
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||||
|
@ -409,7 +415,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
with _ImageClassifier.create_from_options(options) as classifier:
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
# Load the test image.
|
# Load the test image.
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||||
# NormalizedRect around the soccer ball.
|
# NormalizedRect around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _NormalizedRect(
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||||
|
@ -482,7 +489,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
||||||
def test_classify_async_succeeds_with_region_of_interest(self):
|
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||||
# Load the test image.
|
# Load the test image.
|
||||||
test_image = _Image.create_from_file(
|
test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||||
# NormalizedRect around the soccer ball.
|
# NormalizedRect around the soccer ball.
|
||||||
roi = _NormalizedRect(
|
roi = _NormalizedRect(
|
||||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
"""Tests for image segmenter."""
|
"""Tests for image segmenter."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
@ -43,6 +44,7 @@ _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
|
||||||
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
||||||
_MASK_MAGNIFICATION_FACTOR = 10
|
_MASK_MAGNIFICATION_FACTOR = 10
|
||||||
_MASK_SIMILARITY_THRESHOLD = 0.98
|
_MASK_SIMILARITY_THRESHOLD = 0.98
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
|
||||||
|
|
||||||
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
||||||
|
@ -71,12 +73,16 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
# Load the test input image.
|
# Load the test input image.
|
||||||
self.test_image = _Image.create_from_file(
|
self.test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||||
# Loads ground truth segmentation file.
|
# Loads ground truth segmentation file.
|
||||||
gt_segmentation_data = cv2.imread(
|
gt_segmentation_data = cv2.imread(
|
||||||
test_utils.get_test_data_path(_SEGMENTATION_FILE), cv2.IMREAD_GRAYSCALE)
|
test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)),
|
||||||
|
cv2.IMREAD_GRAYSCALE)
|
||||||
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
"""Tests for object detector."""
|
"""Tests for object detector."""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -87,6 +88,7 @@ _ALLOW_LIST = ['cat', 'dog']
|
||||||
_DENY_LIST = ['cat']
|
_DENY_LIST = ['cat']
|
||||||
_SCORE_THRESHOLD = 0.3
|
_SCORE_THRESHOLD = 0.3
|
||||||
_MAX_RESULTS = 3
|
_MAX_RESULTS = 3
|
||||||
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
|
||||||
|
|
||||||
class ModelFileType(enum.Enum):
|
class ModelFileType(enum.Enum):
|
||||||
|
@ -99,8 +101,10 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.test_image = _Image.create_from_file(
|
self.test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
test_utils.get_test_data_path(
|
||||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
@ -395,5 +399,6 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
detector.detect_async(self.test_image, timestamp)
|
detector.detect_async(self.test_image, timestamp)
|
||||||
detector.close()
|
detector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user