Implement Image.create_from_file and update the object_detector_test.py file accordingly.

PiperOrigin-RevId: 477682930
This commit is contained in:
Jiuqiang Tang 2022-09-29 03:42:07 -07:00 committed by Copybara-Service
parent 80fd47b820
commit 554e2a9d69
7 changed files with 96 additions and 28 deletions

View File

@ -48,6 +48,8 @@ pybind_extension(
"//mediapipe/python/pybind:timestamp", "//mediapipe/python/pybind:timestamp",
"//mediapipe/python/pybind:validated_graph_config", "//mediapipe/python/pybind:validated_graph_config",
"//mediapipe/tasks/python/core/pybind:task_runner", "//mediapipe/tasks/python/core/pybind:task_runner",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
# Type registration. # Type registration.
"//mediapipe/framework:basic_types_registration", "//mediapipe/framework:basic_types_registration",
"//mediapipe/framework/formats:classification_registration", "//mediapipe/framework/formats:classification_registration",

View File

@ -15,6 +15,7 @@
"""Tests for mediapipe.python._framework_bindings.image.""" """Tests for mediapipe.python._framework_bindings.image."""
import gc import gc
import os
import random import random
import sys import sys
@ -23,6 +24,7 @@ import cv2
import numpy as np import numpy as np
import PIL.Image import PIL.Image
# resources dependency
from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image
from mediapipe.python._framework_bindings import image_frame from mediapipe.python._framework_bindings import image_frame
@ -185,6 +187,16 @@ class ImageTest(absltest.TestCase):
gc.collect() gc.collect()
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
def test_image_create_from_file(self):
image_path = os.path.join(
resources.GetRunfilesDir(),
'mediapipe/tasks/testdata/vision/cat.jpg')
loaded_image = Image.create_from_file(image_path)
self.assertEqual(loaded_image.width, 600)
self.assertEqual(loaded_image.height, 400)
self.assertEqual(loaded_image.channels, 3)
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -45,6 +45,8 @@ pybind_library(
":util", ":util",
"//mediapipe/framework:type_map", "//mediapipe/framework:type_map",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
], ],
) )

View File

@ -16,9 +16,11 @@
#include <memory> #include <memory>
#include "absl/strings/str_format.h"
#include "mediapipe/python/pybind/image_frame_util.h" #include "mediapipe/python/pybind/image_frame_util.h"
#include "mediapipe/python/pybind/util.h" #include "mediapipe/python/pybind/util.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "stb_image.h"
namespace mediapipe { namespace mediapipe {
namespace python { namespace python {
@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) {
image.is_aligned(16) image.is_aligned(16)
)doc"); )doc");
image.def_static(
"create_from_file",
[](const std::string& file_name) {
int width;
int height;
int channels;
auto* image_data =
stbi_load(file_name.c_str(), &width, &height, &channels,
/*desired_channels=*/0);
if (image_data == nullptr) {
throw RaisePyError(PyExc_RuntimeError,
absl::StrFormat("Image decoding failed (%s): %s",
stbi_failure_reason(), file_name)
.c_str());
}
ImageFrameSharedPtr image_frame;
switch (channels) {
case 1:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::GRAY8, width, height, width, image_data,
stbi_image_free);
break;
case 3:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGB, width, height, 3 * width, image_data,
stbi_image_free);
break;
case 4:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGBA, width, height, 4 * width, image_data,
stbi_image_free);
break;
default:
throw RaisePyError(
PyExc_RuntimeError,
absl::StrFormat(
"Expected image with 1 (grayscale), 3 (RGB) or 4 "
"(RGBA) channels, found %d channels.",
channels)
.c_str());
}
return Image(std::move(image_frame));
},
R"doc(Creates `Image` object from the image file.
Args:
file_name: Image file name.
Returns:
`Image` object.
Raises:
RuntimeError if the image file can't be decoded.
)doc",
py::arg("file_name"));
image.def_property_readonly("width", &Image::width) image.def_property_readonly("width", &Image::width)
.def_property_readonly("height", &Image::height) .def_property_readonly("height", &Image::height)
.def_property_readonly("channels", &Image::channels) .def_property_readonly("channels", &Image::channels)

View File

@ -16,7 +16,6 @@
import os import os
from absl import flags from absl import flags
import cv2
from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module from mediapipe.python._framework_bindings import image_frame as image_frame_module
@ -44,12 +43,3 @@ def get_test_data_path(file_or_dirname: str) -> str:
if f.endswith(file_or_dirname): if f.endswith(file_or_dirname):
return os.path.join(directory, f) return os.path.join(directory, f)
raise ValueError("No %s in test directory" % file_or_dirname) raise ValueError("No %s in test directory" % file_or_dirname)
# TODO: Implement image util module to read image data from file.
def read_test_image(image_file: str) -> _Image:
"""Reads a MediaPipe Image from the image file."""
image_data = cv2.imread(image_file)
if image_data.shape[2] != _RGB_CHANNELS:
raise ValueError("Input image must contain three channel rgb data.")
return _Image(_ImageFormat.SRGB, cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB))

View File

@ -44,7 +44,7 @@ _IMAGE_FILE = 'cats_and_dogs.jpg'
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=608, origin_y=164, width=381, height=432), origin_x=608, origin_y=161, width=381, height=439),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
@ -64,7 +64,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
]), ]),
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=257, origin_y=394, width=173, height=202), origin_x=256, origin_y=395, width=173, height=202),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
@ -74,7 +74,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
]), ]),
_Detection( _Detection(
bounding_box=_BoundingBox( bounding_box=_BoundingBox(
origin_x=362, origin_y=195, width=325, height=412), origin_x=362, origin_y=191, width=325, height=419),
categories=[ categories=[
_Category( _Category(
index=None, index=None,
@ -98,7 +98,7 @@ class ObjectDetectorTest(parameterized.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.test_image = test_util.read_test_image( self.test_image = _Image.create_from_file(
test_util.get_test_data_path(_IMAGE_FILE)) test_util.get_test_data_path(_IMAGE_FILE))
self.model_path = test_util.get_test_data_path(_MODEL_FILE) self.model_path = test_util.get_test_data_path(_MODEL_FILE)
@ -153,9 +153,9 @@ class ObjectDetectorTest(parameterized.TestCase):
detector = _ObjectDetector.create_from_options(options) detector = _ObjectDetector.create_from_options(options)
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
# Comparing results. # Comparing results.
self.assertEqual(image_result, expected_detection_result) self.assertEqual(detection_result, expected_detection_result)
# Closes the detector explicitly when the detector is not used in # Closes the detector explicitly when the detector is not used in
# a context. # a context.
detector.close() detector.close()
@ -179,9 +179,9 @@ class ObjectDetectorTest(parameterized.TestCase):
base_options=base_options, max_results=max_results) base_options=base_options, max_results=max_results)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
# Comparing results. # Comparing results.
self.assertEqual(image_result, expected_detection_result) self.assertEqual(detection_result, expected_detection_result)
def test_score_threshold_option(self): def test_score_threshold_option(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(
@ -189,8 +189,8 @@ class ObjectDetectorTest(parameterized.TestCase):
score_threshold=_SCORE_THRESHOLD) score_threshold=_SCORE_THRESHOLD)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
detections = image_result.detections detections = detection_result.detections
for detection in detections: for detection in detections:
score = detection.categories[0].score score = detection.categories[0].score
@ -204,8 +204,8 @@ class ObjectDetectorTest(parameterized.TestCase):
max_results=_MAX_RESULTS) max_results=_MAX_RESULTS)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
detections = image_result.detections detections = detection_result.detections
self.assertLessEqual( self.assertLessEqual(
len(detections), _MAX_RESULTS, 'Too many results returned.') len(detections), _MAX_RESULTS, 'Too many results returned.')
@ -216,8 +216,8 @@ class ObjectDetectorTest(parameterized.TestCase):
category_allowlist=_ALLOW_LIST) category_allowlist=_ALLOW_LIST)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
detections = image_result.detections detections = detection_result.detections
for detection in detections: for detection in detections:
label = detection.categories[0].category_name label = detection.categories[0].category_name
@ -230,8 +230,8 @@ class ObjectDetectorTest(parameterized.TestCase):
category_denylist=_DENY_LIST) category_denylist=_DENY_LIST)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
detections = image_result.detections detections = detection_result.detections
for detection in detections: for detection in detections:
label = detection.categories[0].category_name label = detection.categories[0].category_name
@ -257,8 +257,8 @@ class ObjectDetectorTest(parameterized.TestCase):
score_threshold=1) score_threshold=1)
with _ObjectDetector.create_from_options(options) as detector: with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input. # Performs object detection on the input.
image_result = detector.detect(self.test_image) detection_result = detector.detect(self.test_image)
self.assertEmpty(image_result.detections) self.assertEmpty(detection_result.detections)
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ObjectDetectorOptions( options = _ObjectDetectorOptions(

View File

@ -85,6 +85,10 @@ filegroup(
"selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_128_128_3_expected_mask.jpg",
"selfie_segm_144_256_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg",
], ],
visibility = [
"//mediapipe/python:__subpackages__",
"//mediapipe/tasks:internal",
],
) )
# TODO Create individual filegroup for models required for each Tasks. # TODO Create individual filegroup for models required for each Tasks.