diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java
index 53cf480eb..86fb24d99 100644
--- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java
+++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketGetter.java
@@ -116,5 +116,43 @@ public final class AndroidPacketGetter {
mutableBitmap.copyPixelsFromBuffer(buffer);
}
+ /**
+ * Gets an {@code ARGB_8888} bitmap from an 8-bit alpha mediapipe image frame packet.
+ *
+ * @param packet mediapipe packet
+ * @return {@link Bitmap} with pixels copied from the packet
+ */
+ public static Bitmap getBitmapFromAlpha(Packet packet) {
+ int width = PacketGetter.getImageWidth(packet);
+ int height = PacketGetter.getImageHeight(packet);
+ Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
+ copyAlphaToBitmap(packet, bitmap, width, height);
+ return bitmap;
+ }
+
+ /**
+ * Copies data from an 8-bit alpha mediapipe image frame packet to {@code ARGB_8888} bitmap.
+ *
+ * @param packet mediapipe packet
+ * @param inBitmap mutable {@link Bitmap} of same dimension and config as the expected output, the
+ * image would be copied to this {@link Bitmap}
+ */
+ public static void copyAlphaToBitmap(Packet packet, Bitmap inBitmap) {
+ checkArgument(inBitmap.isMutable(), "Input bitmap should be mutable.");
+ checkArgument(
+ inBitmap.getConfig() == Config.ARGB_8888, "Input bitmap should be of type ARGB_8888.");
+ int width = PacketGetter.getImageWidth(packet);
+ int height = PacketGetter.getImageHeight(packet);
+ checkArgument(inBitmap.getByteCount() == width * height, "Input bitmap size mismatch.");
+ copyAlphaToBitmap(packet, inBitmap, width, height);
+ }
+
+ private static void copyAlphaToBitmap(Packet packet, Bitmap mutableBitmap, int width, int height) {
+ // TODO: use NDK Bitmap access instead of copyPixelsToBuffer.
+ ByteBuffer buffer = ByteBuffer.allocateDirect(width * height * 4);
+ PacketGetter.getRgbaFromAlpha(packet, buffer);
+ mutableBitmap.copyPixelsFromBuffer(buffer);
+ }
+
private AndroidPacketGetter() {}
}
diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java
index 3d6b16ce6..379826392 100644
--- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java
+++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java
@@ -273,6 +273,15 @@ public final class PacketGetter {
return nativeGetRgbaFromRgb(packet.getNativeHandle(), buffer);
}
+ /**
+ * Converts an 8-bit alpha mediapipe image frame packet to an RGBA Byte buffer.
+ *
+ *
Use {@link ByteBuffer#allocateDirect} when allocating the buffer.
+ */
+ public static boolean getRgbaFromAlpha(final Packet packet, ByteBuffer buffer) {
+ return nativeGetRgbaFromAlpha(packet.getNativeHandle(), buffer);
+ }
+
/**
* Converts the audio matrix data back into byte data.
*
@@ -443,6 +452,7 @@ public final class PacketGetter {
long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy);
private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer);
+ private static native boolean nativeGetRgbaFromAlpha(long nativePacketHandle, ByteBuffer buffer);
// Retrieves the values that are in the VideoHeader.
private static native int nativeGetVideoHeaderWidth(long nativepackethandle);
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h b/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h
index f5ad09acd..3f811f41e 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h
@@ -28,16 +28,16 @@ inline void RgbaToRgb(const uint8_t* rgba_img, int rgba_width_step, int width,
const auto* rgba = rgba_img + y * rgba_width_step;
auto* rgb = rgb_img + y * rgb_width_step;
for (int x = 0; x < width; ++x) {
- *rgb = *rgba;
- *(rgb + 1) = *(rgba + 1);
- *(rgb + 2) = *(rgba + 2);
+ rgb[0] = rgba[0];
+ rgb[1] = rgba[1];
+ rgb[2] = rgba[2];
rgb += 3;
rgba += 4;
}
}
}
-// Converts a RGB image to RGBA
+// Converts an RGB image to RGBA
inline void RgbToRgba(const uint8_t* rgb_img, int rgb_width_step, int width,
int height, uint8_t* rgba_img, int rgba_width_step,
uint8_t alpha) {
@@ -45,16 +45,31 @@ inline void RgbToRgba(const uint8_t* rgb_img, int rgb_width_step, int width,
const auto* rgb = rgb_img + y * rgb_width_step;
auto* rgba = rgba_img + y * rgba_width_step;
for (int x = 0; x < width; ++x) {
- *rgba = *rgb;
- *(rgba + 1) = *(rgb + 1);
- *(rgba + 2) = *(rgb + 2);
- *(rgba + 3) = alpha;
+ rgba[0] = rgb[0];
+ rgba[1] = rgb[1];
+ rgba[2] = rgb[2];
+ rgba[3] = alpha;
rgb += 3;
rgba += 4;
}
}
}
+// Converts an 8-bit alpha image to RGBA
+inline void AlphaToRgba(const uint8_t* alpha_img, int alpha_width_step, int width,
+ int height, uint8_t* rgba_img, int rgba_width_step) {
+ memset(rgba_img, 0, rgba_width_step * height);
+ for (int y = 0; y < height; ++y) {
+ const auto* alpha = alpha_img + y * alpha_width_step;
+ auto* rgba = rgba_img + y * rgba_width_step;
+ for (int x = 0; x < width; ++x) {
+ rgba[3] = *alpha;
+ ++alpha;
+ rgba += 4;
+ }
+ }
+}
+
} // namespace android
} // namespace mediapipe
#endif // JAVA_COM_GOOGLE_MEDIAPIPE_FRAMEWORK_JNI_COLORSPACE_H_
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc
index 093b147a2..e01de7487 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc
@@ -511,6 +511,41 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
return true;
}
+JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)(
+ JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer) {
+ mediapipe::Packet mediapipe_packet =
+ mediapipe::android::Graph::GetPacketFromHandle(packet);
+ const bool is_image =
+ mediapipe_packet.ValidateAsType().ok();
+ const mediapipe::ImageFrame& image =
+ is_image ? *GetFromNativeHandle(packet)
+ .GetImageFrameSharedPtr()
+ .get()
+ : GetFromNativeHandle(packet);
+ uint8_t* rgba_data =
+ static_cast(env->GetDirectBufferAddress(byte_buffer));
+ int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
+ if (rgba_data == nullptr || buffer_size < 0) {
+ ThrowIfError(env, absl::InvalidArgumentError(
+ "input buffer does not support direct access"));
+ return false;
+ }
+ if (buffer_size != image.Width() * image.Height() * 4) {
+ ThrowIfError(env,
+ absl::InvalidArgumentError(absl::StrCat(
+ "Buffer size has to be width*height*4\n"
+ "Image width: ",
+ image.Width(), ", Image height: ", image.Height(),
+ ", Buffer size: ", buffer_size, ", Buffer size needed: ",
+ image.Width() * image.Height() * 4)));
+ return false;
+ }
+ mediapipe::android::AlphaToRgba(image.PixelData(), image.WidthStep(),
+ image.Width(), image.Height(), rgba_data,
+ image.Width() * 4);
+ return true;
+}
+
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetVideoHeaderWidth)(
JNIEnv* env, jobject thiz, jlong packet) {
return GetFromNativeHandle(packet).width;
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h
index 202795307..c56cb6e72 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h
@@ -137,6 +137,11 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
+// Before calling this, the byte_buffer needs to have the correct allocated
+// size.
+JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha)(
+ JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
+
// Returns the width in VideoHeader packet.
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetVideoHeaderWidth)(
JNIEnv* env, jobject thiz, jlong packet);
diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc
index 3f96a404d..b0084266c 100644
--- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc
+++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc
@@ -284,6 +284,9 @@ void RegisterPacketGetterNatives(JNIEnv *env) {
AddJNINativeMethod(&packet_getter_methods, packet_getter,
"nativeGetRgbaFromRgb", "(JLjava/nio/ByteBuffer;)Z",
(void *)&PACKET_GETTER_METHOD(nativeGetRgbaFromRgb));
+ AddJNINativeMethod(&packet_getter_methods, packet_getter,
+ "nativeGetRgbaFromA", "(JLjava/nio/ByteBuffer;)Z",
+ (void *)&PACKET_GETTER_METHOD(nativeGetRgbaFromAlpha));
RegisterNativesVector(env, packet_getter_class, packet_getter_methods);
env->DeleteLocalRef(packet_getter_class);
}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
index 4152e5d4d..cc74176d8 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java
@@ -19,6 +19,7 @@ import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -164,9 +165,20 @@ public final class FaceDetector extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
detectorOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java
index 2f57aacb1..fcf488122 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facelandmarker/FaceLandmarker.java
@@ -21,6 +21,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -206,9 +207,20 @@ public final class FaceLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java
index 305f22301..c94f94138 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java
@@ -151,9 +151,20 @@ public final class FaceStylizer extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
// Empty output image packets indicates that no face stylization is applied.
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java
index 42e46c6d0..1bd67dd08 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java
@@ -22,6 +22,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -187,9 +188,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
recognizerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java
index f259f24f2..94338c84b 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarker.java
@@ -22,6 +22,7 @@ import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -178,9 +179,20 @@ public final class HandLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java
index d8b6e7a1e..464dec7cf 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/holisticlandmarker/HolisticLandmarker.java
@@ -257,9 +257,20 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
index 6f5790398..e882909d3 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java
@@ -187,9 +187,20 @@ public final class ImageClassifier extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
options.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
index 7e4c7c229..f1eb04bd5 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java
@@ -170,9 +170,20 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
options.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java
index 813dba93c..7e5c28dfb 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java
@@ -211,9 +211,20 @@ public final class ImageSegmenter extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java
index a53f76f9d..699fdf8c5 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java
@@ -222,9 +222,20 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(imageOutStreamIndex)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
segmenterOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java
index 3a0343424..7a8e0b1e4 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java
@@ -19,6 +19,7 @@ import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
+import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
@@ -196,9 +197,20 @@ public final class ObjectDetector extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
detectorOptions.resultListener().ifPresent(handler::setResultListener);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java
index fba2c714e..a274407df 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java
@@ -189,9 +189,20 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
@Override
public MPImage convertToTaskInput(List packets) {
- return new BitmapImageBuilder(
- AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
- .build();
+ Packet currentPacket = packets.get(IMAGE_OUT_STREAM_INDEX);
+ int numChannels = PacketGetter.getImageNumChannels(currentPacket);
+ switch (numChannels) {
+ case 1:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromAlpha(currentPacket)).build();
+ case 3:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgb(currentPacket)).build();
+ case 4:
+ return new BitmapImageBuilder(AndroidPacketGetter.getBitmapFromRgba(currentPacket)).build();
+ default:
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "Channels should be: 1, 3, or 4, but is " + numChannels);
+ }
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);