Fix test utils bug when two file names have same ending.

PiperOrigin-RevId: 485780917
This commit is contained in:
MediaPipe Team 2022-11-02 22:30:37 -07:00 committed by Copybara-Service
parent 1b5da09a92
commit ddf37d014e
8 changed files with 83 additions and 42 deletions

View File

@ -27,6 +27,8 @@ from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata import metadata as _metadata from mediapipe.tasks.python.metadata import metadata as _metadata
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
class Tokenizer(enum.Enum): class Tokenizer(enum.Enum):
BERT_TOKENIZER = 0 BERT_TOKENIZER = 0
@ -810,7 +812,8 @@ class MetadataDisplayerTest(MetadataTest):
actual_json = _metadata.convert_to_json(actual_buffer) actual_json = _metadata.convert_to_json(actual_buffer)
# Verifies the generated json file. # Verifies the generated json file.
golden_json_file_path = test_utils.get_test_data_path("golden_json.json") golden_json_file_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
with open(golden_json_file_path, "r") as f: with open(golden_json_file_path, "r") as f:
expected = f.read() expected = f.read()
self.assertEqual(actual_json, expected) self.assertEqual(actual_json, expected)
@ -821,7 +824,8 @@ class MetadataDisplayerTest(MetadataTest):
actual = displayer.get_metadata_json() actual = displayer.get_metadata_json()
# Verifies the generated json file. # Verifies the generated json file.
golden_json_file_path = test_utils.get_test_data_path("golden_json.json") golden_json_file_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
expected = _read_file(golden_json_file_path, "r") expected = _read_file(golden_json_file_path, "r")
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
@ -848,7 +852,8 @@ class MetadataUtilTest(MetadataTest):
metadata_json = _metadata.convert_to_json(metadata_buf) metadata_json = _metadata.convert_to_json(metadata_buf)
# Verifies the generated json file. # Verifies the generated json file.
golden_json_file_path = test_utils.get_test_data_path("golden_json.json") golden_json_file_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "golden_json.json"))
expected = _read_file(golden_json_file_path, "r") expected = _read_file(golden_json_file_path, "r")
self.assertEqual(metadata_json, expected) self.assertEqual(metadata_json, expected)

View File

@ -14,6 +14,8 @@
# ============================================================================== # ==============================================================================
"""Tests for metadata_writer.image_classifier.""" """Tests for metadata_writer.image_classifier."""
import os
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
@ -23,18 +25,25 @@ from mediapipe.tasks.python.metadata.metadata_writers import image_classifier
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
_FLOAT_MODEL = test_utils.get_test_data_path( _FLOAT_MODEL = test_utils.get_test_data_path(
"mobilenet_v2_1.0_224_without_metadata.tflite") os.path.join(_TEST_DATA_DIR,
"mobilenet_v2_1.0_224_without_metadata.tflite"))
_QUANT_MODEL = test_utils.get_test_data_path( _QUANT_MODEL = test_utils.get_test_data_path(
"mobilenet_v2_1.0_224_quant_without_metadata.tflite") os.path.join(_TEST_DATA_DIR,
_LABEL_FILE = test_utils.get_test_data_path("labels.txt") "mobilenet_v2_1.0_224_quant_without_metadata.tflite"))
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt") _LABEL_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "labels.txt"))
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "score_calibration.txt"))
_SCORE_CALIBRATION_FILENAME = "score_calibration.txt" _SCORE_CALIBRATION_FILENAME = "score_calibration.txt"
_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2 _DEFAULT_SCORE_CALIBRATION_VALUE = 0.2
_NORM_MEAN = 127.5 _NORM_MEAN = 127.5
_NORM_STD = 127.5 _NORM_STD = 127.5
_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json") _FLOAT_JSON = test_utils.get_test_data_path(
_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json") os.path.join(_TEST_DATA_DIR, "mobilenet_v2_1.0_224.json"))
_QUANT_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "mobilenet_v2_1.0_224_quant.json"))
class ImageClassifierTest(parameterized.TestCase): class ImageClassifierTest(parameterized.TestCase):

View File

@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for metadata info classes.""" """Tests for metadata info classes."""
import os
import tempfile import tempfile
from absl.testing import absltest from absl.testing import absltest
@ -26,13 +27,15 @@ from mediapipe.tasks.python.metadata import metadata as _metadata
from mediapipe.tasks.python.metadata.metadata_writers import metadata_info from mediapipe.tasks.python.metadata.metadata_writers import metadata_info
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt") _TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata"
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "score_calibration.txt"))
class GeneralMdTest(absltest.TestCase): class GeneralMdTest(absltest.TestCase):
_EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path( _EXPECTED_GENERAL_META_JSON = test_utils.get_test_data_path(
"general_meta.json") os.path.join(_TEST_DATA_DIR, "general_meta.json"))
def test_create_metadata_should_succeed(self): def test_create_metadata_should_succeed(self):
general_md = metadata_info.GeneralMd( general_md = metadata_info.GeneralMd(
@ -59,7 +62,7 @@ class GeneralMdTest(absltest.TestCase):
class AssociatedFileMdTest(absltest.TestCase): class AssociatedFileMdTest(absltest.TestCase):
_EXPECTED_META_JSON = test_utils.get_test_data_path( _EXPECTED_META_JSON = test_utils.get_test_data_path(
"associated_file_meta.json") os.path.join(_TEST_DATA_DIR, "associated_file_meta.json"))
def test_create_metadata_should_succeed(self): def test_create_metadata_should_succeed(self):
file_md = metadata_info.AssociatedFileMd( file_md = metadata_info.AssociatedFileMd(
@ -92,11 +95,11 @@ class TensorMdTest(parameterized.TestCase):
_LABEL_FILE_EN = "labels.txt" _LABEL_FILE_EN = "labels.txt"
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
_EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_FEATURE_TENSOR_JSON = test_utils.get_test_data_path(
"feature_tensor_meta.json") os.path.join(_TEST_DATA_DIR, "feature_tensor_meta.json"))
_EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_IMAGE_TENSOR_JSON = test_utils.get_test_data_path(
"image_tensor_meta.json") os.path.join(_TEST_DATA_DIR, "image_tensor_meta.json"))
_EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_BOUNDING_BOX_TENSOR_JSON = test_utils.get_test_data_path(
"bounding_box_tensor_meta.json") os.path.join(_TEST_DATA_DIR, "bounding_box_tensor_meta.json"))
@parameterized.named_parameters( @parameterized.named_parameters(
{ {
@ -142,11 +145,11 @@ class InputImageTensorMdTest(parameterized.TestCase):
_NORM_STD = (127.5, 127.5, 127.5) _NORM_STD = (127.5, 127.5, 127.5)
_COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB _COLOR_SPACE_TYPE = _metadata_fb.ColorSpaceType.RGB
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
"input_image_tensor_float_meta.json") os.path.join(_TEST_DATA_DIR, "input_image_tensor_float_meta.json"))
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
"input_image_tensor_uint8_meta.json") os.path.join(_TEST_DATA_DIR, "input_image_tensor_uint8_meta.json"))
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
"input_image_tensor_unsupported_meta.json") os.path.join(_TEST_DATA_DIR, "input_image_tensor_unsupported_meta.json"))
@parameterized.named_parameters( @parameterized.named_parameters(
{ {
@ -196,11 +199,12 @@ class ClassificationTensorMdTest(parameterized.TestCase):
_LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese. _LABEL_FILE_CN = "labels_cn.txt" # Locale label file in Chinese.
_CALIBRATION_DEFAULT_SCORE = 0.2 _CALIBRATION_DEFAULT_SCORE = 0.2
_EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_FLOAT_TENSOR_JSON = test_utils.get_test_data_path(
"classification_tensor_float_meta.json") os.path.join(_TEST_DATA_DIR, "classification_tensor_float_meta.json"))
_EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_UINT8_TENSOR_JSON = test_utils.get_test_data_path(
"classification_tensor_uint8_meta.json") os.path.join(_TEST_DATA_DIR, "classification_tensor_uint8_meta.json"))
_EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_UNSUPPORTED_TENSOR_JSON = test_utils.get_test_data_path(
"classification_tensor_unsupported_meta.json") os.path.join(_TEST_DATA_DIR,
"classification_tensor_unsupported_meta.json"))
@parameterized.named_parameters( @parameterized.named_parameters(
{ {
@ -243,9 +247,9 @@ class ClassificationTensorMdTest(parameterized.TestCase):
class ScoreCalibrationMdTest(absltest.TestCase): class ScoreCalibrationMdTest(absltest.TestCase):
_DEFAULT_VALUE = 0.2 _DEFAULT_VALUE = 0.2
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
"score_calibration_tensor_meta.json") os.path.join(_TEST_DATA_DIR, "score_calibration_tensor_meta.json"))
_EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path( _EXPECTED_MODEL_META_JSON = test_utils.get_test_data_path(
"score_calibration_file_meta.json") os.path.join(_TEST_DATA_DIR, "score_calibration_file_meta.json"))
def test_create_metadata_should_succeed(self): def test_create_metadata_should_succeed(self):
score_calibration_md = metadata_info.ScoreCalibrationMd( score_calibration_md = metadata_info.ScoreCalibrationMd(
@ -310,7 +314,7 @@ class ScoreCalibrationMdTest(absltest.TestCase):
class ScoreThresholdingMdTest(absltest.TestCase): class ScoreThresholdingMdTest(absltest.TestCase):
_DEFAULT_GLOBAL_THRESHOLD = 0.5 _DEFAULT_GLOBAL_THRESHOLD = 0.5
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path( _EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
"score_thresholding_meta.json") os.path.join(_TEST_DATA_DIR, "score_thresholding_meta.json"))
def test_create_metadata_should_succeed(self): def test_create_metadata_should_succeed(self):
score_thresholding_md = metadata_info.ScoreThresholdingMd( score_thresholding_md = metadata_info.ScoreThresholdingMd(

View File

@ -21,9 +21,12 @@ from absl.testing import absltest
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/metadata'
_IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path( _IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path(
'mobilenet_v1_0.25_224_1_default_1.tflite') os.path.join(_TEST_DATA_DIR, 'mobilenet_v1_0.25_224_1_default_1.tflite'))
_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt') _SCORE_CALIBRATION_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'score_calibration.txt'))
class LabelsTest(absltest.TestCase): class LabelsTest(absltest.TestCase):
@ -85,8 +88,7 @@ class ScoreCalibrationTest(absltest.TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'Expected empty lines or 3 or 4 parameters per line in score ' 'Expected empty lines or 3 or 4 parameters per line in score '
'calibration file, but got 2.' 'calibration file, but got 2.'):
):
metadata_writer.ScoreCalibration.create_from_file( metadata_writer.ScoreCalibration.create_from_file(
metadata_writer.ScoreCalibration.transformation_types.LOG, metadata_writer.ScoreCalibration.transformation_types.LOG,
test_file) test_file)

View File

@ -42,13 +42,15 @@ def test_srcdir():
raise RuntimeError("Missing TEST_SRCDIR environment.") raise RuntimeError("Missing TEST_SRCDIR environment.")
def get_test_data_path(file_or_dirname: str) -> str: def get_test_data_path(file_or_dirname_path: str) -> str:
"""Returns full test data path.""" """Returns full test data path."""
for (directory, subdirs, files) in os.walk(test_srcdir()): for (directory, subdirs, files) in os.walk(test_srcdir()):
for f in subdirs + files: for f in subdirs + files:
if f.endswith(file_or_dirname): path = os.path.join(directory, f)
return os.path.join(directory, f) if path.endswith(file_or_dirname_path):
raise ValueError("No %s in test directory" % file_or_dirname) return path
raise ValueError("No %s in test directory: %s." %
(file_or_dirname_path, test_srcdir()))
def create_calibration_file(file_dir: str, def create_calibration_file(file_dir: str,

View File

@ -14,10 +14,12 @@
"""Tests for image classifier.""" """Tests for image classifier."""
import enum import enum
import os
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image
@ -48,6 +50,7 @@ _ALLOW_LIST = ['cheeseburger', 'guacamole']
_DENY_LIST = ['cheeseburger'] _DENY_LIST = ['cheeseburger']
_SCORE_THRESHOLD = 0.5 _SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3 _MAX_RESULTS = 3
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
@ -124,8 +127,10 @@ class ImageClassifierTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE)) test_utils.get_test_data_path(
self.model_path = test_utils.get_test_data_path(_MODEL_FILE) os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -220,7 +225,8 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg')) test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
# NormalizedRect around the soccer ball. # NormalizedRect around the soccer ball.
roi = _NormalizedRect( roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427) x_center=0.532, y_center=0.521, width=0.164, height=0.427)
@ -409,7 +415,8 @@ class ImageClassifierTest(parameterized.TestCase):
with _ImageClassifier.create_from_options(options) as classifier: with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg')) test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
# NormalizedRect around the soccer ball. # NormalizedRect around the soccer ball.
roi = _NormalizedRect( roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427) x_center=0.532, y_center=0.521, width=0.164, height=0.427)
@ -482,7 +489,8 @@ class ImageClassifierTest(parameterized.TestCase):
def test_classify_async_succeeds_with_region_of_interest(self): def test_classify_async_succeeds_with_region_of_interest(self):
# Load the test image. # Load the test image.
test_image = _Image.create_from_file( test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg')) test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')))
# NormalizedRect around the soccer ball. # NormalizedRect around the soccer ball.
roi = _NormalizedRect( roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427) x_center=0.532, y_center=0.521, width=0.164, height=0.427)

View File

@ -14,6 +14,7 @@
"""Tests for image segmenter.""" """Tests for image segmenter."""
import enum import enum
import os
from typing import List from typing import List
from unittest import mock from unittest import mock
@ -43,6 +44,7 @@ _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_MASK_MAGNIFICATION_FACTOR = 10 _MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98 _MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _similar_to_uint8_mask(actual_mask, expected_mask): def _similar_to_uint8_mask(actual_mask, expected_mask):
@ -71,12 +73,16 @@ class ImageSegmenterTest(parameterized.TestCase):
super().setUp() super().setUp()
# Load the test input image. # Load the test input image.
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE)) test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
# Loads ground truth segmentation file. # Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread( gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(_SEGMENTATION_FILE), cv2.IMREAD_GRAYSCALE) test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)),
cv2.IMREAD_GRAYSCALE)
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data) self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
self.model_path = test_utils.get_test_data_path(_MODEL_FILE) self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.

View File

@ -14,6 +14,7 @@
"""Tests for object detector.""" """Tests for object detector."""
import enum import enum
import os
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
@ -87,6 +88,7 @@ _ALLOW_LIST = ['cat', 'dog']
_DENY_LIST = ['cat'] _DENY_LIST = ['cat']
_SCORE_THRESHOLD = 0.3 _SCORE_THRESHOLD = 0.3
_MAX_RESULTS = 3 _MAX_RESULTS = 3
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
class ModelFileType(enum.Enum): class ModelFileType(enum.Enum):
@ -99,8 +101,10 @@ class ObjectDetectorTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE)) test_utils.get_test_data_path(
self.model_path = test_utils.get_test_data_path(_MODEL_FILE) os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -395,5 +399,6 @@ class ObjectDetectorTest(parameterized.TestCase):
detector.detect_async(self.test_image, timestamp) detector.detect_async(self.test_image, timestamp)
detector.close() detector.close()
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()