From 4c2132177559ab7a9636138a2190e2e57177b6c3 Mon Sep 17 00:00:00 2001
From: definability <5962274+definability@users.noreply.github.com>
Date: Sun, 10 Dec 2023 16:13:19 +0000
Subject: [PATCH] Android. Process ARGB and Grayscale input packets
Issue https://github.com/google/mediapipe/issues/5017
points to the issue of incorrect bitmap conversion in
Android when RGBA or Grayscale images are used.
As far as the library allows the creation of ALPHA8 bitmaps,
it is not enough to use only `getBitmapFromRgb` and `getBitmapFromRgba`,
so the new `getBitmapFromAlpha` method is needed.
Otherwise, existing users relying on the 8-bit input but not using the
input packet would catch errors despite their code operating as
intended.
Add correct processing of images with 1, 3, and 4 channels
to `FaceDetector`, `FaceLandmarker`, `FaceStylizer`,
`GestureRecognizer`, `HandLandmarker`, `HolisticLandmarker`,
`ImageClassifier`, `ImageEmbedder`, `ImageSegmenter`,
`InteractiveSegmenter`, `ObjectDetector`, and `PoseLandmarker`.
If the number of channels is not 1, 3, or 4,
throw MediaPipeException with `INVALID_ARGUMENT` status code.
Implement methods and functions needed to convert single-channel images
to RGBA bitmaps.
The native function `AlphaToRGBA` uses `memset`
to set all bytes of the resulting buffer to zeros
and then copies the alpha values of the input image to the alpha values
of the corresponding output pixels.
Perform a slight refactoring:
use `operator[i]` instead of `*(array + i)` for clarity.
---
.../framework/AndroidPacketGetter.java | 38 +++++++++++++++++++
.../mediapipe/framework/PacketGetter.java | 10 +++++
.../mediapipe/framework/jni/colorspace.h | 31 +++++++++++----
.../framework/jni/packet_getter_jni.cc | 35 +++++++++++++++++
.../framework/jni/packet_getter_jni.h | 5 +++
.../framework/jni/register_natives.cc | 3 ++
.../vision/facedetector/FaceDetector.java | 18 +++++++--
.../vision/facelandmarker/FaceLandmarker.java | 18 +++++++--
.../vision/facestylizer/FaceStylizer.java | 17 +++++++--
.../gesturerecognizer/GestureRecognizer.java | 18 +++++++--
.../vision/handlandmarker/HandLandmarker.java | 18 +++++++--
.../HolisticLandmarker.java | 17 +++++++--
.../imageclassifier/ImageClassifier.java | 17 +++++++--
.../vision/imageembedder/ImageEmbedder.java | 17 +++++++--
.../vision/imagesegmenter/ImageSegmenter.java | 17 +++++++--
.../InteractiveSegmenter.java | 17 +++++++--
.../vision/objectdetector/ObjectDetector.java | 18 +++++++--
.../vision/poselandmarker/PoseLandmarker.java | 17 +++++++--
18 files changed, 287 insertions(+), 44 deletions(-)
diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java
index 53cf480eb..86fb24d99 100644
--- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java
+++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java
@@ -116,5 +116,43 @@ public final class AndroidPacketGetter {
mutableBitmap.copyPixelsFromBuffer(buffer);
}
+ /**
+ * Gets an {@code ARGB_8888} bitmap from an 8-bit alpha mediapipe image frame packet.
+ *
+ * @param packet mediapipe packet
+ * @return {@link Bitmap} with pixels copied from the packet
+ */
+ public static Bitmap getBitmapFromAlpha(Packet packet) {
+ int width = PacketGetter.getImageWidth(packet);
+ int height = PacketGetter.getImageHeight(packet);
+ Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
+ copyAlphaToBitmap(packet, bitmap, width, height);
+ return bitmap;
+ }
+
+ /**
+ * Copies data from an 8-bit alpha mediapipe image frame packet to {@code ARGB_8888} bitmap.
+ *
+ * @param packet mediapipe packet
+ * @param inBitmap mutable {@link Bitmap} of same dimension and config as the expected output, the
+ * image would be copied to this {@link Bitmap}
+ */
+ public static void copyAlphaToBitmap(Packet packet, Bitmap inBitmap) {
+ checkArgument(inBitmap.isMutable(), "Input bitmap should be mutable.");
+ checkArgument(
+ inBitmap.getConfig() == Config.ARGB_8888, "Input bitmap should be of type ARGB_8888.");
+ int width = PacketGetter.getImageWidth(packet);
+ int height = PacketGetter.getImageHeight(packet);
+ checkArgument(inBitmap.getByteCount() == width * height, "Input bitmap size mismatch.");
+ copyAlphaToBitmap(packet, inBitmap, width, height);
+ }
+
+ private static void copyAlphaToBitmap(Packet packet, Bitmap mutableBitmap, int width, int height) {
+ // TODO: use NDK Bitmap access instead of copyPixelsToBuffer.
+ ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4);
+ PacketGetter.getRgbaFromAlpha(packet, buffer);
+ mutableBitmap.copyPixelsFromBuffer(buffer);
+ }
+
private AndroidPacketGetter() {}
}
diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java
index 3d6b16ce6..379826392 100644
--- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java
+++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java
@@ -273,6 +273,15 @@ public final class PacketGetter {
return nativeGetRgbaFromRgb(packet.getNativeHandle(), buffer);
}
+ /**
+ * Converts an 8-bit alpha mediapipe image frame packet to an RGBA Byte buffer.
+ *
+ *
Use {@link ByteBuffer#allocateDirect} when allocating the buffer.
+ */
+ public static boolean getRgbaFromAlpha(final Packet packet, ByteBuffer buffer) {
+ return nativeGetRgbaFromAlpha(packet.getNativeHandle(), buffer);
+ }
+
/**
* Converts the audio matrix data back into byte data.
*
@@ -443,6 +452,7 @@ public final class PacketGetter {
long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy);
private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer);
+ private static native boolean nativeGetRgbaFromAlpha(long nativePacketHandle, ByteBuffer buffer);
// Retrieves the values that are in the VideoHeader.
private static native int nativeGetVideoHeaderWidth(long nativepackethandle);
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h b/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h
index f5ad09acd..3f811f41e 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h
@@ -28,16 +28,16 @@ inline void RgbaToRgb(const uint8_t* rgba_img, int rgba_width_step, int width,
const auto* rgba = rgba_img + y * rgba_width_step;
auto* rgb = rgb_img + y * rgb_width_step;
for (int x = 0; x < width; ++x) {
- *rgb = *rgba;
- *(rgb + 1) = *(rgba + 1);
- *(rgb + 2) = *(rgba + 2);
+ rgb[0] = rgba[0];
+ rgb[1] = rgba[1];
+ rgb[2] = rgba[2];
rgb += 3;
rgba += 4;
}
}
}
-// Converts a RGB image to RGBA
+// Converts an RGB image to RGBA
inline void RgbToRgba(const uint8_t* rgb_img, int rgb_width_step, int width,
int height, uint8_t* rgba_img, int rgba_width_step,
uint8_t alpha) {
@@ -45,16 +45,31 @@ inline void RgbToRgba(const uint8_t* rgb_img, int rgb_width_step, int width,
const auto* rgb = rgb_img + y * rgb_width_step;
auto* rgba = rgba_img + y * rgba_width_step;
for (int x = 0; x < width; ++x) {
- *rgba = *rgb;
- *(rgba + 1) = *(rgb + 1);
- *(rgba + 2) = *(rgb + 2);
- *(rgba + 3) = alpha;
+ rgba[0] = rgb[0];
+ rgba[1] = rgb[1];
+ rgba[2] = rgb[2];
+ rgba[3] = alpha;
rgb += 3;
rgba += 4;
}
}
}
+// Converts an 8-bit alpha image to RGBA
+inline void AlphaToRgba(const uint8_t* alpha_img, int alpha_width_step, int width,
+ int height, uint8_t* rgba_img, int rgba_width_step) {
+ memset(rgba_img, 0, rgba_width_step * height);
+ for (int y = 0; y < height; ++y) {
+ const auto* alpha = alpha_img + y * alpha_width_step;
+ auto* rgba = rgba_img + y * rgba_width_step;
+ for (int x = 0; x < width; ++x) {
+ rgba[3] = *alpha;
+ ++alpha;
+ rgba += 4;
+ }
+ }
+}
+
} // namespace android
} // namespace mediapipe
#endif // JAVA_COM_GOOGLE_MEDIAPIPE_FRAMEWORK_JNI_COLORSPACE_H_
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc
index 093b147a2..e01de7487 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc
@@ -511,6 +511,41 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
return true;
}
+JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)(
+ JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer) {
+ mediapipe::Packet mediapipe_packet =
+ mediapipe::android::Graph::GetPacketFromHandle(packet);
+ const bool is_image =
+ mediapipe_packet.ValidateAsType().ok();
+ const mediapipe::ImageFrame& image =
+ is_image ? *GetFromNativeHandle(packet)
+ .GetImageFrameSharedPtr()
+ .get()
+ : GetFromNativeHandle(packet);
+ uint8_t* rgba_data =
+ static_cast(env->GetDirectBufferAddress(byte_buffer));
+ int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
+ if (rgba_data == nullptr || buffer_size < 0) {
+ ThrowIfError(env, absl::InvalidArgumentError(
+ "input buffer does not support direct access"));
+ return false;
+ }
+ if (buffer_size != image.Width() * image.Height() * 4) {
+ ThrowIfError(env,
+ absl::InvalidArgumentError(absl::StrCat(
+ "Buffer size has to be width*height*4\n"
+ "Image width: ",
+ image.Width(), ", Image height: ", image.Height(),
+ ", Buffer size: ", buffer_size, ", Buffer size needed: ",
+ image.Width() * image.Height() * 4)));
+ return false;
+ }
+ mediapipe::android::AlphaToRgba(image.PixelData(), image.WidthStep(),
+ image.Width(), image.Height(), rgba_data,
+ image.Width() * 4);
+ return true;
+}
+
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetVideoHeaderWidth)(
JNIEnv* env, jobject thiz, jlong packet) {
return GetFromNativeHandle(packet).width;
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h
index 202795307..c56cb6e72 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h
@@ -137,6 +137,11 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
+// Before calling this, the byte_buffer needs to have the correct allocated
+// size.
+JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)(
+ JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
+
// Returns the width in VideoHeader packet.
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetVideoHeaderWidth)(
JNIEnv* env, jobject thiz, jlong packet);
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc
index 3f96a404d..b0084266c 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc
@@ -284,6 +284,9 @@ void RegisterPacketGetterNatives(JNIEnv *env) {
AddJNINativeMethod(&packet_getter_methods, packet_getter,
"nativeGetRgbaFromRgb", "(JLjava/nio/ByteBuffer;)Z",
(void *)&PACKET_GETTER_METHOD(nativeGetRgbaFromRgb));
+ AddJNINativeMethod(&packet_getter_methods, packet_getter,
+ "nativeGetRgbaFromA", "(JLjava/nio/ByteBuffer;)Z",
+ (void *)&PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha));
RegisterNativesVector(env, packet_getter_class, packet_getter_methods);
env->DeleteLocalRef(packet_getter_class);
}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
index 4152e5d4d..cc74176d8 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
@@ -19,6 +19,7 @@ import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -164,9 +165,20 @@ public final class FaceDetector extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
detectorOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java
index 2f57aacb1..fcf488122 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java
@@ -21,6 +21,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -206,9 +207,20 @@ public final class FaceLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java
index 305f22301..c94f94138 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java
@@ -151,9 +151,20 @@ public final class FaceStylizer extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
// Empty output image packets indicates that no face stylization is applied.
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java
index 42e46c6d0..1bd67dd08 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java
@@ -22,6 +22,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -187,9 +188,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
recognizerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java
index f259f24f2..94338c84b 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java
@@ -22,6 +22,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -178,9 +179,20 @@ public final class HandLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java
index e80da4fca..2bb07859e 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java
@@ -257,9 +257,20 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
index 6f5790398..e882909d3 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
@@ -187,9 +187,20 @@ public final class ImageClassifier extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
options.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
index 7e4c7c229..f1eb04bd5 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
@@ -170,9 +170,20 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
options.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java
index 813dba93c..7e5c28dfb 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java
@@ -211,9 +211,20 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java
index a53f76f9d..699fdf8c5 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java
@@ -222,9 +222,20 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java
index 3a0343424..7a8e0b1e4 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java
@@ -19,6 +19,7 @@ import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -196,9 +197,20 @@ public final class ObjectDetector extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
detectorOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java
index fba2c714e..a274407df 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java
@@ -189,9 +189,20 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);