Check that Java buffer supports direct access before using it

If the buffer is not created with allocateDirect, JNI APIs will return a data pointer of nullptr and a capacity of -1. This can cause a crash when we access it.

Also clean up the code to raise exceptions instead of just logging errors and returning nullptr.

PiperOrigin-RevId: 489751312
This commit is contained in:
Camillo Lugaresi 2022-11-19 21:03:29 -08:00 committed by Copybara-Service
parent 977ee4272e
commit a33cb1e05e
2 changed files with 133 additions and 80 deletions

View File

@ -17,6 +17,8 @@
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/camera_intrinsics.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
@ -107,17 +109,18 @@ absl::StatusOr<mediapipe::GpuBuffer> CreateGpuBuffer(
// Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java // Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java
// ByteBuffer. // ByteBuffer.
std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer( absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>>
JNIEnv* env, jobject byte_buffer, jint width, jint height, CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
mediapipe::ImageFormat::Format format) { jint height,
mediapipe::ImageFormat::Format format) {
switch (format) { switch (format) {
case mediapipe::ImageFormat::SRGBA: case mediapipe::ImageFormat::SRGBA:
case mediapipe::ImageFormat::SRGB: case mediapipe::ImageFormat::SRGB:
case mediapipe::ImageFormat::GRAY8: case mediapipe::ImageFormat::GRAY8:
break; break;
default: default:
LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; return absl::InvalidArgumentError(
return nullptr; "Format must be either SRGBA, SRGB, or GRAY8.");
} }
auto image_frame = std::make_unique<mediapipe::ImageFrame>( auto image_frame = std::make_unique<mediapipe::ImageFrame>(
@ -125,25 +128,30 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"Cannot get direct access to the input buffer. It should be created "
"using allocateDirect.");
}
const int num_channels = image_frame->NumberOfChannels(); const int num_channels = image_frame->NumberOfChannels();
const int expected_buffer_size = const int expected_buffer_size =
num_channels == 1 ? width * height : image_frame->PixelDataSize(); num_channels == 1 ? width * height : image_frame->PixelDataSize();
if (buffer_size != expected_buffer_size) { RET_CHECK_EQ(buffer_size, expected_buffer_size)
if (num_channels != 1) << (num_channels != 1
LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; ? "The input image buffer should have 4 bytes alignment. "
LOG(ERROR) << "Please check the input buffer size."; : "")
LOG(ERROR) << "Buffer size: " << buffer_size << "Please check the input buffer size."
<< ", Buffer size needed: " << expected_buffer_size << " Buffer size: " << buffer_size
<< ", Image width: " << width; << ", Buffer size needed: " << expected_buffer_size
return nullptr; << ", Image width: " << width;
}
// Copy buffer data to image frame's pixel_data_. // Copy buffer data to image frame's pixel_data_.
if (num_channels == 1) { if (num_channels == 1) {
const int width_step = image_frame->WidthStep(); const int width_step = image_frame->WidthStep();
const char* src_row = const char* src_row = reinterpret_cast<const char*>(buffer_data);
reinterpret_cast<const char*>(env->GetDirectBufferAddress(byte_buffer));
char* dst_row = reinterpret_cast<char*>(image_frame->MutablePixelData()); char* dst_row = reinterpret_cast<char*>(image_frame->MutablePixelData());
for (int i = height; i > 0; --i) { for (int i = height; i > 0; --i) {
std::memcpy(dst_row, src_row, width); std::memcpy(dst_row, src_row, width);
@ -152,7 +160,6 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameFromByteBuffer(
} }
} else { } else {
// 3 and 4 channels. // 3 and 4 channels.
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
std::memcpy(image_frame->MutablePixelData(), buffer_data, std::memcpy(image_frame->MutablePixelData(), buffer_data,
image_frame->PixelDataSize()); image_frame->PixelDataSize());
} }
@ -176,77 +183,100 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
auto image_frame = CreateImageFrameFromByteBuffer( auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB);
if (nullptr == image_frame) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> CreateRgbImageFromRgba(
JNIEnv* env, jobject byte_buffer, jint width, jint height) {
const uint8_t* rgba_data =
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (rgba_data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"Cannot get direct access to the input buffer. It should be created "
"using allocateDirect.");
}
const int expected_buffer_size = width * height * 4;
RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height,
image_frame->MutablePixelData(),
image_frame->WidthStep());
return image_frame;
}
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
const uint8_t* rgba_data = auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height);
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer)); if (ThrowIfError(env, image_frame_or.status())) return 0L;
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, width, height, mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (buffer_size != width * height * 4) {
LOG(ERROR) << "Please check the input buffer size.";
LOG(ERROR) << "Buffer size: " << buffer_size
<< ", Buffer size needed: " << width * height * 4
<< ", Image width: " << width;
return 0L;
}
mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height,
image_frame->MutablePixelData(),
image_frame->WidthStep());
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
auto image_frame = CreateImageFrameFromByteBuffer( auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8);
if (nullptr == image_frame) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
const void* data = env->GetDirectBufferAddress(byte_buffer); // TODO: merge this case with CreateImageFrameFromByteBuffer.
auto image_frame = absl::make_unique<mediapipe::ImageFrame>( auto image_frame_or =
mediapipe::ImageFormat::VEC32F1, width, height, [&]() -> absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> {
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); const void* data = env->GetDirectBufferAddress(byte_buffer);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (buffer_size != image_frame->PixelDataSize()) { if (data == nullptr || buffer_size < 0) {
LOG(ERROR) << "Please check the input buffer size."; return absl::InvalidArgumentError(
LOG(ERROR) << "Buffer size: " << buffer_size "input buffer does not support direct access");
<< ", Buffer size needed: " << image_frame->PixelDataSize() }
<< ", Image width: " << width;
return 0L; auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
} mediapipe::ImageFormat::VEC32F1, width, height,
std::memcpy(image_frame->MutablePixelData(), data, mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
image_frame->PixelDataSize()); RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize())
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); << "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << image_frame->PixelDataSize()
<< ", Image width: " << width;
std::memcpy(image_frame->MutablePixelData(), data,
image_frame->PixelDataSize());
return image_frame;
}();
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
auto image_frame = CreateImageFrameFromByteBuffer( auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA);
if (nullptr == image_frame) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
@ -291,6 +321,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)(
jint num_samples) { jint num_samples) {
const uint8_t* audio_sample = const uint8_t* audio_sample =
reinterpret_cast<uint8_t*>(env->GetDirectBufferAddress(data)); reinterpret_cast<uint8_t*>(env->GetDirectBufferAddress(data));
if (!audio_sample) {
ThrowIfError(env, absl::InvalidArgumentError(
"Cannot get direct access to the input buffer. It "
"should be created using allocateDirect."));
return 0L;
}
mediapipe::Packet packet = mediapipe::Packet packet =
createAudioPacket(audio_sample, num_samples, num_channels); createAudioPacket(audio_sample, num_samples, num_channels);
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
@ -360,8 +396,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols,
jfloatArray data) { jfloatArray data) {
if (env->GetArrayLength(data) != rows * cols) { if (env->GetArrayLength(data) != rows * cols) {
LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " ThrowIfError(
<< rows * cols; env, absl::InvalidArgumentError(absl::StrCat(
"Please check the matrix data size, has to be rows * cols = ",
rows * cols)));
return 0L; return 0L;
} }
std::unique_ptr<mediapipe::Matrix> matrix(new mediapipe::Matrix(rows, cols)); std::unique_ptr<mediapipe::Matrix> matrix(new mediapipe::Matrix(rows, cols));
@ -392,16 +430,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
format = mediapipe::ImageFormat::GRAY8; format = mediapipe::ImageFormat::GRAY8;
break; break;
default: default:
LOG(ERROR) << "Channels must be either 1, 3, or 4."; ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
"Channels must be either 1, 3, or 4, but are ",
num_channels)));
return 0L; return 0L;
} }
auto image_frame = auto image_frame_or =
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format);
if (nullptr == image_frame) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Packet packet =
mediapipe::MakePacket<mediapipe::Image>(std::move(image_frame)); mediapipe::MakePacket<mediapipe::Image>(*std::move(image_frame_or));
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
@ -502,7 +542,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)(
jbyte* data_ref = env->GetByteArrayElements(data, nullptr); jbyte* data_ref = env->GetByteArrayElements(data, nullptr);
auto options = absl::make_unique<mediapipe::CalculatorOptions>(); auto options = absl::make_unique<mediapipe::CalculatorOptions>();
if (!options->ParseFromArray(data_ref, count)) { if (!options->ParseFromArray(data_ref, count)) {
LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed."; ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
"Parsing binary-encoded CalculatorOptions failed.")));
return 0L; return 0L;
} }
mediapipe::Packet packet = mediapipe::Adopt(options.release()); mediapipe::Packet packet = mediapipe::Adopt(options.release());

View File

@ -14,6 +14,7 @@
#include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
: GetFromNativeHandle<mediapipe::ImageFrame>(packet); : GetFromNativeHandle<mediapipe::ImageFrame>(packet);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InvalidArgumentError(
"input buffer does not support direct access"));
return false;
}
// Assume byte buffer stores pixel data contiguously. // Assume byte buffer stores pixel data contiguously.
const int expected_buffer_size = image.Width() * image.Height() * const int expected_buffer_size = image.Width() * image.Height() *
image.ByteDepth() * image.NumberOfChannels(); image.ByteDepth() * image.NumberOfChannels();
if (buffer_size != expected_buffer_size) { if (buffer_size != expected_buffer_size) {
LOG(ERROR) << "Expected buffer size " << expected_buffer_size ThrowIfError(
<< " got: " << buffer_size << ", width " << image.Width() env, absl::InvalidArgumentError(absl::StrCat(
<< ", height " << image.Height() << ", channels " "Expected buffer size ", expected_buffer_size,
<< image.NumberOfChannels(); " got: ", buffer_size, ", width ", image.Width(), ", height ",
image.Height(), ", channels ", image.NumberOfChannels())));
return false; return false;
} }
switch (image.ByteDepth()) { switch (image.ByteDepth()) {
case 1: { case 1: {
uint8* data = uint8* data = static_cast<uint8*>(buffer_data);
static_cast<uint8*>(env->GetDirectBufferAddress(byte_buffer));
image.CopyToBuffer(data, expected_buffer_size); image.CopyToBuffer(data, expected_buffer_size);
break; break;
} }
case 2: { case 2: {
uint16* data = uint16* data = static_cast<uint16*>(buffer_data);
static_cast<uint16*>(env->GetDirectBufferAddress(byte_buffer));
image.CopyToBuffer(data, expected_buffer_size); image.CopyToBuffer(data, expected_buffer_size);
break; break;
} }
case 4: { case 4: {
float* data = float* data = static_cast<float*>(buffer_data);
static_cast<float*>(env->GetDirectBufferAddress(byte_buffer));
image.CopyToBuffer(data, expected_buffer_size); image.CopyToBuffer(data, expected_buffer_size);
break; break;
} }
@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
uint8_t* rgba_data = uint8_t* rgba_data =
static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer)); static_cast<uint8_t*>(env->GetDirectBufferAddress(byte_buffer));
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (rgba_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InvalidArgumentError(
"input buffer does not support direct access"));
return false;
}
if (buffer_size != image.Width() * image.Height() * 4) { if (buffer_size != image.Width() * image.Height() * 4) {
LOG(ERROR) << "Buffer size has to be width*height*4\n" ThrowIfError(env,
<< "Image width: " << image.Width() absl::InvalidArgumentError(absl::StrCat(
<< ", Image height: " << image.Height() "Buffer size has to be width*height*4\n"
<< ", Buffer size: " << buffer_size << ", Buffer size needed: " "Image width: ",
<< image.Width() * image.Height() * 4; image.Width(), ", Image height: ", image.Height(),
", Buffer size: ", buffer_size, ", Buffer size needed: ",
image.Width() * image.Height() * 4)));
return false; return false;
} }
mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(), mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(),