mediapipe/mediapipe/python/pybind/image_frame_util.h
MediaPipe Team a9b643e0f5 Project import generated by Copybara.
GitOrigin-RevId: ff83882955f1a1e2a043ff4e71278be9d7217bbe
2021-05-05 14:56:16 -04:00

174 lines
7.1 KiB
C++

// Copyright 2020-2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_PYTHON_PYBIND_IMAGE_FRAME_UTIL_H_
#define MEDIAPIPE_PYTHON_PYBIND_IMAGE_FRAME_UTIL_H_
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/python/pybind/util.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace mediapipe {
namespace python {
namespace py = pybind11;
template <typename T>
std::unique_ptr<ImageFrame> CreateImageFrame(
mediapipe::ImageFormat::Format format,
const py::array_t<T, py::array::c_style>& data, bool copy = true) {
int rows = data.shape()[0];
int cols = data.shape()[1];
int width_step = ImageFrame::NumberOfChannelsForFormat(format) *
ImageFrame::ByteDepthForFormat(format) * cols;
if (copy) {
auto image_frame = absl::make_unique<ImageFrame>(
format, /*width=*/cols, /*height=*/rows, width_step,
static_cast<uint8*>(data.request().ptr),
ImageFrame::PixelDataDeleter::kNone);
auto image_frame_copy = absl::make_unique<ImageFrame>();
// Set alignment_boundary to kGlDefaultAlignmentBoundary so that both
// GPU and CPU can process it.
image_frame_copy->CopyFrom(*image_frame,
ImageFrame::kGlDefaultAlignmentBoundary);
return image_frame_copy;
}
PyObject* data_pyobject = data.ptr();
auto image_frame = absl::make_unique<ImageFrame>(
format, /*width=*/cols, /*height=*/rows, width_step,
static_cast<uint8*>(data.request().ptr),
/*deleter=*/[data_pyobject](uint8*) { Py_XDECREF(data_pyobject); });
Py_XINCREF(data_pyobject);
return image_frame;
}
template <typename T>
py::array GenerateContiguousDataArrayHelper(const ImageFrame& image_frame,
const py::object& py_object) {
std::vector<int> shape{image_frame.Height(), image_frame.Width()};
if (image_frame.NumberOfChannels() > 1) {
shape.push_back(image_frame.NumberOfChannels());
}
py::array_t<T, py::array::c_style> contiguous_data;
if (image_frame.IsContiguous()) {
contiguous_data = py::array_t<T, py::array::c_style>(
shape, reinterpret_cast<const T*>(image_frame.PixelData()), py_object);
} else {
auto contiguous_data_copy =
absl::make_unique<T[]>(image_frame.Width() * image_frame.Height() *
image_frame.NumberOfChannels());
image_frame.CopyToBuffer(contiguous_data_copy.get(),
image_frame.PixelDataSizeStoredContiguously());
auto capsule = py::capsule(contiguous_data_copy.get(), [](void* data) {
if (data) {
delete[] reinterpret_cast<T*>(data);
}
});
contiguous_data = py::array_t<T, py::array::c_style>(
shape, contiguous_data_copy.release(), capsule);
}
// In both cases, the underlying data is not writable in Python.
py::detail::array_proxy(contiguous_data.ptr())->flags &=
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
return contiguous_data;
}
inline py::array GenerateContiguousDataArray(const ImageFrame& image_frame,
const py::object& py_object) {
switch (image_frame.ChannelSize()) {
case sizeof(uint8):
return GenerateContiguousDataArrayHelper<uint8>(image_frame, py_object)
.cast<py::array>();
case sizeof(uint16):
return GenerateContiguousDataArrayHelper<uint16>(image_frame, py_object)
.cast<py::array>();
case sizeof(float):
return GenerateContiguousDataArrayHelper<float>(image_frame, py_object)
.cast<py::array>();
break;
default:
throw RaisePyError(PyExc_RuntimeError,
"Unsupported image frame channel size. Data is not "
"uint8, uint16, or float?");
}
}
// Generates a contiguous data pyarray object on demand.
// This function only accepts an image frame object that already stores
// contiguous data. The output py::array points to the raw pixel data array of
// the image frame object directly.
inline py::array GenerateDataPyArrayOnDemand(const ImageFrame& image_frame,
const py::object& py_object) {
if (!image_frame.IsContiguous()) {
throw RaisePyError(PyExc_RuntimeError,
"GenerateDataPyArrayOnDemand must take an ImageFrame "
"object that stores contiguous data.");
}
return GenerateContiguousDataArray(image_frame, py_object);
}
// Gets the cached contiguous data array from the "__contiguous_data" attribute.
// If the attribute doesn't exist, the function calls
// GenerateContiguousDataArray() to generate the contiguous data pyarray object,
// which realigns and copies the data from the original image frame object.
// Then, the data array object is cached in the "__contiguous_data" attribute.
// This function only accepts an image frame object that stores non-contiguous
// data.
inline py::array GetCachedContiguousDataAttr(const ImageFrame& image_frame,
const py::object& py_object) {
if (image_frame.IsContiguous()) {
throw RaisePyError(PyExc_RuntimeError,
"GetCachedContiguousDataAttr must take an ImageFrame "
"object that stores non-contiguous data.");
}
py::object get_data_attr =
py::getattr(py_object, "__contiguous_data", py::none());
if (image_frame.IsEmpty()) {
throw RaisePyError(PyExc_RuntimeError, "ImageFrame is unallocated.");
}
// If __contiguous_data attr doesn't store data yet, generates the contiguous
// data array object and caches the result.
if (get_data_attr.is_none()) {
py_object.attr("__contiguous_data") =
GenerateContiguousDataArray(image_frame, py_object);
}
return py_object.attr("__contiguous_data").cast<py::array>();
}
template <typename T>
py::object GetValue(const ImageFrame& image_frame, const std::vector<int>& pos,
const py::object& py_object) {
py::array_t<T, py::array::c_style> output_array =
image_frame.IsContiguous()
? GenerateDataPyArrayOnDemand(image_frame, py_object)
: GetCachedContiguousDataAttr(image_frame, py_object);
if (pos.size() == 2) {
return py::cast(static_cast<T>(output_array.at(pos[0], pos[1])));
} else if (pos.size() == 3) {
return py::cast(static_cast<T>(output_array.at(pos[0], pos[1], pos[2])));
}
return py::none();
}
} // namespace python
} // namespace mediapipe
#endif // MEDIAPIPE_PYTHON_PYBIND_IMAGE_FRAME_UTIL_H_