diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java index 53cf480eb..86fb24d99 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java @@ -116,5 +116,43 @@ public final class AndroidPacketGetter { mutableBitmap.copyPixelsFromBuffer(buffer); } + /** + * Gets an {@code ARGB_8888} bitmap from an 8-bit alpha mediapipe image frame packet. + * + * @param packet mediapipe packet + * @return {@link Bitmap} with pixels copied from the packet + */ + public static Bitmap getBitmapFromAlpha(Packet packet) { + int width = PacketGetter.getImageWidth(packet); + int height = PacketGetter.getImageHeight(packet); + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888); + copyAlphaToBitmap(packet, bitmap, width, height); + return bitmap; + } + + /** + * Copies data from an 8-bit alpha mediapipe image frame packet to {@code ARGB_8888} bitmap. + * + * @param packet mediapipe packet + * @param inBitmap mutable {@link Bitmap} of same dimension and config as the expected output, the + * image would be copied to this {@link Bitmap} + */ + public static void copyAlphaToBitmap(Packet packet, Bitmap inBitmap) { + checkArgument(inBitmap.isMutable(), "Input bitmap should be mutable."); + checkArgument( + inBitmap.getConfig() == Config.ARGB_8888, "Input bitmap should be of type ARGB_8888."); + int width = PacketGetter.getImageWidth(packet); + int height = PacketGetter.getImageHeight(packet); + checkArgument(inBitmap.getByteCount() == width * height, "Input bitmap size mismatch."); + copyAlphaToBitmap(packet, inBitmap, width, height); + } + + private static void copyAlphaToBitmap(Packet packet, Bitmap mutableBitmap, int width, int height) { + // TODO: use NDK Bitmap access instead of copyPixelsToBuffer. + ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4); + PacketGetter.getRgbaFromAlpha(packet, buffer); + mutableBitmap.copyPixelsFromBuffer(buffer); + } + private AndroidPacketGetter() {} } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 3d6b16ce6..379826392 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -273,6 +273,15 @@ public final class PacketGetter { return nativeGetRgbaFromRgb(packet.getNativeHandle(), buffer); } + /** + * Converts an 8-bit alpha mediapipe image frame packet to an RGBA Byte buffer. + * + *

Use {@link ByteBuffer#allocateDirect} when allocating the buffer. + */ + public static boolean getRgbaFromAlpha(final Packet packet, ByteBuffer buffer) { + return nativeGetRgbaFromAlpha(packet.getNativeHandle(), buffer); + } + /** * Converts the audio matrix data back into byte data. * @@ -443,6 +452,7 @@ public final class PacketGetter { long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy); private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer); + private static native boolean nativeGetRgbaFromAlpha(long nativePacketHandle, ByteBuffer buffer); // Retrieves the values that are in the VideoHeader. private static native int nativeGetVideoHeaderWidth(long nativepackethandle); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h b/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h index f5ad09acd..3f811f41e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h @@ -28,16 +28,16 @@ inline void RgbaToRgb(const uint8_t* rgba_img, int rgba_width_step, int width, const auto* rgba = rgba_img + y * rgba_width_step; auto* rgb = rgb_img + y * rgb_width_step; for (int x = 0; x < width; ++x) { - *rgb = *rgba; - *(rgb + 1) = *(rgba + 1); - *(rgb + 2) = *(rgba + 2); + rgb[0] = rgba[0]; + rgb[1] = rgba[1]; + rgb[2] = rgba[2]; rgb += 3; rgba += 4; } } } -// Converts a RGB image to RGBA +// Converts an RGB image to RGBA inline void RgbToRgba(const uint8_t* rgb_img, int rgb_width_step, int width, int height, uint8_t* rgba_img, int rgba_width_step, uint8_t alpha) { @@ -45,16 +45,31 @@ inline void RgbToRgba(const uint8_t* rgb_img, int rgb_width_step, int width, const auto* rgb = rgb_img + y * rgb_width_step; auto* rgba = rgba_img + y * rgba_width_step; for (int x = 0; x < width; ++x) { - *rgba = *rgb; - *(rgba + 1) = *(rgb + 1); - *(rgba + 2) = *(rgb + 2); - *(rgba + 3) = alpha; + rgba[0] = rgb[0]; + rgba[1] = rgb[1]; + rgba[2] = rgb[2]; + rgba[3] = alpha; rgb += 3; rgba += 4; } } } +// Converts an 8-bit alpha image to RGBA +inline void AlphaToRgba(const uint8_t* alpha_img, int alpha_width_step, int width, + int height, uint8_t* rgba_img, int rgba_width_step) { + memset(rgba_img, 0, rgba_width_step * height); + for (int y = 0; y < height; ++y) { + const auto* alpha = alpha_img + y * alpha_width_step; + auto* rgba = rgba_img + y * rgba_width_step; + for (int x = 0; x < width; ++x) { + rgba[3] = *alpha; + ++alpha; + rgba += 4; + } + } +} + } // namespace android } // namespace mediapipe #endif // JAVA_COM_GOOGLE_MEDIAPIPE_FRAMEWORK_JNI_COLORSPACE_H_ diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index 093b147a2..e01de7487 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -511,6 +511,41 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( return true; } +JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)( + JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer) { + mediapipe::Packet mediapipe_packet = + mediapipe::android::Graph::GetPacketFromHandle(packet); + const bool is_image = + mediapipe_packet.ValidateAsType().ok(); + const mediapipe::ImageFrame& image = + is_image ? *GetFromNativeHandle(packet) + .GetImageFrameSharedPtr() + .get() + : GetFromNativeHandle(packet); + uint8_t* rgba_data = + static_cast(env->GetDirectBufferAddress(byte_buffer)); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } + if (buffer_size != image.Width() * image.Height() * 4) { + ThrowIfError(env, + absl::InvalidArgumentError(absl::StrCat( + "Buffer size has to be width*height*4\n" + "Image width: ", + image.Width(), ", Image height: ", image.Height(), + ", Buffer size: ", buffer_size, ", Buffer size needed: ", + image.Width() * image.Height() * 4))); + return false; + } + mediapipe::android::AlphaToRgba(image.PixelData(), image.WidthStep(), + image.Width(), image.Height(), rgba_data, + image.Width() * 4); + return true; +} + JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetVideoHeaderWidth)( JNIEnv* env, jobject thiz, jlong packet) { return GetFromNativeHandle(packet).width; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h index 202795307..c56cb6e72 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h @@ -137,6 +137,11 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)( JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer); +// Before calling this, the byte_buffer needs to have the correct allocated +// size. +JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)( + JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer); + // Returns the width in VideoHeader packet. JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetVideoHeaderWidth)( JNIEnv* env, jobject thiz, jlong packet); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index 3f96a404d..b0084266c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -284,6 +284,9 @@ void RegisterPacketGetterNatives(JNIEnv *env) { AddJNINativeMethod(&packet_getter_methods, packet_getter, "nativeGetRgbaFromRgb", "(JLjava/nio/ByteBuffer;)Z", (void *)&PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)); + AddJNINativeMethod(&packet_getter_methods, packet_getter, + "nativeGetRgbaFromA", "(JLjava/nio/ByteBuffer;)Z", + (void *)&PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)); RegisterNativesVector(env, packet_getter_class, packet_getter_methods); env->DeleteLocalRef(packet_getter_class); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java index 4152e5d4d..cc74176d8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java @@ -19,6 +19,7 @@ import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; @@ -164,9 +165,20 @@ public final class FaceDetector extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); detectorOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java index 2f57aacb1..fcf488122 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java @@ -21,6 +21,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; @@ -206,9 +207,20 @@ public final class FaceLandmarker extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); landmarkerOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java index 305f22301..c94f94138 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java @@ -151,9 +151,20 @@ public final class FaceStylizer extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); // Empty output image packets indicates that no face stylization is applied. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 42e46c6d0..1bd67dd08 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -22,6 +22,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; @@ -187,9 +188,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); recognizerOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java index f259f24f2..94338c84b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java @@ -22,6 +22,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; @@ -178,9 +179,20 @@ public final class HandLandmarker extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); landmarkerOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java index d8b6e7a1e..464dec7cf 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java @@ -257,9 +257,20 @@ public final class HolisticLandmarker extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); landmarkerOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 6f5790398..e882909d3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -187,9 +187,20 @@ public final class ImageClassifier extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); options.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index 7e4c7c229..f1eb04bd5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -170,9 +170,20 @@ public final class ImageEmbedder extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); options.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 813dba93c..7e5c28dfb 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -211,9 +211,20 @@ public final class ImageSegmenter extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); segmenterOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index a53f76f9d..699fdf8c5 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -222,9 +222,20 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); segmenterOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 3a0343424..7a8e0b1e4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -19,6 +19,7 @@ import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; @@ -196,9 +197,20 @@ public final class ObjectDetector extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); detectorOptions.resultListener().ifPresent(handler::setResultListener); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java index fba2c714e..a274407df 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java @@ -189,9 +189,20 @@ public final class PoseLandmarker extends BaseVisionTaskApi { @Override public MPImage convertToTaskInput(List packets) { - return new BitmapImageBuilder( - AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) - .build(); + Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX); + int numChannels = PacketGetter.getImageNumChannels(currentPacket); + switch (numChannels) { + case 1: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build(); + case 3: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build(); + case 4: + return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build(); + default: + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "Channels should be: 1, 3, or 4, but is " + numChannels); + } } }); landmarkerOptions.resultListener().ifPresent(handler::setResultListener);