Implement Image.create_from_file and update the object_detector_test.py file accordingly.

PiperOrigin-RevId: 477682930
This commit is contained in:
Jiuqiang Tang 2022-09-29 03:42:07 -07:00 committed by Copybara-Service
parent 80fd47b820
commit 554e2a9d69
7 changed files with 96 additions and 28 deletions

View File

@ -48,6 +48,8 @@ pybind_extension(
"//mediapipe/python/pybind:timestamp",
"//mediapipe/python/pybind:validated_graph_config",
"//mediapipe/tasks/python/core/pybind:task_runner",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
# Type registration.
"//mediapipe/framework:basic_types_registration",
"//mediapipe/framework/formats:classification_registration",

View File

@ -15,6 +15,7 @@
"""Tests for mediapipe.python._framework_bindings.image."""
import gc
import os
import random
import sys
@ -23,6 +24,7 @@ import cv2
import numpy as np
import PIL.Image
# resources dependency
from mediapipe.python._framework_bindings import image
from mediapipe.python._framework_bindings import image_frame
@ -185,6 +187,16 @@ class ImageTest(absltest.TestCase):
gc.collect()
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
def test_image_create_from_file(self):
image_path = os.path.join(
resources.GetRunfilesDir(),
'mediapipe/tasks/testdata/vision/cat.jpg')
loaded_image = Image.create_from_file(image_path)
self.assertEqual(loaded_image.width, 600)
self.assertEqual(loaded_image.height, 400)
self.assertEqual(loaded_image.channels, 3)
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
if __name__ == '__main__':
absltest.main()

View File

@ -45,6 +45,8 @@ pybind_library(
":util",
"//mediapipe/framework:type_map",
"//mediapipe/framework/formats:image",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
],
)

View File

@ -16,9 +16,11 @@
#include <memory>
#include "absl/strings/str_format.h"
#include "mediapipe/python/pybind/image_frame_util.h"
#include "mediapipe/python/pybind/util.h"
#include "pybind11/stl.h"
#include "stb_image.h"
namespace mediapipe {
namespace python {
@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) {
image.is_aligned(16)
)doc");
image.def_static(
"create_from_file",
[](const std::string& file_name) {
int width;
int height;
int channels;
auto* image_data =
stbi_load(file_name.c_str(), &width, &height, &channels,
/*desired_channels=*/0);
if (image_data == nullptr) {
throw RaisePyError(PyExc_RuntimeError,
absl::StrFormat("Image decoding failed (%s): %s",
stbi_failure_reason(), file_name)
.c_str());
}
ImageFrameSharedPtr image_frame;
switch (channels) {
case 1:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::GRAY8, width, height, width, image_data,
stbi_image_free);
break;
case 3:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGB, width, height, 3 * width, image_data,
stbi_image_free);
break;
case 4:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGBA, width, height, 4 * width, image_data,
stbi_image_free);
break;
default:
throw RaisePyError(
PyExc_RuntimeError,
absl::StrFormat(
"Expected image with 1 (grayscale), 3 (RGB) or 4 "
"(RGBA) channels, found %d channels.",
channels)
.c_str());
}
return Image(std::move(image_frame));
},
R"doc(Creates `Image` object from the image file.
Args:
file_name: Image file name.
Returns:
`Image` object.
Raises:
RuntimeError if the image file can't be decoded.
)doc",
py::arg("file_name"));
image.def_property_readonly("width", &Image::width)
.def_property_readonly("height", &Image::height)
.def_property_readonly("channels", &Image::channels)

View File

@ -16,7 +16,6 @@
import os
from absl import flags
import cv2
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
@ -44,12 +43,3 @@ def get_test_data_path(file_or_dirname: str) -> str:
if f.endswith(file_or_dirname):
return os.path.join(directory, f)
raise ValueError("No %s in test directory" % file_or_dirname)
# TODO: Implement image util module to read image data from file.
def read_test_image(image_file: str) -> _Image:
"""Reads a MediaPipe Image from the image file."""
image_data = cv2.imread(image_file)
if image_data.shape[2] != _RGB_CHANNELS:
raise ValueError("Input image must contain three channel rgb data.")
return _Image(_ImageFormat.SRGB, cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB))

View File

@ -44,7 +44,7 @@ _IMAGE_FILE = 'cats_and_dogs.jpg'
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
_Detection(
bounding_box=_BoundingBox(
origin_x=608, origin_y=164, width=381, height=432),
origin_x=608, origin_y=161, width=381, height=439),
categories=[
_Category(
index=None,
@ -64,7 +64,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
]),
_Detection(
bounding_box=_BoundingBox(
origin_x=257, origin_y=394, width=173, height=202),
origin_x=256, origin_y=395, width=173, height=202),
categories=[
_Category(
index=None,
@ -74,7 +74,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
]),
_Detection(
bounding_box=_BoundingBox(
origin_x=362, origin_y=195, width=325, height=412),
origin_x=362, origin_y=191, width=325, height=419),
categories=[
_Category(
index=None,
@ -98,7 +98,7 @@ class ObjectDetectorTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = test_util.read_test_image(
self.test_image = _Image.create_from_file(
test_util.get_test_data_path(_IMAGE_FILE))
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
@ -153,9 +153,9 @@ class ObjectDetectorTest(parameterized.TestCase):
detector = _ObjectDetector.create_from_options(options)
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
detection_result = detector.detect(self.test_image)
# Comparing results.
self.assertEqual(image_result, expected_detection_result)
self.assertEqual(detection_result, expected_detection_result)
# Closes the detector explicitly when the detector is not used in
# a context.
detector.close()
@ -179,9 +179,9 @@ class ObjectDetectorTest(parameterized.TestCase):
base_options=base_options, max_results=max_results)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
detection_result = detector.detect(self.test_image)
# Comparing results.
self.assertEqual(image_result, expected_detection_result)
self.assertEqual(detection_result, expected_detection_result)
def test_score_threshold_option(self):
options = _ObjectDetectorOptions(
@ -189,8 +189,8 @@ class ObjectDetectorTest(parameterized.TestCase):
score_threshold=_SCORE_THRESHOLD)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
detections = image_result.detections
detection_result = detector.detect(self.test_image)
detections = detection_result.detections
for detection in detections:
score = detection.categories[0].score
@ -204,8 +204,8 @@ class ObjectDetectorTest(parameterized.TestCase):
max_results=_MAX_RESULTS)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
detections = image_result.detections
detection_result = detector.detect(self.test_image)
detections = detection_result.detections
self.assertLessEqual(
len(detections), _MAX_RESULTS, 'Too many results returned.')
@ -216,8 +216,8 @@ class ObjectDetectorTest(parameterized.TestCase):
category_allowlist=_ALLOW_LIST)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
detections = image_result.detections
detection_result = detector.detect(self.test_image)
detections = detection_result.detections
for detection in detections:
label = detection.categories[0].category_name
@ -230,8 +230,8 @@ class ObjectDetectorTest(parameterized.TestCase):
category_denylist=_DENY_LIST)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
detections = image_result.detections
detection_result = detector.detect(self.test_image)
detections = detection_result.detections
for detection in detections:
label = detection.categories[0].category_name
@ -257,8 +257,8 @@ class ObjectDetectorTest(parameterized.TestCase):
score_threshold=1)
with _ObjectDetector.create_from_options(options) as detector:
# Performs object detection on the input.
image_result = detector.detect(self.test_image)
self.assertEmpty(image_result.detections)
detection_result = detector.detect(self.test_image)
self.assertEmpty(detection_result.detections)
def test_missing_result_callback(self):
options = _ObjectDetectorOptions(

View File

@ -85,6 +85,10 @@ filegroup(
"selfie_segm_128_128_3_expected_mask.jpg",
"selfie_segm_144_256_3_expected_mask.jpg",
],
visibility = [
"//mediapipe/python:__subpackages__",
"//mediapipe/tasks:internal",
],
)
# TODO Create individual filegroup for models required for each Tasks.