Simplify image creation in PacketCreator

Use more existing functions, remove redundant code, remove direct use of RuntimeException.

PiperOrigin-RevId: 489868983
This commit is contained in:
Camillo Lugaresi 2022-11-20 19:30:05 -08:00 committed by Copybara-Service
parent 6cf464636b
commit 3ac7f6a216
3 changed files with 64 additions and 95 deletions

View File

@ -55,7 +55,11 @@ public class PacketCreator {
public Packet createRgbImage(ByteBuffer buffer, int width, int height) {
int widthStep = (((width * 3) + 3) / 4) * 4;
if (widthStep * height != buffer.capacity()) {
throw new RuntimeException("The size of the buffer should be: " + widthStep * height);
throw new IllegalArgumentException(
"The size of the buffer should be: "
+ widthStep * height
+ " but is "
+ buffer.capacity());
}
return Packet.create(
nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -123,7 +127,11 @@ public class PacketCreator {
*/
public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) {
if (width * height * 4 != buffer.capacity()) {
throw new RuntimeException("The size of the buffer should be: " + width * height * 4);
throw new IllegalArgumentException(
"The size of the buffer should be: "
+ width * height * 4
+ " but is "
+ buffer.capacity());
}
return Packet.create(
nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -136,7 +144,7 @@ public class PacketCreator {
*/
public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) {
if (width * height != buffer.capacity()) {
throw new RuntimeException(
throw new IllegalArgumentException(
"The size of the buffer should be: " + width * height + " but is " + buffer.capacity());
}
return Packet.create(
@ -150,7 +158,11 @@ public class PacketCreator {
*/
public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) {
if (buffer.capacity() != width * height * 4) {
throw new RuntimeException("buffer doesn't have the correct size.");
throw new IllegalArgumentException(
"The size of the buffer should be: "
+ width * height * 4
+ " but is "
+ buffer.capacity());
}
return Packet.create(
nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -163,7 +175,11 @@ public class PacketCreator {
*/
public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) {
if (buffer.capacity() != width * height * 4) {
throw new RuntimeException("buffer doesn't have the correct size.");
throw new IllegalArgumentException(
"The size of the buffer should be: "
+ width * height * 4
+ " but is "
+ buffer.capacity());
}
return Packet.create(
nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height));
@ -354,25 +370,24 @@ public class PacketCreator {
* <p>For 3 and 4 channel images, the pixel rows should have 4-byte alignment.
*/
public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) {
int widthStep;
if (numChannels == 4) {
if (buffer.capacity() != width * height * 4) {
throw new RuntimeException("buffer doesn't have the correct size.");
}
widthStep = width * 4;
} else if (numChannels == 3) {
int widthStep = (((width * 3) + 3) / 4) * 4;
if (widthStep * height != buffer.capacity()) {
throw new RuntimeException("The size of the buffer should be: " + widthStep * height);
}
widthStep = (((width * 3) + 3) / 4) * 4;
} else if (numChannels == 1) {
if (width * height != buffer.capacity()) {
throw new RuntimeException(
"The size of the buffer should be: " + width * height + " but is " + buffer.capacity());
}
widthStep = width;
} else {
throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels);
throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels);
}
int expectedSize = widthStep * height;
if (buffer.capacity() != expectedSize) {
throw new IllegalArgumentException(
"The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity());
}
return Packet.create(
nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels));
nativeCreateCpuImage(
mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels));
}
/** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */
@ -430,7 +445,7 @@ public class PacketCreator {
long context, int name, int width, int height, TextureReleaseCallback releaseCallback);
private native long nativeCreateCpuImage(
long context, ByteBuffer buffer, int width, int height, int numChannels);
long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels);
private native long nativeCreateInt32Array(long context, int[] data);

View File

@ -111,22 +111,8 @@ absl::StatusOr<mediapipe::GpuBuffer> CreateGpuBuffer(
// ByteBuffer.
absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>>
CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
jint height,
jint height, jint width_step,
mediapipe::ImageFormat::Format format) {
switch (format) {
case mediapipe::ImageFormat::SRGBA:
case mediapipe::ImageFormat::SRGB:
case mediapipe::ImageFormat::GRAY8:
break;
default:
return absl::InvalidArgumentError(
"Format must be either SRGBA, SRGB, or GRAY8.");
}
auto image_frame = std::make_unique<mediapipe::ImageFrame>(
format, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
const void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
@ -135,34 +121,19 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width,
"using allocateDirect.");
}
const int num_channels = image_frame->NumberOfChannels();
const int expected_buffer_size =
num_channels == 1 ? width * height : image_frame->PixelDataSize();
const int expected_buffer_size = height * width_step;
RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< (num_channels != 1
? "The input image buffer should have 4 bytes alignment. "
: "")
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
<< "Input buffer size should be " << expected_buffer_size
<< " but is: " << buffer_size;
// Copy buffer data to image frame's pixel_data_.
if (num_channels == 1) {
const int width_step = image_frame->WidthStep();
const char* src_row = reinterpret_cast<const char*>(buffer_data);
char* dst_row = reinterpret_cast<char*>(image_frame->MutablePixelData());
for (int i = height; i > 0; --i) {
std::memcpy(dst_row, src_row, width);
src_row += width;
dst_row += width_step;
}
} else {
// 3 and 4 channels.
std::memcpy(image_frame->MutablePixelData(), buffer_data,
image_frame->PixelDataSize());
}
auto image_frame = std::make_unique<mediapipe::ImageFrame>();
// TODO: we could retain the buffer with a special deleter and use
// the data directly without a copy. May need a new Java API since existing
// code might expect to be able to overwrite the buffer after creating an
// ImageFrame from it.
image_frame->CopyPixelData(
format, width, height, width_step, static_cast<const uint8*>(buffer_data),
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
return image_frame;
}
@ -183,8 +154,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB);
// We require 4-byte alignment. See Java method.
constexpr int kAlignment = 4;
int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1;
auto image_frame_or =
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height,
width_step, mediapipe::ImageFormat::SRGB);
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
@ -204,10 +179,8 @@ absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> CreateRgbImageFromRgba(
const int expected_buffer_size = width * height * 4;
RET_CHECK_EQ(buffer_size, expected_buffer_size)
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << expected_buffer_size
<< ", Image width: " << width;
<< "Input buffer size should be " << expected_buffer_size
<< " but is: " << buffer_size;
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, width, height,
@ -232,7 +205,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8);
env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8);
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
@ -242,28 +215,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
// TODO: merge this case with CreateImageFrameFromByteBuffer.
auto image_frame_or =
[&]() -> absl::StatusOr<std::unique_ptr<mediapipe::ImageFrame>> {
const void* data = env->GetDirectBufferAddress(byte_buffer);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
if (data == nullptr || buffer_size < 0) {
return absl::InvalidArgumentError(
"input buffer does not support direct access");
}
auto image_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::VEC32F1, width, height,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize())
<< "Please check the input buffer size."
<< " Buffer size: " << buffer_size
<< ", Buffer size needed: " << image_frame->PixelDataSize()
<< ", Image width: " << width;
std::memcpy(image_frame->MutablePixelData(), data,
image_frame->PixelDataSize());
return image_frame;
}();
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4,
mediapipe::ImageFormat::VEC32F1);
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
@ -272,10 +226,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height) {
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA);
auto image_frame_or =
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4,
mediapipe::ImageFormat::SRGBA);
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release());
return CreatePacketWithContext(context, packet);
}
@ -417,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height, jint num_channels) {
jint height, jint width_step, jint num_channels) {
mediapipe::ImageFormat::Format format;
switch (num_channels) {
case 4:
@ -436,8 +390,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
return 0L;
}
auto image_frame_or =
CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format);
auto image_frame_or = CreateImageFrameFromByteBuffer(
env, byte_buffer, width, height, width_step, format);
if (ThrowIfError(env, image_frame_or.status())) return 0L;
mediapipe::Packet packet =

View File

@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)(
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)(
JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width,
jint height, jint num_channels);
jint height, jint width_step, jint num_channels);
JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)(
JNIEnv* env, jobject thiz, jlong context, jint name, jint width,