Implement Image.create_from_file and update the object_detector_test.py file accordingly.
PiperOrigin-RevId: 477682930
This commit is contained in:
parent
80fd47b820
commit
554e2a9d69
|
@ -48,6 +48,8 @@ pybind_extension(
|
||||||
"//mediapipe/python/pybind:timestamp",
|
"//mediapipe/python/pybind:timestamp",
|
||||||
"//mediapipe/python/pybind:validated_graph_config",
|
"//mediapipe/python/pybind:validated_graph_config",
|
||||||
"//mediapipe/tasks/python/core/pybind:task_runner",
|
"//mediapipe/tasks/python/core/pybind:task_runner",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@stblib//:stb_image",
|
||||||
# Type registration.
|
# Type registration.
|
||||||
"//mediapipe/framework:basic_types_registration",
|
"//mediapipe/framework:basic_types_registration",
|
||||||
"//mediapipe/framework/formats:classification_registration",
|
"//mediapipe/framework/formats:classification_registration",
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""Tests for mediapipe.python._framework_bindings.image."""
|
"""Tests for mediapipe.python._framework_bindings.image."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -23,6 +24,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
|
|
||||||
|
# resources dependency
|
||||||
from mediapipe.python._framework_bindings import image
|
from mediapipe.python._framework_bindings import image
|
||||||
from mediapipe.python._framework_bindings import image_frame
|
from mediapipe.python._framework_bindings import image_frame
|
||||||
|
|
||||||
|
@ -185,6 +187,16 @@ class ImageTest(absltest.TestCase):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
|
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
|
||||||
|
|
||||||
|
def test_image_create_from_file(self):
|
||||||
|
image_path = os.path.join(
|
||||||
|
resources.GetRunfilesDir(),
|
||||||
|
'mediapipe/tasks/testdata/vision/cat.jpg')
|
||||||
|
loaded_image = Image.create_from_file(image_path)
|
||||||
|
self.assertEqual(loaded_image.width, 600)
|
||||||
|
self.assertEqual(loaded_image.height, 400)
|
||||||
|
self.assertEqual(loaded_image.channels, 3)
|
||||||
|
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -45,6 +45,8 @@ pybind_library(
|
||||||
":util",
|
":util",
|
||||||
"//mediapipe/framework:type_map",
|
"//mediapipe/framework:type_map",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@stblib//:stb_image",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/python/pybind/image_frame_util.h"
|
#include "mediapipe/python/pybind/image_frame_util.h"
|
||||||
#include "mediapipe/python/pybind/util.h"
|
#include "mediapipe/python/pybind/util.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
|
#include "stb_image.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace python {
|
namespace python {
|
||||||
|
@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) {
|
||||||
image.is_aligned(16)
|
image.is_aligned(16)
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
image.def_static(
|
||||||
|
"create_from_file",
|
||||||
|
[](const std::string& file_name) {
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
int channels;
|
||||||
|
auto* image_data =
|
||||||
|
stbi_load(file_name.c_str(), &width, &height, &channels,
|
||||||
|
/*desired_channels=*/0);
|
||||||
|
if (image_data == nullptr) {
|
||||||
|
throw RaisePyError(PyExc_RuntimeError,
|
||||||
|
absl::StrFormat("Image decoding failed (%s): %s",
|
||||||
|
stbi_failure_reason(), file_name)
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
ImageFrameSharedPtr image_frame;
|
||||||
|
switch (channels) {
|
||||||
|
case 1:
|
||||||
|
image_frame = std::make_shared<ImageFrame>(
|
||||||
|
ImageFormat::GRAY8, width, height, width, image_data,
|
||||||
|
stbi_image_free);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
image_frame = std::make_shared<ImageFrame>(
|
||||||
|
ImageFormat::SRGB, width, height, 3 * width, image_data,
|
||||||
|
stbi_image_free);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
image_frame = std::make_shared<ImageFrame>(
|
||||||
|
ImageFormat::SRGBA, width, height, 4 * width, image_data,
|
||||||
|
stbi_image_free);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw RaisePyError(
|
||||||
|
PyExc_RuntimeError,
|
||||||
|
absl::StrFormat(
|
||||||
|
"Expected image with 1 (grayscale), 3 (RGB) or 4 "
|
||||||
|
"(RGBA) channels, found %d channels.",
|
||||||
|
channels)
|
||||||
|
.c_str());
|
||||||
|
}
|
||||||
|
return Image(std::move(image_frame));
|
||||||
|
},
|
||||||
|
R"doc(Creates `Image` object from the image file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: Image file name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Image` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if the image file can't be decoded.
|
||||||
|
)doc",
|
||||||
|
py::arg("file_name"));
|
||||||
|
|
||||||
image.def_property_readonly("width", &Image::width)
|
image.def_property_readonly("width", &Image::width)
|
||||||
.def_property_readonly("height", &Image::height)
|
.def_property_readonly("height", &Image::height)
|
||||||
.def_property_readonly("channels", &Image::channels)
|
.def_property_readonly("channels", &Image::channels)
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from absl import flags
|
from absl import flags
|
||||||
import cv2
|
|
||||||
|
|
||||||
from mediapipe.python._framework_bindings import image as image_module
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||||
|
@ -44,12 +43,3 @@ def get_test_data_path(file_or_dirname: str) -> str:
|
||||||
if f.endswith(file_or_dirname):
|
if f.endswith(file_or_dirname):
|
||||||
return os.path.join(directory, f)
|
return os.path.join(directory, f)
|
||||||
raise ValueError("No %s in test directory" % file_or_dirname)
|
raise ValueError("No %s in test directory" % file_or_dirname)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Implement image util module to read image data from file.
|
|
||||||
def read_test_image(image_file: str) -> _Image:
|
|
||||||
"""Reads a MediaPipe Image from the image file."""
|
|
||||||
image_data = cv2.imread(image_file)
|
|
||||||
if image_data.shape[2] != _RGB_CHANNELS:
|
|
||||||
raise ValueError("Input image must contain three channel rgb data.")
|
|
||||||
return _Image(_ImageFormat.SRGB, cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB))
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ _IMAGE_FILE = 'cats_and_dogs.jpg'
|
||||||
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
||||||
_Detection(
|
_Detection(
|
||||||
bounding_box=_BoundingBox(
|
bounding_box=_BoundingBox(
|
||||||
origin_x=608, origin_y=164, width=381, height=432),
|
origin_x=608, origin_y=161, width=381, height=439),
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=None,
|
index=None,
|
||||||
|
@ -64,7 +64,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
||||||
]),
|
]),
|
||||||
_Detection(
|
_Detection(
|
||||||
bounding_box=_BoundingBox(
|
bounding_box=_BoundingBox(
|
||||||
origin_x=257, origin_y=394, width=173, height=202),
|
origin_x=256, origin_y=395, width=173, height=202),
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=None,
|
index=None,
|
||||||
|
@ -74,7 +74,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
||||||
]),
|
]),
|
||||||
_Detection(
|
_Detection(
|
||||||
bounding_box=_BoundingBox(
|
bounding_box=_BoundingBox(
|
||||||
origin_x=362, origin_y=195, width=325, height=412),
|
origin_x=362, origin_y=191, width=325, height=419),
|
||||||
categories=[
|
categories=[
|
||||||
_Category(
|
_Category(
|
||||||
index=None,
|
index=None,
|
||||||
|
@ -98,7 +98,7 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.test_image = test_util.read_test_image(
|
self.test_image = _Image.create_from_file(
|
||||||
test_util.get_test_data_path(_IMAGE_FILE))
|
test_util.get_test_data_path(_IMAGE_FILE))
|
||||||
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
||||||
|
|
||||||
|
@ -153,9 +153,9 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
detector = _ObjectDetector.create_from_options(options)
|
detector = _ObjectDetector.create_from_options(options)
|
||||||
|
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self.assertEqual(image_result, expected_detection_result)
|
self.assertEqual(detection_result, expected_detection_result)
|
||||||
# Closes the detector explicitly when the detector is not used in
|
# Closes the detector explicitly when the detector is not used in
|
||||||
# a context.
|
# a context.
|
||||||
detector.close()
|
detector.close()
|
||||||
|
@ -179,9 +179,9 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
base_options=base_options, max_results=max_results)
|
base_options=base_options, max_results=max_results)
|
||||||
with _ObjectDetector.create_from_options(options) as detector:
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
# Comparing results.
|
# Comparing results.
|
||||||
self.assertEqual(image_result, expected_detection_result)
|
self.assertEqual(detection_result, expected_detection_result)
|
||||||
|
|
||||||
def test_score_threshold_option(self):
|
def test_score_threshold_option(self):
|
||||||
options = _ObjectDetectorOptions(
|
options = _ObjectDetectorOptions(
|
||||||
|
@ -189,8 +189,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
score_threshold=_SCORE_THRESHOLD)
|
score_threshold=_SCORE_THRESHOLD)
|
||||||
with _ObjectDetector.create_from_options(options) as detector:
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
detections = image_result.detections
|
detections = detection_result.detections
|
||||||
|
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
score = detection.categories[0].score
|
score = detection.categories[0].score
|
||||||
|
@ -204,8 +204,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
max_results=_MAX_RESULTS)
|
max_results=_MAX_RESULTS)
|
||||||
with _ObjectDetector.create_from_options(options) as detector:
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
detections = image_result.detections
|
detections = detection_result.detections
|
||||||
|
|
||||||
self.assertLessEqual(
|
self.assertLessEqual(
|
||||||
len(detections), _MAX_RESULTS, 'Too many results returned.')
|
len(detections), _MAX_RESULTS, 'Too many results returned.')
|
||||||
|
@ -216,8 +216,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
category_allowlist=_ALLOW_LIST)
|
category_allowlist=_ALLOW_LIST)
|
||||||
with _ObjectDetector.create_from_options(options) as detector:
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
detections = image_result.detections
|
detections = detection_result.detections
|
||||||
|
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
label = detection.categories[0].category_name
|
label = detection.categories[0].category_name
|
||||||
|
@ -230,8 +230,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
category_denylist=_DENY_LIST)
|
category_denylist=_DENY_LIST)
|
||||||
with _ObjectDetector.create_from_options(options) as detector:
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
detections = image_result.detections
|
detections = detection_result.detections
|
||||||
|
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
label = detection.categories[0].category_name
|
label = detection.categories[0].category_name
|
||||||
|
@ -257,8 +257,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
||||||
score_threshold=1)
|
score_threshold=1)
|
||||||
with _ObjectDetector.create_from_options(options) as detector:
|
with _ObjectDetector.create_from_options(options) as detector:
|
||||||
# Performs object detection on the input.
|
# Performs object detection on the input.
|
||||||
image_result = detector.detect(self.test_image)
|
detection_result = detector.detect(self.test_image)
|
||||||
self.assertEmpty(image_result.detections)
|
self.assertEmpty(detection_result.detections)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
def test_missing_result_callback(self):
|
||||||
options = _ObjectDetectorOptions(
|
options = _ObjectDetectorOptions(
|
||||||
|
|
4
mediapipe/tasks/testdata/vision/BUILD
vendored
4
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -85,6 +85,10 @@ filegroup(
|
||||||
"selfie_segm_128_128_3_expected_mask.jpg",
|
"selfie_segm_128_128_3_expected_mask.jpg",
|
||||||
"selfie_segm_144_256_3_expected_mask.jpg",
|
"selfie_segm_144_256_3_expected_mask.jpg",
|
||||||
],
|
],
|
||||||
|
visibility = [
|
||||||
|
"//mediapipe/python:__subpackages__",
|
||||||
|
"//mediapipe/tasks:internal",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO Create individual filegroup for models required for each Tasks.
|
# TODO Create individual filegroup for models required for each Tasks.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user