174 lines
7.1 KiB
C++
174 lines
7.1 KiB
C++
// Copyright 2020-2021 The MediaPipe Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef MEDIAPIPE_PYTHON_PYBIND_IMAGE_FRAME_UTIL_H_
|
|
#define MEDIAPIPE_PYTHON_PYBIND_IMAGE_FRAME_UTIL_H_
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "mediapipe/framework/formats/image_format.pb.h"
|
|
#include "mediapipe/framework/formats/image_frame.h"
|
|
#include "mediapipe/framework/port/logging.h"
|
|
#include "mediapipe/python/pybind/util.h"
|
|
#include "pybind11/numpy.h"
|
|
#include "pybind11/pybind11.h"
|
|
|
|
namespace mediapipe {
|
|
namespace python {
|
|
|
|
namespace py = pybind11;
|
|
|
|
template <typename T>
|
|
std::unique_ptr<ImageFrame> CreateImageFrame(
|
|
mediapipe::ImageFormat::Format format,
|
|
const py::array_t<T, py::array::c_style>& data, bool copy = true) {
|
|
int rows = data.shape()[0];
|
|
int cols = data.shape()[1];
|
|
int width_step = ImageFrame::NumberOfChannelsForFormat(format) *
|
|
ImageFrame::ByteDepthForFormat(format) * cols;
|
|
if (copy) {
|
|
auto image_frame = absl::make_unique<ImageFrame>(
|
|
format, /*width=*/cols, /*height=*/rows, width_step,
|
|
static_cast<uint8*>(data.request().ptr),
|
|
ImageFrame::PixelDataDeleter::kNone);
|
|
auto image_frame_copy = absl::make_unique<ImageFrame>();
|
|
// Set alignment_boundary to kGlDefaultAlignmentBoundary so that both
|
|
// GPU and CPU can process it.
|
|
image_frame_copy->CopyFrom(*image_frame,
|
|
ImageFrame::kGlDefaultAlignmentBoundary);
|
|
return image_frame_copy;
|
|
}
|
|
PyObject* data_pyobject = data.ptr();
|
|
auto image_frame = absl::make_unique<ImageFrame>(
|
|
format, /*width=*/cols, /*height=*/rows, width_step,
|
|
static_cast<uint8*>(data.request().ptr),
|
|
/*deleter=*/[data_pyobject](uint8*) { Py_XDECREF(data_pyobject); });
|
|
Py_XINCREF(data_pyobject);
|
|
return image_frame;
|
|
}
|
|
|
|
template <typename T>
|
|
py::array GenerateContiguousDataArrayHelper(const ImageFrame& image_frame,
|
|
const py::object& py_object) {
|
|
std::vector<int> shape{image_frame.Height(), image_frame.Width()};
|
|
if (image_frame.NumberOfChannels() > 1) {
|
|
shape.push_back(image_frame.NumberOfChannels());
|
|
}
|
|
py::array_t<T, py::array::c_style> contiguous_data;
|
|
if (image_frame.IsContiguous()) {
|
|
contiguous_data = py::array_t<T, py::array::c_style>(
|
|
shape, reinterpret_cast<const T*>(image_frame.PixelData()), py_object);
|
|
} else {
|
|
auto contiguous_data_copy =
|
|
absl::make_unique<T[]>(image_frame.Width() * image_frame.Height() *
|
|
image_frame.NumberOfChannels());
|
|
image_frame.CopyToBuffer(contiguous_data_copy.get(),
|
|
image_frame.PixelDataSizeStoredContiguously());
|
|
auto capsule = py::capsule(contiguous_data_copy.get(), [](void* data) {
|
|
if (data) {
|
|
delete[] reinterpret_cast<T*>(data);
|
|
}
|
|
});
|
|
contiguous_data = py::array_t<T, py::array::c_style>(
|
|
shape, contiguous_data_copy.release(), capsule);
|
|
}
|
|
|
|
// In both cases, the underlying data is not writable in Python.
|
|
py::detail::array_proxy(contiguous_data.ptr())->flags &=
|
|
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
|
return contiguous_data;
|
|
}
|
|
|
|
inline py::array GenerateContiguousDataArray(const ImageFrame& image_frame,
|
|
const py::object& py_object) {
|
|
switch (image_frame.ChannelSize()) {
|
|
case sizeof(uint8):
|
|
return GenerateContiguousDataArrayHelper<uint8>(image_frame, py_object)
|
|
.cast<py::array>();
|
|
case sizeof(uint16):
|
|
return GenerateContiguousDataArrayHelper<uint16>(image_frame, py_object)
|
|
.cast<py::array>();
|
|
case sizeof(float):
|
|
return GenerateContiguousDataArrayHelper<float>(image_frame, py_object)
|
|
.cast<py::array>();
|
|
break;
|
|
default:
|
|
throw RaisePyError(PyExc_RuntimeError,
|
|
"Unsupported image frame channel size. Data is not "
|
|
"uint8, uint16, or float?");
|
|
}
|
|
}
|
|
|
|
// Generates a contiguous data pyarray object on demand.
|
|
// This function only accepts an image frame object that already stores
|
|
// contiguous data. The output py::array points to the raw pixel data array of
|
|
// the image frame object directly.
|
|
inline py::array GenerateDataPyArrayOnDemand(const ImageFrame& image_frame,
|
|
const py::object& py_object) {
|
|
if (!image_frame.IsContiguous()) {
|
|
throw RaisePyError(PyExc_RuntimeError,
|
|
"GenerateDataPyArrayOnDemand must take an ImageFrame "
|
|
"object that stores contiguous data.");
|
|
}
|
|
return GenerateContiguousDataArray(image_frame, py_object);
|
|
}
|
|
|
|
// Gets the cached contiguous data array from the "__contiguous_data" attribute.
|
|
// If the attribute doesn't exist, the function calls
|
|
// GenerateContiguousDataArray() to generate the contiguous data pyarray object,
|
|
// which realigns and copies the data from the original image frame object.
|
|
// Then, the data array object is cached in the "__contiguous_data" attribute.
|
|
// This function only accepts an image frame object that stores non-contiguous
|
|
// data.
|
|
inline py::array GetCachedContiguousDataAttr(const ImageFrame& image_frame,
|
|
const py::object& py_object) {
|
|
if (image_frame.IsContiguous()) {
|
|
throw RaisePyError(PyExc_RuntimeError,
|
|
"GetCachedContiguousDataAttr must take an ImageFrame "
|
|
"object that stores non-contiguous data.");
|
|
}
|
|
py::object get_data_attr =
|
|
py::getattr(py_object, "__contiguous_data", py::none());
|
|
if (image_frame.IsEmpty()) {
|
|
throw RaisePyError(PyExc_RuntimeError, "ImageFrame is unallocated.");
|
|
}
|
|
// If __contiguous_data attr doesn't store data yet, generates the contiguous
|
|
// data array object and caches the result.
|
|
if (get_data_attr.is_none()) {
|
|
py_object.attr("__contiguous_data") =
|
|
GenerateContiguousDataArray(image_frame, py_object);
|
|
}
|
|
return py_object.attr("__contiguous_data").cast<py::array>();
|
|
}
|
|
|
|
template <typename T>
|
|
py::object GetValue(const ImageFrame& image_frame, const std::vector<int>& pos,
|
|
const py::object& py_object) {
|
|
py::array_t<T, py::array::c_style> output_array =
|
|
image_frame.IsContiguous()
|
|
? GenerateDataPyArrayOnDemand(image_frame, py_object)
|
|
: GetCachedContiguousDataAttr(image_frame, py_object);
|
|
if (pos.size() == 2) {
|
|
return py::cast(static_cast<T>(output_array.at(pos[0], pos[1])));
|
|
} else if (pos.size() == 3) {
|
|
return py::cast(static_cast<T>(output_array.at(pos[0], pos[1], pos[2])));
|
|
}
|
|
return py::none();
|
|
}
|
|
|
|
} // namespace python
|
|
} // namespace mediapipe
|
|
|
|
#endif // MEDIAPIPE_PYTHON_PYBIND_IMAGE_FRAME_UTIL_H_
|