Check that Java buffer supports direct access before using it
If the buffer is not created with allocateDirect, JNI APIs will return a data pointer of nullptr and a capacity of -1. This can cause a crash when we access it. Also clean up the code to raise exceptions instead of just logging errors and returning nullptr. PiperOrigin-RevId: 489751312
This commit is contained in:
parent
977ee4272e
commit
a33cb1e05e
|
@ -17,6 +17,8 @@
|
|||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/camera_intrinsics.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
@ -107,17 +109,18 @@ absl::StatusOr<mediapipe::GpuBuffer> CreateGpuBuffer(
|
|||
|
||||
// Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java
|
||||
// ByteBuffer.
|
||||
std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
|
||||
JNIEnv* env, jobject byte_buffer, jint width, jint height,
|
||||
mediapipe::ImageFormat::Format format) {
|
||||
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>>
|
||||
CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
|
||||
jint height,
|
||||
mediapipe::ImageFormat::Format format) {
|
||||
switch (format) {
|
||||
case mediapipe::ImageFormat::SRGBA:
|
||||
case mediapipe::ImageFormat::SRGB:
|
||||
case mediapipe::ImageFormat::GRAY8:
|
||||
break;
|
||||
default:
|
||||
LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8.";
|
||||
return nullptr;
|
||||
return absl::InvalidArgumentError(
|
||||
"Format must be either SRGBA, SRGB, or GRAY8.");
|
||||
}
|
||||
|
||||
auto image_frame = std::make_unique<mediapipe::ImageFrame>(
|
||||
|
@ -125,25 +128,30 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
|
|||
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
|
||||
const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
if (buffer_data == nullptr || buffer_size < 0) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Cannot get direct access to the input buffer. It should be created "
|
||||
"using allocateDirect.");
|
||||
}
|
||||
|
||||
const int num_channels = image_frame->NumberOfChannels();
|
||||
const int expected_buffer_size =
|
||||
num_channels == 1 ? width * height : image_frame->PixelDataSize();
|
||||
|
||||
if (buffer_size != expected_buffer_size) {
|
||||
if (num_channels != 1)
|
||||
LOG(ERROR) << "The input image buffer should have 4 bytes alignment.";
|
||||
LOG(ERROR) << "Please check the input buffer size.";
|
||||
LOG(ERROR) << "Buffer size: " << buffer_size
|
||||
<< ", Buffer size needed: " << expected_buffer_size
|
||||
<< ", Image width: " << width;
|
||||
return nullptr;
|
||||
}
|
||||
RET_CHECK_EQ(buffer_size, expected_buffer_size)
|
||||
<< (num_channels != 1
|
||||
? "The input image buffer should have 4 bytes alignment. "
|
||||
: "")
|
||||
<< "Please check the input buffer size."
|
||||
<< " Buffer size: " << buffer_size
|
||||
<< ", Buffer size needed: " << expected_buffer_size
|
||||
<< ", Image width: " << width;
|
||||
|
||||
// Copy buffer data to image frame's pixel_data_.
|
||||
if (num_channels == 1) {
|
||||
const int width_step = image_frame->WidthStep();
|
||||
const char* src_row =
|
||||
reinterpret_cast<const char*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
const char* src_row = reinterpret_cast<const char*>(buffer_data);
|
||||
char* dst_row = reinterpret_cast<char*>(image_frame->MutablePixelData());
|
||||
for (int i = height; i > 0; --i) {
|
||||
std::memcpy(dst_row, src_row, width);
|
||||
|
@ -152,7 +160,6 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
|
|||
}
|
||||
} else {
|
||||
// 3 and 4 channels.
|
||||
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
std::memcpy(image_frame->MutablePixelData(), buffer_data,
|
||||
image_frame->PixelDataSize());
|
||||
}
|
||||
|
@ -176,77 +183,100 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)(
|
|||
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)(
|
||||
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
|
||||
jint height) {
|
||||
auto image_frame = CreateImageFrameFromByteBuffer(
|
||||
auto image_frame_or = CreateImageFrameFromByteBuffer(
|
||||
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB);
|
||||
if (nullptr == image_frame) return 0L;
|
||||
if (ThrowIfError(env, image_frame_or.status())) return 0L;
|
||||
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
|
||||
return CreatePacketWithContext(context, packet);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> CreateRgbImageFromRgba(
|
||||
JNIEnv* env, jobject byte_buffer, jint width, jint height) {
|
||||
const uint8_t* rgba_data =
|
||||
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
if (rgba_data == nullptr || buffer_size < 0) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Cannot get direct access to the input buffer. It should be created "
|
||||
"using allocateDirect.");
|
||||
}
|
||||
|
||||
const int expected_buffer_size = width * height * 4;
|
||||
RET_CHECK_EQ(buffer_size, expected_buffer_size)
|
||||
<< "Please check the input buffer size."
|
||||
<< " Buffer size: " << buffer_size
|
||||
<< ", Buffer size needed: " << expected_buffer_size
|
||||
<< ", Image width: " << width;
|
||||
|
||||
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
|
||||
mediapipe::ImageFormat::SRGB, width, height,
|
||||
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height,
|
||||
image_frame->MutablePixelData(),
|
||||
image_frame->WidthStep());
|
||||
return image_frame;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)(
|
||||
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
|
||||
jint height) {
|
||||
const uint8_t* rgba_data =
|
||||
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
|
||||
mediapipe::ImageFormat::SRGB, width, height,
|
||||
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
if (buffer_size != width * height * 4) {
|
||||
LOG(ERROR) << "Please check the input buffer size.";
|
||||
LOG(ERROR) << "Buffer size: " << buffer_size
|
||||
<< ", Buffer size needed: " << width * height * 4
|
||||
<< ", Image width: " << width;
|
||||
return 0L;
|
||||
}
|
||||
mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height,
|
||||
image_frame->MutablePixelData(),
|
||||
image_frame->WidthStep());
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
|
||||
auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height);
|
||||
if (ThrowIfError(env, image_frame_or.status())) return 0L;
|
||||
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
|
||||
return CreatePacketWithContext(context, packet);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
|
||||
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
|
||||
jint height) {
|
||||
auto image_frame = CreateImageFrameFromByteBuffer(
|
||||
auto image_frame_or = CreateImageFrameFromByteBuffer(
|
||||
env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8);
|
||||
if (nullptr == image_frame) return 0L;
|
||||
if (ThrowIfError(env, image_frame_or.status())) return 0L;
|
||||
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
|
||||
return CreatePacketWithContext(context, packet);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
|
||||
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
|
||||
jint height) {
|
||||
const void* data = env->GetDirectBufferAddress(byte_buffer);
|
||||
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
|
||||
mediapipe::ImageFormat::VEC32F1, width, height,
|
||||
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
if (buffer_size != image_frame->PixelDataSize()) {
|
||||
LOG(ERROR) << "Please check the input buffer size.";
|
||||
LOG(ERROR) << "Buffer size: " << buffer_size
|
||||
<< ", Buffer size needed: " << image_frame->PixelDataSize()
|
||||
<< ", Image width: " << width;
|
||||
return 0L;
|
||||
}
|
||||
std::memcpy(image_frame->MutablePixelData(), data,
|
||||
image_frame->PixelDataSize());
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
|
||||
// TODO: merge this case with CreateImageFrameFromByteBuffer.
|
||||
auto image_frame_or =
|
||||
[&]() -> absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> {
|
||||
const void* data = env->GetDirectBufferAddress(byte_buffer);
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
if (data == nullptr || buffer_size < 0) {
|
||||
return absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access");
|
||||
}
|
||||
|
||||
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
|
||||
mediapipe::ImageFormat::VEC32F1, width, height,
|
||||
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
|
||||
RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize())
|
||||
<< "Please check the input buffer size."
|
||||
<< " Buffer size: " << buffer_size
|
||||
<< ", Buffer size needed: " << image_frame->PixelDataSize()
|
||||
<< ", Image width: " << width;
|
||||
std::memcpy(image_frame->MutablePixelData(), data,
|
||||
image_frame->PixelDataSize());
|
||||
return image_frame;
|
||||
}();
|
||||
if (ThrowIfError(env, image_frame_or.status())) return 0L;
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
|
||||
return CreatePacketWithContext(context, packet);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)(
|
||||
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
|
||||
jint height) {
|
||||
auto image_frame = CreateImageFrameFromByteBuffer(
|
||||
auto image_frame_or = CreateImageFrameFromByteBuffer(
|
||||
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA);
|
||||
if (nullptr == image_frame) return 0L;
|
||||
if (ThrowIfError(env, image_frame_or.status())) return 0L;
|
||||
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
|
||||
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
|
||||
return CreatePacketWithContext(context, packet);
|
||||
}
|
||||
|
||||
|
@ -291,6 +321,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)(
|
|||
jint num_samples) {
|
||||
const uint8_t* audio_sample =
|
||||
reinterpret_cast<uint8_t*>(env->GetDirectBufferAddress(data));
|
||||
if (!audio_sample) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"Cannot get direct access to the input buffer. It "
|
||||
"should be created using allocateDirect."));
|
||||
return 0L;
|
||||
}
|
||||
mediapipe::Packet packet =
|
||||
createAudioPacket(audio_sample, num_samples, num_channels);
|
||||
return CreatePacketWithContext(context, packet);
|
||||
|
@ -360,8 +396,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
|
|||
JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols,
|
||||
jfloatArray data) {
|
||||
if (env->GetArrayLength(data) != rows * cols) {
|
||||
LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = "
|
||||
<< rows * cols;
|
||||
ThrowIfError(
|
||||
env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Please check the matrix data size, has to be rows * cols = ",
|
||||
rows * cols)));
|
||||
return 0L;
|
||||
}
|
||||
std::unique_ptr<mediapipe::Matrix> matrix(new mediapipe::Matrix(rows, cols));
|
||||
|
@ -392,16 +430,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
|
|||
format = mediapipe::ImageFormat::GRAY8;
|
||||
break;
|
||||
default:
|
||||
LOG(ERROR) << "Channels must be either 1, 3, or 4.";
|
||||
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Channels must be either 1, 3, or 4, but are ",
|
||||
num_channels)));
|
||||
return 0L;
|
||||
}
|
||||
|
||||
auto image_frame =
|
||||
auto image_frame_or =
|
||||
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format);
|
||||
if (nullptr == image_frame) return 0L;
|
||||
if (ThrowIfError(env, image_frame_or.status())) return 0L;
|
||||
|
||||
mediapipe::Packet packet =
|
||||
mediapipe::MakePacket<mediapipe::Image>(std::move(image_frame));
|
||||
mediapipe::MakePacket<mediapipe::Image>(*std::move(image_frame_or));
|
||||
return CreatePacketWithContext(context, packet);
|
||||
}
|
||||
|
||||
|
@ -502,7 +542,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)(
|
|||
jbyte* data_ref = env->GetByteArrayElements(data, nullptr);
|
||||
auto options = absl::make_unique<mediapipe::CalculatorOptions>();
|
||||
if (!options->ParseFromArray(data_ref, count)) {
|
||||
LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed.";
|
||||
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Parsing binary-encoded CalculatorOptions failed.")));
|
||||
return 0L;
|
||||
}
|
||||
mediapipe::Packet packet = mediapipe::Adopt(options.release());
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
|
@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
|
|||
: GetFromNativeHandle<mediapipe::ImageFrame>(packet);
|
||||
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
if (buffer_data == nullptr || buffer_size < 0) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access"));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() * image.NumberOfChannels();
|
||||
if (buffer_size != expected_buffer_size) {
|
||||
LOG(ERROR) << "Expected buffer size " << expected_buffer_size
|
||||
<< " got: " << buffer_size << ", width " << image.Width()
|
||||
<< ", height " << image.Height() << ", channels "
|
||||
<< image.NumberOfChannels();
|
||||
ThrowIfError(
|
||||
env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected buffer size ", expected_buffer_size,
|
||||
" got: ", buffer_size, ", width ", image.Width(), ", height ",
|
||||
image.Height(), ", channels ", image.NumberOfChannels())));
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (image.ByteDepth()) {
|
||||
case 1: {
|
||||
uint8* data =
|
||||
static_cast<uint8*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
uint8* data = static_cast<uint8*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
uint16* data =
|
||||
static_cast<uint16*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
uint16* data = static_cast<uint16*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
float* data =
|
||||
static_cast<float*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
float* data = static_cast<float*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
|
@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
|
|||
uint8_t* rgba_data =
|
||||
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
if (rgba_data == nullptr || buffer_size < 0) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access"));
|
||||
return false;
|
||||
}
|
||||
if (buffer_size != image.Width() * image.Height() * 4) {
|
||||
LOG(ERROR) << "Buffer size has to be width*height*4\n"
|
||||
<< "Image width: " << image.Width()
|
||||
<< ", Image height: " << image.Height()
|
||||
<< ", Buffer size: " << buffer_size << ", Buffer size needed: "
|
||||
<< image.Width() * image.Height() * 4;
|
||||
ThrowIfError(env,
|
||||
absl::InvalidArgumentError(absl::StrCat(
|
||||
"Buffer size has to be width*height*4\n"
|
||||
"Image width: ",
|
||||
image.Width(), ", Image height: ", image.Height(),
|
||||
", Buffer size: ", buffer_size, ", Buffer size needed: ",
|
||||
image.Width() * image.Height() * 4)));
|
||||
return false;
|
||||
}
|
||||
mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(),
|
||||
|
|
Loading…
Reference in New Issue
Block a user