Fix test utils bug when two file names have same ending.
PiperOrigin-RevId: 485780917
This commit is contained in:
parent
1b5da09a92
commit
ddf37d014e
|
@ -27,6 +27,8 @@ from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
|||
from mediapipe.tasks.python.metadata import metadata as _metadata
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||
|
||||
|
||||
class Tokenizer(enum.Enum):
|
||||
BERT_TOKENIZER = 0
|
||||
|
@ -810,7 +812,8 @@ class MetadataDisplayerTest(MetadataTest):
|
|||
actual_json = _metadata.convert_to_json(actual_buffer)
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
||||
golden_json_file_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
|
||||
with open(golden_json_file_path, "r") as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(actual_json, expected)
|
||||
|
@ -821,7 +824,8 @@ class MetadataDisplayerTest(MetadataTest):
|
|||
actual = displayer.get_metadata_json()
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
||||
golden_json_file_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
|
||||
expected = _read_file(golden_json_file_path, "r")
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
@ -848,7 +852,8 @@ class MetadataUtilTest(MetadataTest):
|
|||
metadata_json = _metadata.convert_to_json(metadata_buf)
|
||||
|
||||
# Verifies the generated json file.
|
||||
golden_json_file_path = test_utils.get_test_data_path("golden_json.json")
|
||||
golden_json_file_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
|
||||
expected = _read_file(golden_json_file_path, "r")
|
||||
self.assertEqual(metadata_json, expected)
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# ==============================================================================
|
||||
"""Tests for metadata_writer.image_classifier."""
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
@ -23,18 +25,25 @@ from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
|
|||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||
_FLOAT_MODEL = test_utils.get_test_data_path(
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite")
|
||||
os.path.join(_TEST_DATA_DIR,
|
||||
"mobilenet_v2_1.0_224_without_metadata.tflite"))
|
||||
_QUANT_MODEL = test_utils.get_test_data_path(
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite")
|
||||
_LABEL_FILE = test_utils.get_test_data_path("labels.txt")
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
|
||||
os.path.join(_TEST_DATA_DIR,
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite"))
|
||||
_LABEL_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "labels.txt"))
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "score_calibration.txt"))
|
||||
_SCORE_CALIBRATION_FILENAME = "score_calibration.txt"
|
||||
_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2
|
||||
_NORM_MEAN = 127.5
|
||||
_NORM_STD = 127.5
|
||||
_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json")
|
||||
_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json")
|
||||
_FLOAT_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "mobilenet_v2_1.0_224.json"))
|
||||
_QUANT_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "mobilenet_v2_1.0_224_quant.json"))
|
||||
|
||||
|
||||
class ImageClassifierTest(parameterized.TestCase):
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
"""Tests for metadata info classes."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -26,13 +27,15 @@ from mediapipe.tasks.python.metadata import metadata as _metadata
|
|||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt")
|
||||
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "score_calibration.txt"))
|
||||
|
||||
|
||||
class GeneralMdTest(absltest.TestCase):
|
||||
|
||||
_EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path(
|
||||
"general_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "general_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
general_md = metadata_info.GeneralMd(
|
||||
|
@ -59,7 +62,7 @@ class GeneralMdTest(absltest.TestCase):
|
|||
class AssociatedFileMdTest(absltest.TestCase):
|
||||
|
||||
_EXPECTED_META_JSON = test_utils.get_test_data_path(
|
||||
"associated_file_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "associated_file_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
file_md = metadata_info.AssociatedFileMd(
|
||||
|
@ -92,11 +95,11 @@ class TensorMdTest(parameterized.TestCase):
|
|||
_LABEL_FILE_EN = "labels.txt"
|
||||
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
|
||||
_EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"feature_tensor_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "feature_tensor_meta.json"))
|
||||
_EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"image_tensor_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "image_tensor_meta.json"))
|
||||
_EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"bounding_box_tensor_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "bounding_box_tensor_meta.json"))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
|
@ -142,11 +145,11 @@ class InputImageTensorMdTest(parameterized.TestCase):
|
|||
_NORM_STD = (127.5, 127.5, 127.5)
|
||||
_COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB
|
||||
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"input_image_tensor_float_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "input_image_tensor_float_meta.json"))
|
||||
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"input_image_tensor_uint8_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "input_image_tensor_uint8_meta.json"))
|
||||
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"input_image_tensor_unsupported_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "input_image_tensor_unsupported_meta.json"))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
|
@ -196,11 +199,12 @@ class ClassificationTensorMdTest(parameterized.TestCase):
|
|||
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
|
||||
_CALIBRATION_DEFAULT_SCORE = 0.2
|
||||
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"classification_tensor_float_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "classification_tensor_float_meta.json"))
|
||||
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"classification_tensor_uint8_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "classification_tensor_uint8_meta.json"))
|
||||
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"classification_tensor_unsupported_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR,
|
||||
"classification_tensor_unsupported_meta.json"))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
|
@ -243,9 +247,9 @@ class ClassificationTensorMdTest(parameterized.TestCase):
|
|||
class ScoreCalibrationMdTest(absltest.TestCase):
|
||||
_DEFAULT_VALUE = 0.2
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"score_calibration_tensor_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "score_calibration_tensor_meta.json"))
|
||||
_EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path(
|
||||
"score_calibration_file_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "score_calibration_file_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
score_calibration_md = metadata_info.ScoreCalibrationMd(
|
||||
|
@ -310,7 +314,7 @@ class ScoreCalibrationMdTest(absltest.TestCase):
|
|||
class ScoreThresholdingMdTest(absltest.TestCase):
|
||||
_DEFAULT_GLOBAL_THRESHOLD = 0.5
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
"score_thresholding_meta.json")
|
||||
os.path.join(_TEST_DATA_DIR, "score_thresholding_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
score_thresholding_md = metadata_info.ScoreThresholdingMd(
|
||||
|
|
|
@ -21,9 +21,12 @@ from absl.testing import absltest
|
|||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/metadata'
|
||||
|
||||
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
|
||||
'mobilenet_v1_0.25_224_1_default_1.tflite')
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt')
|
||||
os.path.join(_TEST_DATA_DIR, 'mobilenet_v1_0.25_224_1_default_1.tflite'))
|
||||
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
|
||||
|
||||
|
||||
class LabelsTest(absltest.TestCase):
|
||||
|
@ -85,8 +88,7 @@ class ScoreCalibrationTest(absltest.TestCase):
|
|||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Expected empty lines or 3 or 4 parameters per line in score '
|
||||
'calibration file, but got 2.'
|
||||
):
|
||||
'calibration file, but got 2.'):
|
||||
metadata_writer.ScoreCalibration.create_from_file(
|
||||
metadata_writer.ScoreCalibration.transformation_types.LOG,
|
||||
test_file)
|
||||
|
|
|
@ -42,13 +42,15 @@ def test_srcdir():
|
|||
raise RuntimeError("Missing TEST_SRCDIR environment.")
|
||||
|
||||
|
||||
def get_test_data_path(file_or_dirname: str) -> str:
|
||||
def get_test_data_path(file_or_dirname_path: str) -> str:
|
||||
"""Returns full test data path."""
|
||||
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
||||
for f in subdirs + files:
|
||||
if f.endswith(file_or_dirname):
|
||||
return os.path.join(directory, f)
|
||||
raise ValueError("No %s in test directory" % file_or_dirname)
|
||||
path = os.path.join(directory, f)
|
||||
if path.endswith(file_or_dirname_path):
|
||||
return path
|
||||
raise ValueError("No %s in test directory: %s." %
|
||||
(file_or_dirname_path, test_srcdir()))
|
||||
|
||||
|
||||
def create_calibration_file(file_dir: str,
|
||||
|
|
|
@ -14,10 +14,12 @@
|
|||
"""Tests for image classifier."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image
|
||||
|
@ -48,6 +50,7 @@ _ALLOW_LIST = ['cheeseburger', 'guacamole']
|
|||
_DENY_LIST = ['cheeseburger']
|
||||
_SCORE_THRESHOLD = 0.5
|
||||
_MAX_RESULTS = 3
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
|
@ -124,8 +127,10 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
@ -220,7 +225,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
|
@ -409,7 +415,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
|
@ -482,7 +489,8 @@ class ImageClassifierTest(parameterized.TestCase):
|
|||
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Tests for image segmenter."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
|
@ -43,6 +44,7 @@ _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
|
|||
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
||||
_MASK_MAGNIFICATION_FACTOR = 10
|
||||
_MASK_SIMILARITY_THRESHOLD = 0.98
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
||||
|
@ -71,12 +73,16 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
super().setUp()
|
||||
# Load the test input image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||
# Loads ground truth segmentation file.
|
||||
gt_segmentation_data = cv2.imread(
|
||||
test_utils.get_test_data_path(_SEGMENTATION_FILE), cv2.IMREAD_GRAYSCALE)
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)),
|
||||
cv2.IMREAD_GRAYSCALE)
|
||||
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Tests for object detector."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -87,6 +88,7 @@ _ALLOW_LIST = ['cat', 'dog']
|
|||
_DENY_LIST = ['cat']
|
||||
_SCORE_THRESHOLD = 0.3
|
||||
_MAX_RESULTS = 3
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
|
@ -99,8 +101,10 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
@ -395,5 +399,6 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
detector.detect_async(self.test_image, timestamp)
|
||||
detector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user