Implement Image.create_from_file and update the object_detector_test.py file accordingly.
PiperOrigin-RevId: 477682930
This commit is contained in:
parent
80fd47b820
commit
554e2a9d69
|
@ -48,6 +48,8 @@ pybind_extension(
|
|||
"//mediapipe/python/pybind:timestamp",
|
||||
"//mediapipe/python/pybind:validated_graph_config",
|
||||
"//mediapipe/tasks/python/core/pybind:task_runner",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@stblib//:stb_image",
|
||||
# Type registration.
|
||||
"//mediapipe/framework:basic_types_registration",
|
||||
"//mediapipe/framework/formats:classification_registration",
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Tests for mediapipe.python._framework_bindings.image."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
@ -23,6 +24,7 @@ import cv2
|
|||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
# resources dependency
|
||||
from mediapipe.python._framework_bindings import image
|
||||
from mediapipe.python._framework_bindings import image_frame
|
||||
|
||||
|
@ -185,6 +187,16 @@ class ImageTest(absltest.TestCase):
|
|||
gc.collect()
|
||||
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
|
||||
|
||||
def test_image_create_from_file(self):
|
||||
image_path = os.path.join(
|
||||
resources.GetRunfilesDir(),
|
||||
'mediapipe/tasks/testdata/vision/cat.jpg')
|
||||
loaded_image = Image.create_from_file(image_path)
|
||||
self.assertEqual(loaded_image.width, 600)
|
||||
self.assertEqual(loaded_image.height, 400)
|
||||
self.assertEqual(loaded_image.channels, 3)
|
||||
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -45,6 +45,8 @@ pybind_library(
|
|||
":util",
|
||||
"//mediapipe/framework:type_map",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@stblib//:stb_image",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/python/pybind/image_frame_util.h"
|
||||
#include "mediapipe/python/pybind/util.h"
|
||||
#include "pybind11/stl.h"
|
||||
#include "stb_image.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace python {
|
||||
|
@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
image.is_aligned(16)
|
||||
)doc");
|
||||
|
||||
image.def_static(
|
||||
"create_from_file",
|
||||
[](const std::string& file_name) {
|
||||
int width;
|
||||
int height;
|
||||
int channels;
|
||||
auto* image_data =
|
||||
stbi_load(file_name.c_str(), &width, &height, &channels,
|
||||
/*desired_channels=*/0);
|
||||
if (image_data == nullptr) {
|
||||
throw RaisePyError(PyExc_RuntimeError,
|
||||
absl::StrFormat("Image decoding failed (%s): %s",
|
||||
stbi_failure_reason(), file_name)
|
||||
.c_str());
|
||||
}
|
||||
ImageFrameSharedPtr image_frame;
|
||||
switch (channels) {
|
||||
case 1:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::GRAY8, width, height, width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
case 3:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGB, width, height, 3 * width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
case 4:
|
||||
image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGBA, width, height, 4 * width, image_data,
|
||||
stbi_image_free);
|
||||
break;
|
||||
default:
|
||||
throw RaisePyError(
|
||||
PyExc_RuntimeError,
|
||||
absl::StrFormat(
|
||||
"Expected image with 1 (grayscale), 3 (RGB) or 4 "
|
||||
"(RGBA) channels, found %d channels.",
|
||||
channels)
|
||||
.c_str());
|
||||
}
|
||||
return Image(std::move(image_frame));
|
||||
},
|
||||
R"doc(Creates `Image` object from the image file.
|
||||
|
||||
Args:
|
||||
file_name: Image file name.
|
||||
|
||||
Returns:
|
||||
`Image` object.
|
||||
|
||||
Raises:
|
||||
RuntimeError if the image file can't be decoded.
|
||||
)doc",
|
||||
py::arg("file_name"));
|
||||
|
||||
image.def_property_readonly("width", &Image::width)
|
||||
.def_property_readonly("height", &Image::height)
|
||||
.def_property_readonly("channels", &Image::channels)
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import os
|
||||
|
||||
from absl import flags
|
||||
import cv2
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||
|
@ -44,12 +43,3 @@ def get_test_data_path(file_or_dirname: str) -> str:
|
|||
if f.endswith(file_or_dirname):
|
||||
return os.path.join(directory, f)
|
||||
raise ValueError("No %s in test directory" % file_or_dirname)
|
||||
|
||||
|
||||
# TODO: Implement image util module to read image data from file.
|
||||
def read_test_image(image_file: str) -> _Image:
|
||||
"""Reads a MediaPipe Image from the image file."""
|
||||
image_data = cv2.imread(image_file)
|
||||
if image_data.shape[2] != _RGB_CHANNELS:
|
||||
raise ValueError("Input image must contain three channel rgb data.")
|
||||
return _Image(_ImageFormat.SRGB, cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB))
|
||||
|
|
|
@ -44,7 +44,7 @@ _IMAGE_FILE = 'cats_and_dogs.jpg'
|
|||
_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=608, origin_y=164, width=381, height=432),
|
||||
origin_x=608, origin_y=161, width=381, height=439),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
|
@ -64,7 +64,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
|||
]),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=257, origin_y=394, width=173, height=202),
|
||||
origin_x=256, origin_y=395, width=173, height=202),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
|
@ -74,7 +74,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[
|
|||
]),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=362, origin_y=195, width=325, height=412),
|
||||
origin_x=362, origin_y=191, width=325, height=419),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
|
@ -98,7 +98,7 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = test_util.read_test_image(
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_util.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_util.get_test_data_path(_MODEL_FILE)
|
||||
|
||||
|
@ -153,9 +153,9 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
detector = _ObjectDetector.create_from_options(options)
|
||||
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self.assertEqual(image_result, expected_detection_result)
|
||||
self.assertEqual(detection_result, expected_detection_result)
|
||||
# Closes the detector explicitly when the detector is not used in
|
||||
# a context.
|
||||
detector.close()
|
||||
|
@ -179,9 +179,9 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
base_options=base_options, max_results=max_results)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self.assertEqual(image_result, expected_detection_result)
|
||||
self.assertEqual(detection_result, expected_detection_result)
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
|
@ -189,8 +189,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
score_threshold=_SCORE_THRESHOLD)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
detections = image_result.detections
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
for detection in detections:
|
||||
score = detection.categories[0].score
|
||||
|
@ -204,8 +204,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
max_results=_MAX_RESULTS)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
detections = image_result.detections
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
self.assertLessEqual(
|
||||
len(detections), _MAX_RESULTS, 'Too many results returned.')
|
||||
|
@ -216,8 +216,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
category_allowlist=_ALLOW_LIST)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
detections = image_result.detections
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
for detection in detections:
|
||||
label = detection.categories[0].category_name
|
||||
|
@ -230,8 +230,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
category_denylist=_DENY_LIST)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
detections = image_result.detections
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
for detection in detections:
|
||||
label = detection.categories[0].category_name
|
||||
|
@ -257,8 +257,8 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
score_threshold=1)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
image_result = detector.detect(self.test_image)
|
||||
self.assertEmpty(image_result.detections)
|
||||
detection_result = detector.detect(self.test_image)
|
||||
self.assertEmpty(detection_result.detections)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
|
|
4
mediapipe/tasks/testdata/vision/BUILD
vendored
4
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -85,6 +85,10 @@ filegroup(
|
|||
"selfie_segm_128_128_3_expected_mask.jpg",
|
||||
"selfie_segm_144_256_3_expected_mask.jpg",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/python:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO Create individual filegroup for models required for each Tasks.
|
||||
|
|
Loading…
Reference in New Issue
Block a user