Check that Java buffer supports direct access before using it

If the buffer is not created with allocateDirect, JNI APIs will return a data pointer of nullptr and a capacity of -1. This can cause a crash when we access it.

Also clean up the code to raise exceptions instead of just logging errors and returning nullptr.

PiperOrigin-RevId: 489751312
This commit is contained in:
Camillo Lugaresi 2022-11-19 21:03:29 -08:00 committed by Copybara-Service
parent 977ee4272e
commit a33cb1e05e
2 changed files with 133 additions and 80 deletions

View File

@ -17,6 +17,8 @@
#include <cstring>
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/camera_intrinsics.h"
#include "mediapipe/framework/formats/image.h"
@ -107,17 +109,18 @@ absl::StatusOr<mediapipe::GpuBuffer> CreateGpuBuffer(
// Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java
// ByteBuffer.
std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
JNIEnv* env, jobject byte_buffer, jint width, jint height,
mediapipe::ImageFormat::Format format) {
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>>
CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
jint height,
mediapipe::ImageFormat::Format format) {
switch (format) {
case mediapipe::ImageFormat::SRGBA:
case mediapipe::ImageFormat::SRGB:
case mediapipe::ImageFormat::GRAY8:
break;
default:
LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8.";
return nullptr;
return absl::InvalidArgumentError(
"Format must be either SRGBA, SRGB, or GRAY8.");
}
auto image_frame = std::make_unique<mediapipe::ImageFrame>(
@ -125,25 +128,30 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"Cannot get direct access to the input buffer. It should be created "
"using allocateDirect.");
}
const int num_channels = image_frame->NumberOfChannels();
const int expected_buffer_size =
num_channels == 1 ? width * height : image_frame->PixelDataSize();
if (buffer_size != expected_buffer_size) {
if (num_channels != 1)
LOG(ERROR) << "The input image buffer should have 4 bytes alignment.";
LOG(ERROR) << "Please check the input buffer size.";
LOG(ERROR) << "Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
return nullptr;
}
RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< (num_channels != 1
? "The input image buffer should have 4 bytes alignment. "
: "")
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
// Copy buffer data to image frame's pixel_data_.
if (num_channels == 1) {
const int width_step = image_frame->WidthStep();
const char* src_row =
reinterpret_cast<const char*>(env->GetDirectBufferAddress(byte_buffer));
const char* src_row = reinterpret_cast<const char*>(buffer_data);
char* dst_row = reinterpret_cast<char*>(image_frame->MutablePixelData());
for (int i = height; i > 0; --i) {
std::memcpy(dst_row, src_row, width);
@ -152,7 +160,6 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
}
} else {
// 3 and 4 channels.
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
std::memcpy(image_frame->MutablePixelData(), buffer_data,
image_frame->PixelDataSize());
}
@ -176,77 +183,100 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
auto image_frame = CreateImageFrameFromByteBuffer(
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB);
if (nullptr == image_frame) return 0L;
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
}
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> CreateRgbImageFromRgba(
JNIEnv* env, jobject byte_buffer, jint width, jint height) {
const uint8_t* rgba_data =
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (rgba_data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"Cannot get direct access to the input buffer. It should be created "
"using allocateDirect.");
}
const int expected_buffer_size = width * height * 4;
RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height,
image_frame->MutablePixelData(),
image_frame->WidthStep());
return image_frame;
}
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
const uint8_t* rgba_data =
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (buffer_size != width * height * 4) {
LOG(ERROR) << "Please check the input buffer size.";
LOG(ERROR) << "Buffer size: " << buffer_size
<< ", Buffer size needed: " << width * height * 4
<< ", Image width: " << width;
return 0L;
}
mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height,
image_frame->MutablePixelData(),
image_frame->WidthStep());
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height);
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
}
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
auto image_frame = CreateImageFrameFromByteBuffer(
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8);
if (nullptr == image_frame) return 0L;
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
}
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
const void* data = env->GetDirectBufferAddress(byte_buffer);
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::VEC32F1, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (buffer_size != image_frame->PixelDataSize()) {
LOG(ERROR) << "Please check the input buffer size.";
LOG(ERROR) << "Buffer size: " << buffer_size
<< ", Buffer size needed: " << image_frame->PixelDataSize()
<< ", Image width: " << width;
return 0L;
}
std::memcpy(image_frame->MutablePixelData(), data,
image_frame->PixelDataSize());
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
// TODO: merge this case with CreateImageFrameFromByteBuffer.
auto image_frame_or =
[&]() -> absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> {
const void* data = env->GetDirectBufferAddress(byte_buffer);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"input buffer does not support direct access");
}
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::VEC32F1, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize())
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << image_frame->PixelDataSize()
<< ", Image width: " << width;
std::memcpy(image_frame->MutablePixelData(), data,
image_frame->PixelDataSize());
return image_frame;
}();
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
}
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
auto image_frame = CreateImageFrameFromByteBuffer(
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA);
if (nullptr == image_frame) return 0L;
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
}
@ -291,6 +321,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)(
jint num_samples) {
const uint8_t* audio_sample =
reinterpret_cast<uint8_t*>(env->GetDirectBufferAddress(data));
if (!audio_sample) {
ThrowIfError(env, absl::InvalidArgumentError(
"Cannot get direct access to the input buffer. It "
"should be created using allocateDirect."));
return 0L;
}
mediapipe::Packet packet =
createAudioPacket(audio_sample, num_samples, num_channels);
return CreatePacketWithContext(context, packet);
@ -360,8 +396,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols,
jfloatArray data) {
if (env->GetArrayLength(data) != rows * cols) {
LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = "
<< rows * cols;
ThrowIfError(
env, absl::InvalidArgumentError(absl::StrCat(
"Please check the matrix data size, has to be rows * cols = ",
rows * cols)));
return 0L;
}
std::unique_ptr<mediapipe::Matrix> matrix(new mediapipe::Matrix(rows, cols));
@ -392,16 +430,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
format = mediapipe::ImageFormat::GRAY8;
break;
default:
LOG(ERROR) << "Channels must be either 1, 3, or 4.";
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
"Channels must be either 1, 3, or 4, but are ",
num_channels)));
return 0L;
}
auto image_frame =
auto image_frame_or =
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format);
if (nullptr == image_frame) return 0L;
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet =
mediapipe::MakePacket<mediapipe::Image>(std::move(image_frame));
mediapipe::MakePacket<mediapipe::Image>(*std::move(image_frame_or));
return CreatePacketWithContext(context, packet);
}
@ -502,7 +542,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)(
jbyte* data_ref = env->GetByteArrayElements(data, nullptr);
auto options = absl::make_unique<mediapipe::CalculatorOptions>();
if (!options->ParseFromArray(data_ref, count)) {
LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed.";
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
"Parsing binary-encoded CalculatorOptions failed.")));
return 0L;
}
mediapipe::Packet packet = mediapipe::Adopt(options.release());

View File

@ -14,6 +14,7 @@
#include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
: GetFromNativeHandle<mediapipe::ImageFrame>(packet);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InvalidArgumentError(
"input buffer does not support direct access"));
return false;
}
// Assume byte buffer stores pixel data contiguously.
const int expected_buffer_size = image.Width() * image.Height() *
image.ByteDepth() * image.NumberOfChannels();
if (buffer_size != expected_buffer_size) {
LOG(ERROR) << "Expected buffer size " << expected_buffer_size
<< " got: " << buffer_size << ", width " << image.Width()
<< ", height " << image.Height() << ", channels "
<< image.NumberOfChannels();
ThrowIfError(
env, absl::InvalidArgumentError(absl::StrCat(
"Expected buffer size ", expected_buffer_size,
" got: ", buffer_size, ", width ", image.Width(), ", height ",
image.Height(), ", channels ", image.NumberOfChannels())));
return false;
}
switch (image.ByteDepth()) {
case 1: {
uint8* data =
static_cast<uint8*>(env->GetDirectBufferAddress(byte_buffer));
uint8* data = static_cast<uint8*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
case 2: {
uint16* data =
static_cast<uint16*>(env->GetDirectBufferAddress(byte_buffer));
uint16* data = static_cast<uint16*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
case 4: {
float* data =
static_cast<float*>(env->GetDirectBufferAddress(byte_buffer));
float* data = static_cast<float*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
uint8_t* rgba_data =
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (rgba_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InvalidArgumentError(
"input buffer does not support direct access"));
return false;
}
if (buffer_size != image.Width() * image.Height() * 4) {
LOG(ERROR) << "Buffer size has to be width*height*4\n"
<< "Image width: " << image.Width()
<< ", Image height: " << image.Height()
<< ", Buffer size: " << buffer_size << ", Buffer size needed: "
<< image.Width() * image.Height() * 4;
ThrowIfError(env,
absl::InvalidArgumentError(absl::StrCat(
"Buffer size has to be width*height*4\n"
"Image width: ",
image.Width(), ", Image height: ", image.Height(),
", Buffer size: ", buffer_size, ", Buffer size needed: ",
image.Width() * image.Height() * 4)));
return false;
}
mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(),