Internal change
PiperOrigin-RevId: 506053206
This commit is contained in:
parent
0863a8a1e7
commit
5730dec260
|
@ -1329,6 +1329,7 @@ cc_library(
|
|||
hdrs = ["merge_to_vector_calculator.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
|
|
@ -199,6 +199,28 @@ public final class PacketGetter {
|
|||
return nativeGetImageData(packet.getNativeHandle(), buffer);
|
||||
}
|
||||
|
||||
/** Returns the size of Image list. This helps to determine size of allocated ByteBuffer array. */
|
||||
public static int getImageListSize(final Packet packet) {
|
||||
return nativeGetImageListSize(packet.getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
|
||||
* array has the the same size of image list packet, and assumes the output buffer stores pixels
|
||||
* contiguously. It returns false if this assumption does not hold.
|
||||
*
|
||||
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of
|
||||
* ByteBuffer to copy image data to. If false, the ByteBuffer will wrap the memory address of
|
||||
* MediaPipe ImageFrame of graph output, and the ByteBuffer data is available only when MediaPipe
|
||||
* graph is alive.
|
||||
*
|
||||
* <p>Note: this function does not assume the pixel format.
|
||||
*/
|
||||
public static boolean getImageList(
|
||||
final Packet packet, ByteBuffer[] buffersArray, boolean deepCopy) {
|
||||
return nativeGetImageList(packet.getNativeHandle(), buffersArray, deepCopy);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an RGB mediapipe image frame packet to an RGBA Byte buffer.
|
||||
*
|
||||
|
@ -316,7 +338,8 @@ public final class PacketGetter {
|
|||
public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) {
|
||||
return new GraphTextureFrame(
|
||||
nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false),
|
||||
packet.getTimestamp(), /* deferredSync= */true);
|
||||
packet.getTimestamp(),
|
||||
/* deferredSync= */ true);
|
||||
}
|
||||
|
||||
private static native long nativeGetPacketFromReference(long nativePacketHandle);
|
||||
|
@ -363,6 +386,11 @@ public final class PacketGetter {
|
|||
|
||||
private static native boolean nativeGetImageData(long nativePacketHandle, ByteBuffer buffer);
|
||||
|
||||
private static native int nativeGetImageListSize(long nativePacketHandle);
|
||||
|
||||
private static native boolean nativeGetImageList(
|
||||
long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy);
|
||||
|
||||
private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer);
|
||||
// Retrieves the values that are in the VideoHeader.
|
||||
private static native int nativeGetVideoHeaderWidth(long nativepackethandle);
|
||||
|
|
|
@ -50,7 +50,10 @@ public class ByteBufferExtractor {
|
|||
switch (container.getImageProperties().getStorageType()) {
|
||||
case MPImage.STORAGE_TYPE_BYTEBUFFER:
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
return byteBufferImageContainer
|
||||
.getByteBuffer()
|
||||
.asReadOnlyBuffer()
|
||||
.order(ByteOrder.nativeOrder());
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
|
||||
|
@ -74,7 +77,7 @@ public class ByteBufferExtractor {
|
|||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
public static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
MPImageContainer container;
|
||||
MPImageProperties byteBufferProperties =
|
||||
MPImageProperties.builder()
|
||||
|
@ -83,12 +86,16 @@ public class ByteBufferExtractor {
|
|||
.build();
|
||||
if ((container = image.getContainer(byteBufferProperties)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
return byteBufferImageContainer
|
||||
.getByteBuffer()
|
||||
.asReadOnlyBuffer()
|
||||
.order(ByteOrder.nativeOrder());
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
|
||||
.asReadOnlyBuffer();
|
||||
.asReadOnlyBuffer()
|
||||
.order(ByteOrder.nativeOrder());
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
|
||||
ByteBuffer byteBuffer =
|
||||
|
|
|
@ -67,6 +67,8 @@ public class MPImage implements Closeable {
|
|||
IMAGE_FORMAT_YUV_420_888,
|
||||
IMAGE_FORMAT_ALPHA,
|
||||
IMAGE_FORMAT_JPEG,
|
||||
IMAGE_FORMAT_VEC32F1,
|
||||
IMAGE_FORMAT_VEC32F2,
|
||||
})
|
||||
@Retention(RetentionPolicy.SOURCE)
|
||||
public @interface MPImageFormat {}
|
||||
|
@ -81,6 +83,8 @@ public class MPImage implements Closeable {
|
|||
public static final int IMAGE_FORMAT_YUV_420_888 = 7;
|
||||
public static final int IMAGE_FORMAT_ALPHA = 8;
|
||||
public static final int IMAGE_FORMAT_JPEG = 9;
|
||||
public static final int IMAGE_FORMAT_VEC32F1 = 10;
|
||||
public static final int IMAGE_FORMAT_VEC32F2 = 11;
|
||||
|
||||
/** Specifies the image container type. Would be useful for choosing extractors. */
|
||||
@IntDef({
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
@ -39,6 +40,52 @@ template <typename T>
|
|||
const T& GetFromNativeHandle(int64_t packet_handle) {
|
||||
return mediapipe::android::Graph::GetPacketFromHandle(packet_handle).Get<T>();
|
||||
}
|
||||
|
||||
bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image,
|
||||
jobject byte_buffer) {
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
if (buffer_data == nullptr || buffer_size < 0) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access"));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() * image.NumberOfChannels();
|
||||
if (buffer_size != expected_buffer_size) {
|
||||
ThrowIfError(
|
||||
env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected buffer size ", expected_buffer_size,
|
||||
" got: ", buffer_size, ", width ", image.Width(), ", height ",
|
||||
image.Height(), ", channels ", image.NumberOfChannels())));
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (image.ByteDepth()) {
|
||||
case 1: {
|
||||
uint8* data = static_cast<uint8*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
uint16* data = static_cast<uint16*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
float* data = static_cast<float*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetPacketFromReference)(
|
||||
|
@ -298,46 +345,51 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
|
|||
.GetImageFrameSharedPtr()
|
||||
.get()
|
||||
: GetFromNativeHandle<mediapipe::ImageFrame>(packet);
|
||||
return CopyImageDataToByteBuffer(env, image, byte_buffer);
|
||||
}
|
||||
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
if (buffer_data == nullptr || buffer_size < 0) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access"));
|
||||
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)(
|
||||
JNIEnv* env, jobject thiz, jlong packet) {
|
||||
const auto& image_list =
|
||||
GetFromNativeHandle<std::vector<mediapipe::Image>>(packet);
|
||||
return image_list.size();
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
|
||||
JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array,
|
||||
jboolean deep_copy) {
|
||||
const auto& image_list =
|
||||
GetFromNativeHandle<std::vector<mediapipe::Image>>(packet);
|
||||
if (env->GetArrayLength(byte_buffer_array) != image_list.size()) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected ByteBuffer array size: ", image_list.size(),
|
||||
" but get ByteBuffer array size: ",
|
||||
env->GetArrayLength(byte_buffer_array))));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() * image.NumberOfChannels();
|
||||
if (buffer_size != expected_buffer_size) {
|
||||
ThrowIfError(
|
||||
env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected buffer size ", expected_buffer_size,
|
||||
" got: ", buffer_size, ", width ", image.Width(), ", height ",
|
||||
image.Height(), ", channels ", image.NumberOfChannels())));
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (image.ByteDepth()) {
|
||||
case 1: {
|
||||
uint8* data = static_cast<uint8*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
uint16* data = static_cast<uint16*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
float* data = static_cast<float*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
for (int i = 0; i < image_list.size(); ++i) {
|
||||
auto& image = *image_list[i].GetImageFrameSharedPtr().get();
|
||||
if (!image.IsContiguous()) {
|
||||
ThrowIfError(
|
||||
env, absl::InternalError("ImageFrame must store data contiguously to "
|
||||
"be allocated as ByteBuffer."));
|
||||
return false;
|
||||
}
|
||||
if (deep_copy) {
|
||||
jobject byte_buffer = reinterpret_cast<jobject>(
|
||||
env->GetObjectArrayElement(byte_buffer_array, i));
|
||||
if (!CopyImageDataToByteBuffer(env, image, byte_buffer)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() *
|
||||
image.NumberOfChannels();
|
||||
jobject image_data_byte_buffer = env->NewDirectByteBuffer(
|
||||
image.MutablePixelData(), expected_buffer_size);
|
||||
env->SetObjectArrayElement(byte_buffer_array, i, image_data_byte_buffer);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -415,7 +467,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)(
|
|||
int16 value =
|
||||
static_cast<int16>(audio_mat(channel, sample) * kMultiplier);
|
||||
// The java and native has the same byte order, by default is little
|
||||
// Endian, we can safely copy data directly, we have tests to cover this.
|
||||
// Endian, we can safely copy data directly, we have tests to cover
|
||||
// this.
|
||||
env->SetByteArrayRegion(byte_data, offset, 2,
|
||||
reinterpret_cast<const jbyte*>(&value));
|
||||
offset += 2;
|
||||
|
|
|
@ -106,6 +106,17 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageHeight)(JNIEnv* env,
|
|||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
|
||||
JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
|
||||
|
||||
// Return the vector size of std::vector<Image>.
|
||||
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)(
|
||||
JNIEnv* env, jobject thiz, jlong packet);
|
||||
|
||||
// Fill ByteBuffer[] from the Packet of std::vector<Image>.
|
||||
// Before calling this, the byte_buffer_array needs to have the correct
|
||||
// allocated size.
|
||||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
|
||||
JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array,
|
||||
jboolean deep_copy);
|
||||
|
||||
// Before calling this, the byte_buffer needs to have the correct allocated
|
||||
// size.
|
||||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
|
||||
|
|
|
@ -257,10 +257,12 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
// TODO: fix this unit test after image segmenter handled post
|
||||
// processing correctly with rotated image.
|
||||
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
|
@ -271,7 +273,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
ImageSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
||||
segmenter->Segment(image, image_processing_options));
|
||||
EXPECT_EQ(confidence_masks.size(), 21);
|
||||
|
||||
cv::Mat expected_mask =
|
||||
|
|
|
@ -44,6 +44,7 @@ cc_binary(
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
|
@ -176,6 +177,30 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "imagesegmenter",
|
||||
srcs = [
|
||||
"imagesegmenter/ImageSegmenter.java",
|
||||
"imagesegmenter/ImageSegmenterResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "imagesegmenter/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "imageembedder",
|
||||
srcs = [
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagesegmenter">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,462 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import android.content.Context;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs image segmentation on images.
|
||||
*
|
||||
* <p>Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a
|
||||
* user-defined callback function even for the synchronous API. This makes it possible for
|
||||
* ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in
|
||||
* the {@link ImageSegmenterOptions} for all {@link RunningMode}.
|
||||
*
|
||||
* <p>The API expects a TFLite model with,<a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image {@link MPImage}
|
||||
* <ul>
|
||||
* <li>The image that image segmenter runs on.
|
||||
* </ul>
|
||||
* <li>Output ImageSegmenterResult {@link ImageSgmenterResult}
|
||||
* <ul>
|
||||
* <li>An ImageSegmenterResult containing segmented masks.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||
private static final String TAG = ImageSegmenter.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"GROUPED_SEGMENTATION:segmented_mask_out",
|
||||
"IMAGE:image_out",
|
||||
"SEGMENTATION:0:segmentation"));
|
||||
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param segmenterOptions an {@link ImageSegmenterOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageSegmenter} creation.
|
||||
*/
|
||||
public static ImageSegmenter createFromOptions(
|
||||
Context context, ImageSegmenterOptions segmenterOptions) {
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ImageSegmenterResult, MPImage>() {
|
||||
@Override
|
||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||
throws MediaPipeException {
|
||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return ImageSegmenterResult.create(
|
||||
new ArrayList<>(),
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
||||
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int imageFormat =
|
||||
segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK
|
||||
? MPImage.IMAGE_FORMAT_VEC32F1
|
||||
: MPImage.IMAGE_FORMAT_ALPHA;
|
||||
int imageListSize =
|
||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting segmented masks. It usually results from incorrect"
|
||||
+ " options of unsupported OutputType of given model.");
|
||||
}
|
||||
for (ByteBuffer buffer : buffersArray) {
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
|
||||
segmentedMasks.add(builder.build());
|
||||
}
|
||||
|
||||
return ImageSegmenterResult.create(
|
||||
segmentedMasks,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
segmenterOptions.runningMode(),
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
handler.setResultListener(segmenterOptions.resultListener());
|
||||
segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageSegmenterOptions>builder()
|
||||
.setTaskName(ImageSegmenter.class.getSimpleName())
|
||||
.setTaskRunningModeName(segmenterOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(segmenterOptions)
|
||||
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new ImageSegmenter(runner, segmenterOptions.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link ImageSegmenter} from a {@link TaskRunner} and a {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java
|
||||
* doc for input image format.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segment(MPImage image) {
|
||||
segment(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
|
||||
* when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO
|
||||
* update java doc for input image format.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
ImageSegmenterResult unused =
|
||||
(ImageSegmenterResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentForVideo(MPImage image, long timestampMs) {
|
||||
segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
|
||||
* when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
ImageSegmenterResult unused =
|
||||
(ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform hand landmarks detection with default image processing
|
||||
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the image segmenter. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentAsync(MPImage image, long timestampMs) {
|
||||
segmentAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform image segmentation, and the results will be available via the
|
||||
* {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when
|
||||
* the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the image segmenter. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ImageSegmenter}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageSegmenterOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link ImageSegmenterOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the image segmenter task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the image segmenter task. Default to the image mode. Image
|
||||
* segmenter has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for segmenting image on single image inputs.
|
||||
* <li>VIDEO: The mode for segmenting image on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for segmenting image on a live stream of input data, such
|
||||
* as from camera. In this mode, {@code setResultListener} must be called to set up a
|
||||
* listener to receive the recognition results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/**
|
||||
* The locale to use for display names specified through the TFLite Model Metadata, if any.
|
||||
* Defaults to English.
|
||||
*/
|
||||
public abstract Builder setDisplayNamesLocale(String value);
|
||||
|
||||
/** The output type from image segmenter. */
|
||||
public abstract Builder setOutputType(OutputType value);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline
|
||||
* is done processing an image.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract ImageSegmenterOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ImageSegmenterOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the image segmenter is
|
||||
* in the live stream mode.
|
||||
*/
|
||||
public final ImageSegmenterOptions build() {
|
||||
ImageSegmenterOptions options = autoBuild();
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract String displayNamesLocale();
|
||||
|
||||
abstract OutputType outputType();
|
||||
|
||||
abstract ResultListener<ImageSegmenterResult, MPImage> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
/** The output type of segmentation results. */
|
||||
public enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setDisplayNamesLocale("en")
|
||||
.setOutputType(OutputType.CATEGORY_MASK)
|
||||
.setResultListener((result, image) -> {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an {@link ImageSegmenterOptions} to a {@link CalculatorOptions} protobuf message.
|
||||
*/
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
|
||||
.build())
|
||||
.setDisplayNamesLocale(displayNamesLocale());
|
||||
|
||||
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
||||
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
||||
if (outputType() == OutputType.CONFIDENCE_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
|
||||
} else if (outputType() == OutputType.CATEGORY_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
// TODO: remove this once activation is handled in metadata and grpah level.
|
||||
segmenterOptionsBuilder.setActivation(
|
||||
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
|
||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the segmentation results generated by {@link ImageSegmenter}. */
|
||||
@AutoValue
|
||||
public abstract class ImageSegmenterResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
|
||||
*
|
||||
* @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType
|
||||
* is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is
|
||||
* CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_ALPHA format.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static ImageSegmenterResult create(List<MPImage> segmentations, long timestampMs) {
|
||||
return new AutoValue_ImageSegmenterResult(
|
||||
Collections.unmodifiableList(segmentations), timestampMs);
|
||||
}
|
||||
|
||||
public abstract List<MPImage> segmentations();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagesegmentertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="imagesegmentertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.imagesegmentertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,427 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.Color;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link ImageSegmenter}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({ImageSegmenterTest.General.class, ImageSegmenterTest.RunningModeTest.class})
|
||||
public class ImageSegmenterTest {
|
||||
private static final String DEEPLAB_MODEL_FILE = "deeplabv3.tflite";
|
||||
private static final String SELFIE_128x128_MODEL_FILE = "selfie_segm_128_128_3.tflite";
|
||||
private static final String SELFIE_144x256_MODEL_FILE = "selfie_segm_144_256_3.tflite";
|
||||
private static final String CAT_IMAGE = "cat.jpg";
|
||||
private static final float GOLDEN_MASK_SIMILARITY = 0.96f;
|
||||
private static final int MAGNIFICATION_FACTOR = 10;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ImageSegmenterTest {
|
||||
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = "segmentation_input_rotation0.jpg";
|
||||
final String goldenImageName = "segmentation_golden_rotation0.png";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(1);
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
verifyCategoryMask(
|
||||
actualMaskBuffer,
|
||||
expectedMaskBuffer,
|
||||
GOLDEN_MASK_SIMILARITY,
|
||||
MAGNIFICATION_FACTOR);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithConfidenceMask() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWith128x128Segmentation() throws Exception {
|
||||
final String inputImageName = "mozart_square.jpg";
|
||||
final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(2);
|
||||
// Selfie category index 1.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
// TODO: enable this unit test once activation option is supported in metadata.
|
||||
// @Test
|
||||
// public void segment_successWith144x256Segmentation() throws Exception {
|
||||
// final String inputImageName = "mozart_square.jpg";
|
||||
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
|
||||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
// ImageSegmenterOptions options =
|
||||
// ImageSegmenterOptions.builder()
|
||||
// .setBaseOptions(
|
||||
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
|
||||
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
// .setActivation(ImageSegmenterOptions.Activation.NONE)
|
||||
// .setResultListener(
|
||||
// (actualResult, inputImage) -> {
|
||||
// List<MPImage> segmentations = actualResult.segmentations();
|
||||
// assertThat(segmentations.size()).isEqualTo(1);
|
||||
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
// verifyConfidenceMask(
|
||||
// actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
// })
|
||||
// .build();
|
||||
// ImageSegmenter imageSegmenter =
|
||||
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(),
|
||||
// options);
|
||||
// imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
// }
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends ImageSegmenterTest {
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentForVideo(
|
||||
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((result, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentForVideo(
|
||||
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithImageMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithVideoMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithLiveStreamMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
MPImage expectedResult = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(segmenterResult, inputImage) -> {
|
||||
verifyConfidenceMask(
|
||||
segmenterResult.segmentations().get(8),
|
||||
expectedResult,
|
||||
GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
try (ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
imageSegmenter.segmentAsync(image, /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
MPImage expectedResult = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(segmenterResult, inputImage) -> {
|
||||
verifyConfidenceMask(
|
||||
segmenterResult.segmentations().get(8),
|
||||
expectedResult,
|
||||
GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
try (ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
imageSegmenter.segmentAsync(image, /* timestampsMs= */ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageSegmenter.segmentAsync(image, /* timestampsMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void verifyCategoryMask(
|
||||
MPImage actualMask, MPImage goldenMask, float similarityThreshold, int magnificationFactor) {
|
||||
assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth());
|
||||
assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight());
|
||||
ByteBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask);
|
||||
Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask);
|
||||
int consistentPixels = 0;
|
||||
final int numPixels = actualMask.getWidth() * actualMask.getHeight();
|
||||
actualMaskBuffer.rewind();
|
||||
for (int y = 0; y < actualMask.getHeight(); y++) {
|
||||
for (int x = 0; x < actualMask.getWidth(); x++) {
|
||||
// RGB values are the same in the golden mask image.
|
||||
consistentPixels +=
|
||||
actualMaskBuffer.get() * magnificationFactor
|
||||
== Color.red(goldenMaskBitmap.getPixel(x, y))
|
||||
? 1
|
||||
: 0;
|
||||
}
|
||||
}
|
||||
assertThat((float) consistentPixels / numPixels).isGreaterThan(similarityThreshold);
|
||||
}
|
||||
|
||||
private static void verifyConfidenceMask(
|
||||
MPImage actualMask, MPImage goldenMask, float similarityThreshold) {
|
||||
assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth());
|
||||
assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight());
|
||||
FloatBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask).asFloatBuffer();
|
||||
Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask);
|
||||
FloatBuffer goldenMaskBuffer = getByteBufferFromBitmap(goldenMaskBitmap).asFloatBuffer();
|
||||
assertThat(
|
||||
calculateSoftIOU(
|
||||
actualMaskBuffer, goldenMaskBuffer, actualMask.getWidth() * actualMask.getHeight()))
|
||||
.isGreaterThan((double) similarityThreshold);
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static ByteBuffer getByteBufferFromBitmap(Bitmap bitmap) {
|
||||
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(bitmap.getWidth() * bitmap.getHeight() * 4);
|
||||
for (int y = 0; y < bitmap.getHeight(); y++) {
|
||||
for (int x = 0; x < bitmap.getWidth(); x++) {
|
||||
byteBuffer.putFloat((float) Color.red(bitmap.getPixel(x, y)) / 255.f);
|
||||
}
|
||||
}
|
||||
byteBuffer.rewind();
|
||||
return byteBuffer;
|
||||
}
|
||||
|
||||
private static double calculateSum(FloatBuffer m) {
|
||||
m.rewind();
|
||||
double sum = 0;
|
||||
while (m.hasRemaining()) {
|
||||
sum += m.get();
|
||||
}
|
||||
m.rewind();
|
||||
return sum;
|
||||
}
|
||||
|
||||
private static FloatBuffer multiply(FloatBuffer m1, FloatBuffer m2, int bufferSize) {
|
||||
m1.rewind();
|
||||
m2.rewind();
|
||||
FloatBuffer buffer = FloatBuffer.allocate(bufferSize);
|
||||
while (m1.hasRemaining()) {
|
||||
buffer.put(m1.get() * m2.get());
|
||||
}
|
||||
m1.rewind();
|
||||
m2.rewind();
|
||||
buffer.rewind();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
private static double calculateSoftIOU(FloatBuffer m1, FloatBuffer m2, int bufferSize) {
|
||||
double intersectionSum = calculateSum(multiply(m1, m2, bufferSize));
|
||||
double m1m1 = calculateSum(multiply(m1, m1.duplicate(), bufferSize));
|
||||
double m2m2 = calculateSum(multiply(m2, m2.duplicate(), bufferSize));
|
||||
double unionSum = m1m1 + m2m2 - intersectionSum;
|
||||
return unionSum > 0.0 ? intersectionSum / unionSum : 0.0;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user