Simplify image creation in PacketCreator

Use more existing functions, remove redundant code, remove direct use of RuntimeException.

PiperOrigin-RevId: 489868983
This commit is contained in:
Camillo Lugaresi 2022-11-20 19:30:05 -08:00 committed by Copybara-Service
parent 6cf464636b
commit 3ac7f6a216
3 changed files with 64 additions and 95 deletions

View File

@ -55,7 +55,11 @@ public class PacketCreator {
public Packet createRgbImage(ByteBuffer buffer, int width, int height) { public Packet createRgbImage(ByteBuffer buffer, int width, int height) {
int widthStep = (((width * 3) + 3) / 4) * 4; int widthStep = (((width * 3) + 3) / 4) * 4;
if (widthStep * height != buffer.capacity()) { if (widthStep * height != buffer.capacity()) {
throw new RuntimeException("The size of the buffer should be: " + widthStep * height); throw new IllegalArgumentException(
"The size of the buffer should be: "
+ widthStep * height
+ " but is "
+ buffer.capacity());
} }
return Packet.create( return Packet.create(
nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height)); nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -123,7 +127,11 @@ public class PacketCreator {
*/ */
public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) { public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) {
if (width * height * 4 != buffer.capacity()) { if (width * height * 4 != buffer.capacity()) {
throw new RuntimeException("The size of the buffer should be: " + width * height * 4); throw new IllegalArgumentException(
"The size of the buffer should be: "
+ width * height * 4
+ " but is "
+ buffer.capacity());
} }
return Packet.create( return Packet.create(
nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height)); nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -136,7 +144,7 @@ public class PacketCreator {
*/ */
public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) { public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) {
if (width * height != buffer.capacity()) { if (width * height != buffer.capacity()) {
throw new RuntimeException( throw new IllegalArgumentException(
"The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); "The size of the buffer should be: " + width * height + " but is " + buffer.capacity());
} }
return Packet.create( return Packet.create(
@ -150,7 +158,11 @@ public class PacketCreator {
*/ */
public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) { public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) {
if (buffer.capacity() != width * height * 4) { if (buffer.capacity() != width * height * 4) {
throw new RuntimeException("buffer doesn't have the correct size."); throw new IllegalArgumentException(
"The size of the buffer should be: "
+ width * height * 4
+ " but is "
+ buffer.capacity());
} }
return Packet.create( return Packet.create(
nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -163,7 +175,11 @@ public class PacketCreator {
*/ */
public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) { public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) {
if (buffer.capacity() != width * height * 4) { if (buffer.capacity() != width * height * 4) {
throw new RuntimeException("buffer doesn't have the correct size."); throw new IllegalArgumentException(
"The size of the buffer should be: "
+ width * height * 4
+ " but is "
+ buffer.capacity());
} }
return Packet.create( return Packet.create(
nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -354,25 +370,24 @@ public class PacketCreator {
* <p>For 3 and 4 channel images, the pixel rows should have 4-byte alignment. * <p>For 3 and 4 channel images, the pixel rows should have 4-byte alignment.
*/ */
public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) {
int widthStep;
if (numChannels == 4) { if (numChannels == 4) {
if (buffer.capacity() != width * height * 4) { widthStep = width * 4;
throw new RuntimeException("buffer doesn't have the correct size.");
}
} else if (numChannels == 3) { } else if (numChannels == 3) {
int widthStep = (((width * 3) + 3) / 4) * 4; widthStep = (((width * 3) + 3) / 4) * 4;
if (widthStep * height != buffer.capacity()) {
throw new RuntimeException("The size of the buffer should be: " + widthStep * height);
}
} else if (numChannels == 1) { } else if (numChannels == 1) {
if (width * height != buffer.capacity()) { widthStep = width;
throw new RuntimeException(
"The size of the buffer should be: " + width * height + " but is " + buffer.capacity());
}
} else { } else {
throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels);
}
int expectedSize = widthStep * height;
if (buffer.capacity() != expectedSize) {
throw new IllegalArgumentException(
"The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity());
} }
return Packet.create( return Packet.create(
nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); nativeCreateCpuImage(
mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels));
} }
/** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */
@ -430,7 +445,7 @@ public class PacketCreator {
long context, int name, int width, int height, TextureReleaseCallback releaseCallback); long context, int name, int width, int height, TextureReleaseCallback releaseCallback);
private native long nativeCreateCpuImage( private native long nativeCreateCpuImage(
long context, ByteBuffer buffer, int width, int height, int numChannels); long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels);
private native long nativeCreateInt32Array(long context, int[] data); private native long nativeCreateInt32Array(long context, int[] data);

View File

@ -111,22 +111,8 @@ absl::StatusOr<mediapipe::GpuBuffer> CreateGpuBuffer(
// ByteBuffer. // ByteBuffer.
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>>
CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
jint height, jint height, jint width_step,
mediapipe::ImageFormat::Format format) { mediapipe::ImageFormat::Format format) {
switch (format) {
case mediapipe::ImageFormat::SRGBA:
case mediapipe::ImageFormat::SRGB:
case mediapipe::ImageFormat::GRAY8:
break;
default:
return absl::InvalidArgumentError(
"Format must be either SRGBA, SRGB, or GRAY8.");
}
auto image_frame = std::make_unique<mediapipe::ImageFrame>(
format, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) { if (buffer_data == nullptr || buffer_size < 0) {
@ -135,34 +121,19 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
"using allocateDirect."); "using allocateDirect.");
} }
const int num_channels = image_frame->NumberOfChannels(); const int expected_buffer_size = height * width_step;
const int expected_buffer_size =
num_channels == 1 ? width * height : image_frame->PixelDataSize();
RET_CHECK_EQ(buffer_size, expected_buffer_size) RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< (num_channels != 1 << "Input buffer size should be " << expected_buffer_size
? "The input image buffer should have 4 bytes alignment. " << " but is: " << buffer_size;
: "")
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
// Copy buffer data to image frame's pixel_data_. auto image_frame = std::make_unique<mediapipe::ImageFrame>();
if (num_channels == 1) { // TODO: we could retain the buffer with a special deleter and use
const int width_step = image_frame->WidthStep(); // the data directly without a copy. May need a new Java API since existing
const char* src_row = reinterpret_cast<const char*>(buffer_data); // code might expect to be able to overwrite the buffer after creating an
char* dst_row = reinterpret_cast<char*>(image_frame->MutablePixelData()); // ImageFrame from it.
for (int i = height; i > 0; --i) { image_frame->CopyPixelData(
std::memcpy(dst_row, src_row, width); format, width, height, width_step, static_cast<const uint8*>(buffer_data),
src_row += width; mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
dst_row += width_step;
}
} else {
// 3 and 4 channels.
std::memcpy(image_frame->MutablePixelData(), buffer_data,
image_frame->PixelDataSize());
}
return image_frame; return image_frame;
} }
@ -183,8 +154,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
auto image_frame_or = CreateImageFrameFromByteBuffer( // We require 4-byte alignment. See Java method.
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); constexpr int kAlignment = 4;
int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1;
auto image_frame_or =
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height,
width_step, mediapipe::ImageFormat::SRGB);
if (ThrowIfError(env, image_frame_or.status())) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
@ -204,10 +179,8 @@ absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> CreateRgbImageFromRgba(
const int expected_buffer_size = width * height * 4; const int expected_buffer_size = width * height * 4;
RET_CHECK_EQ(buffer_size, expected_buffer_size) RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< "Please check the input buffer size." << "Input buffer size should be " << expected_buffer_size
<< " Buffer size: " << buffer_size << " but is: " << buffer_size;
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
auto image_frame = absl::make_unique<mediapipe::ImageFrame>( auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, width, height, mediapipe::ImageFormat::SRGB, width, height,
@ -232,7 +205,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
auto image_frame_or = CreateImageFrameFromByteBuffer( auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8);
if (ThrowIfError(env, image_frame_or.status())) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
@ -242,28 +215,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
// TODO: merge this case with CreateImageFrameFromByteBuffer.
auto image_frame_or = auto image_frame_or =
[&]() -> absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> { CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4,
const void* data = env->GetDirectBufferAddress(byte_buffer); mediapipe::ImageFormat::VEC32F1);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"input buffer does not support direct access");
}
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::VEC32F1, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize())
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << image_frame->PixelDataSize()
<< ", Image width: " << width;
std::memcpy(image_frame->MutablePixelData(), data,
image_frame->PixelDataSize());
return image_frame;
}();
if (ThrowIfError(env, image_frame_or.status())) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
@ -272,10 +226,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) { jint height) {
auto image_frame_or = CreateImageFrameFromByteBuffer( auto image_frame_or =
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4,
mediapipe::ImageFormat::SRGBA);
if (ThrowIfError(env, image_frame_or.status())) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet); return CreatePacketWithContext(context, packet);
} }
@ -417,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height, jint num_channels) { jint height, jint width_step, jint num_channels) {
mediapipe::ImageFormat::Format format; mediapipe::ImageFormat::Format format;
switch (num_channels) { switch (num_channels) {
case 4: case 4:
@ -436,8 +390,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
return 0L; return 0L;
} }
auto image_frame_or = auto image_frame_or = CreateImageFrameFromByteBuffer(
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); env, byte_buffer, width, height, width_step, format);
if (ThrowIfError(env, image_frame_or.status())) return 0L; if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Packet packet =

View File

@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height, jint num_channels); jint height, jint width_step, jint num_channels);
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)(
JNIEnv* env, jobject thiz, jlong context, jint name, jint width, JNIEnv* env, jobject thiz, jlong context, jint name, jint width,