Add Java ImageClassifier API.
PiperOrigin-RevId: 480656683
This commit is contained in:
parent
cbbd4718a0
commit
51a7606083
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.vision.imageclassifier">
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "imageclassifier",
|
||||||
|
srcs = [
|
||||||
|
"ImageClassificationResult.java",
|
||||||
|
"ImageClassifier.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = ":AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,102 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
||||||
|
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Represents the classification results generated by {@link ImageClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ImageClassificationResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassificationResult} instance from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
||||||
|
* message.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
// TODO: consolidate output formats across platforms.
|
||||||
|
static ImageClassificationResult create(
|
||||||
|
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
||||||
|
List<Classifications> classifications = new ArrayList<>();
|
||||||
|
for (ClassificationsProto.Classifications classificationsProto :
|
||||||
|
classificationResult.getClassificationsList()) {
|
||||||
|
classifications.add(classificationsFromProto(classificationsProto));
|
||||||
|
}
|
||||||
|
return new AutoValue_ImageClassificationResult(
|
||||||
|
timestampMs, Collections.unmodifiableList(classifications));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
|
||||||
|
/** Contains one set of results per classifier head. */
|
||||||
|
public abstract List<Classifications> classifications();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
||||||
|
*
|
||||||
|
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
static Category categoryFromProto(CategoryProto.Category category) {
|
||||||
|
return Category.create(
|
||||||
|
category.getScore(),
|
||||||
|
category.getIndex(),
|
||||||
|
category.getCategoryName(),
|
||||||
|
category.getDisplayName());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
||||||
|
* ClassificationEntry} object.
|
||||||
|
*
|
||||||
|
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
static ClassificationEntry classificationEntryFromProto(
|
||||||
|
ClassificationsProto.ClassificationEntry entry) {
|
||||||
|
List<Category> categories = new ArrayList<>();
|
||||||
|
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
||||||
|
categories.add(categoryFromProto(category));
|
||||||
|
}
|
||||||
|
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
||||||
|
* Classifications} object.
|
||||||
|
*
|
||||||
|
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
||||||
|
* convert.
|
||||||
|
*/
|
||||||
|
static Classifications classificationsFromProto(
|
||||||
|
ClassificationsProto.Classifications classifications) {
|
||||||
|
List<ClassificationEntry> entries = new ArrayList<>();
|
||||||
|
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
||||||
|
entries.add(classificationEntryFromProto(entry));
|
||||||
|
}
|
||||||
|
return Classifications.create(
|
||||||
|
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,456 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
|
import com.google.mediapipe.framework.image.Image;
|
||||||
|
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
|
||||||
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs classification on images.
|
||||||
|
*
|
||||||
|
* <p>The API expects a TFLite model with optional, but strongly recommended, <a
|
||||||
|
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||||
|
*
|
||||||
|
* <p>The API supports models with one image input tensor and one or more output tensors. To be more
|
||||||
|
* specific, here are the requirements.
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
|
||||||
|
* <ul>
|
||||||
|
* <li>image input of size {@code [batch x height x width x channels]}.
|
||||||
|
* <li>batch inference is not supported ({@code batch} is required to be 1).
|
||||||
|
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
|
||||||
|
* <li>if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the
|
||||||
|
* metadata for input normalization.
|
||||||
|
* </ul>
|
||||||
|
* <li>At least one output tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) with:
|
||||||
|
* <ul>
|
||||||
|
* <li>{@code N} classes and either 2 or 4 dimensions, i.e. {@code [1 x N]} or {@code [1 x 1
|
||||||
|
* x 1 x N]}
|
||||||
|
* <li>optional (but recommended) label map(s) as AssociatedFile-s with type
|
||||||
|
* TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
|
||||||
|
* any) is used to fill the {@code class_name} field of the results. The {@code
|
||||||
|
* display_name} field is filled from the AssociatedFile (if any) whose locale matches
|
||||||
|
* the {@code display_names_locale} field of the {@code ImageClassifierOptions} used at
|
||||||
|
* creation time ("en" by default, i.e. English). If none of these are available, only
|
||||||
|
* the {@code index} field of the results will be filled.
|
||||||
|
* <li>optional score calibration can be attached using ScoreCalibrationOptions and an
|
||||||
|
* AssociatedFile with type TENSOR_AXIS_SCORE_CALIBRATION. See <a
|
||||||
|
* href="https://github.com/google/mediapipe/blob/master/mediapipe/tasks/metadata/metadata_schema.fbs">
|
||||||
|
* metadata_schema.fbs</a> for more details.
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>An example of such model can be found <a
|
||||||
|
* href="https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1">
|
||||||
|
* TensorFlow Hub</a>.
|
||||||
|
*/
|
||||||
|
public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
|
private static final String TAG = ImageClassifier.class.getSimpleName();
|
||||||
|
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||||
|
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||||
|
private static final List<String> INPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out"));
|
||||||
|
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||||
|
private static final String TASK_GRAPH_NAME =
|
||||||
|
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||||
|
|
||||||
|
static {
|
||||||
|
ProtoUtil.registerTypeName(
|
||||||
|
ClassificationsProto.ClassificationResult.class,
|
||||||
|
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassifier} instance from a model file and default {@link
|
||||||
|
* ImageClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the classification model in the assets.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static ImageClassifier createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassifier} instance from a model file and default {@link
|
||||||
|
* ImageClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the classification model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static ImageClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassifier} instance from a model buffer and default {@link
|
||||||
|
* ImageClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
|
||||||
|
* classification model.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static ImageClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, ImageClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassifier} instance from an {@link ImageClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param options an {@link ImageClassifierOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
|
||||||
|
OutputHandler<ImageClassificationResult, Image> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<ImageClassificationResult, Image>() {
|
||||||
|
@Override
|
||||||
|
public ImageClassificationResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
try {
|
||||||
|
return ImageClassificationResult.create(
|
||||||
|
PacketGetter.getProto(
|
||||||
|
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
||||||
|
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||||
|
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
} catch (InvalidProtocolBufferException e) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image convertToTaskInput(List<Packet> packets) {
|
||||||
|
return new BitmapImageBuilder(
|
||||||
|
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
options.resultListener().ifPresent(handler::setResultListener);
|
||||||
|
options.errorListener().ifPresent(handler::setErrorListener);
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<ImageClassifierOptions>builder()
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(INPUT_STREAMS)
|
||||||
|
.setOutputStreams(OUTPUT_STREAMS)
|
||||||
|
.setTaskOptions(options)
|
||||||
|
.setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
return new ImageClassifier(runner, options.runningMode());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize an {@link ImageClassifier} from a {@link TaskRunner} and {@link
|
||||||
|
* RunningMode}.
|
||||||
|
*
|
||||||
|
* @param taskRunner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||||
|
*/
|
||||||
|
private ImageClassifier(TaskRunner taskRunner, RunningMode runningMode) {
|
||||||
|
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs classification on the provided single image. Only use this method when the {@link
|
||||||
|
* ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ImageClassificationResult classify(Image inputImage) {
|
||||||
|
return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs classification on the provided single image and region-of-interest. Only use this
|
||||||
|
* method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
||||||
|
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ImageClassificationResult classify(Image inputImage, RectF roi) {
|
||||||
|
return (ImageClassificationResult) processImageData(inputImage, roi);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs classification on the provided video frame. Only use this method when the {@link
|
||||||
|
* ImageClassifier} is created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) {
|
||||||
|
return (ImageClassificationResult)
|
||||||
|
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs classification on the provided video frame with additional region-of-interest. Only
|
||||||
|
* use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
||||||
|
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
||||||
|
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ImageClassificationResult classifyForVideo(
|
||||||
|
Image inputImage, RectF roi, long inputTimestampMs) {
|
||||||
|
return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data to perform classification, and the results will be available via the
|
||||||
|
* {@link ResultListener} provided in the {@link ImageClassifierOptions}. Only use this method
|
||||||
|
* when the {@link ImageClassifier} is created with {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void classifyAsync(Image inputImage, long inputTimestampMs) {
|
||||||
|
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data and additional region-of-interest to perform classification, and the
|
||||||
|
* results will be available via the {@link ResultListener} provided in the {@link
|
||||||
|
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
|
||||||
|
* {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
||||||
|
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
||||||
|
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) {
|
||||||
|
sendLiveStreamData(inputImage, roi, inputTimestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up and {@link ImageClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class ImageClassifierOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link ImageClassifierOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the {@link BaseOptions} for the image classifier task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link RunningMode} for the image classifier task. Default to the image mode.
|
||||||
|
* Image classifier has three modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>IMAGE: The mode for performing classification on single image inputs.
|
||||||
|
* <li>VIDEO: The mode for performing classification on the decoded frames of a video.
|
||||||
|
* <li>LIVE_STREAM: The mode for for performing classification on a live stream of input
|
||||||
|
* data, such as from camera. In this mode, {@code setResultListener} must be called to
|
||||||
|
* set up a listener to receive the classification results asynchronously.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||||
|
* score threshold, number of results, etc.
|
||||||
|
*/
|
||||||
|
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||||
|
* the image classifier is in the live stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Builder setResultListener(
|
||||||
|
ResultListener<ImageClassificationResult, Image> resultListener);
|
||||||
|
|
||||||
|
/** Sets an optional {@link ErrorListener}. */
|
||||||
|
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||||
|
|
||||||
|
abstract ImageClassifierOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link ImageClassifierOptions} instance. *
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||||
|
* properly configured. The result listener should only be set when the image classifier
|
||||||
|
* is in the live stream mode.
|
||||||
|
*/
|
||||||
|
public final ImageClassifierOptions build() {
|
||||||
|
ImageClassifierOptions options = autoBuild();
|
||||||
|
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||||
|
if (!options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The image classifier is in the live stream mode, a user-defined result listener"
|
||||||
|
+ " must be provided in the ImageClassifierOptions.");
|
||||||
|
}
|
||||||
|
} else if (options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The image classifier is in the image or video mode, a user-defined result listener"
|
||||||
|
+ " shouldn't be provided in ImageClassifierOptions.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
|
abstract Optional<ClassifierOptions> classifierOptions();
|
||||||
|
|
||||||
|
abstract Optional<ResultListener<ImageClassificationResult, Image>> resultListener();
|
||||||
|
|
||||||
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder()
|
||||||
|
.setRunningMode(RunningMode.IMAGE);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link ImageClassifierOptions} to a {@link CalculatorOptions} protobuf message.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
|
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||||
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||||
|
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder()
|
||||||
|
.setBaseOptions(baseOptionsBuilder);
|
||||||
|
if (classifierOptions().isPresent()) {
|
||||||
|
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||||
|
}
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext,
|
||||||
|
taskOptionsBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Creates a RectF covering the full image. */
|
||||||
|
private static RectF buildFullImageRectF() {
|
||||||
|
return new RectF(0, 0, 1, 1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.vision.imageclassifiertest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="imageclassifiertest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.vision.imageclassifiertest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable this in OSS
|
|
@ -0,0 +1,445 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import android.content.res.AssetManager;
|
||||||
|
import android.graphics.BitmapFactory;
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import androidx.test.core.app.ApplicationProvider;
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
|
import com.google.mediapipe.framework.image.Image;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TestUtils;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Suite;
|
||||||
|
import org.junit.runners.Suite.SuiteClasses;
|
||||||
|
|
||||||
|
/** Test for {@link ImageClassifier}/ */
|
||||||
|
@RunWith(Suite.class)
|
||||||
|
@SuiteClasses({ImageClassifierTest.General.class, ImageClassifierTest.RunningModeTest.class})
|
||||||
|
public class ImageClassifierTest {
|
||||||
|
private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
|
||||||
|
private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
|
||||||
|
private static final String BURGER_IMAGE = "burger.jpg";
|
||||||
|
private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class General extends ImageClassifierTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingModel() throws Exception {
|
||||||
|
String nonExistentFile = "/path/to/non/existent/file";
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
ImageClassifier.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(), nonExistentFile));
|
||||||
|
assertThat(exception).hasMessageThat().contains(nonExistentFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithInvalidModelBuffer() throws Exception {
|
||||||
|
// Create a non-direct model ByteBuffer.
|
||||||
|
ByteBuffer modelBuffer =
|
||||||
|
TestUtils.loadToNonDirectByteBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
|
||||||
|
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageClassifier.createFromBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(), modelBuffer));
|
||||||
|
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithNoOptions() throws Exception {
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001);
|
||||||
|
assertThat(results.classifications().get(0).entries().get(0).categories().get(0))
|
||||||
|
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithFloatModel() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results,
|
||||||
|
Arrays.asList(
|
||||||
|
Category.create(0.7952058f, 934, "cheeseburger", ""),
|
||||||
|
Category.create(0.027329788f, 932, "bagel", ""),
|
||||||
|
Category.create(0.019334773f, 925, "guacamole", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithQuantizedModel() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithScoreThreshold() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results,
|
||||||
|
Arrays.asList(
|
||||||
|
Category.create(0.7952058f, 934, "cheeseburger", ""),
|
||||||
|
Category.create(0.027329788f, 932, "bagel", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithAllowlist() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(
|
||||||
|
ClassifierOptions.builder()
|
||||||
|
.setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf"))
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results,
|
||||||
|
Arrays.asList(
|
||||||
|
Category.create(0.7952058f, 934, "cheeseburger", ""),
|
||||||
|
Category.create(0.019334773f, 925, "guacamole", ""),
|
||||||
|
Category.create(0.006279315f, 963, "meat loaf", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithDenylist() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(
|
||||||
|
ClassifierOptions.builder()
|
||||||
|
.setMaxResults(3)
|
||||||
|
.setCategoryDenylist(Arrays.asList("bagel"))
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results,
|
||||||
|
Arrays.asList(
|
||||||
|
Category.create(0.7952058f, 934, "cheeseburger", ""),
|
||||||
|
Category.create(0.019334773f, 925, "guacamole", ""),
|
||||||
|
Category.create(0.006279315f, 963, "meat loaf", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithRegionOfInterest() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
// RectF around the soccer ball.
|
||||||
|
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
||||||
|
ImageClassificationResult results =
|
||||||
|
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi);
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class RunningModeTest extends ImageClassifierTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||||
|
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setRunningMode(mode)
|
||||||
|
.setResultListener((imageClassificationResult, inputImage) -> {})
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("a user-defined result listener shouldn't be provided");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("a user-defined result listener must be provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener((imageClassificationResult, inputImage) -> {})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithImageMode() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithVideoMode() throws Exception {
|
||||||
|
Image image = getImageFromAsset(BURGER_IMAGE);
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
ImageClassificationResult results = imageClassifier.classifyForVideo(image, i);
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, i);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||||
|
Image image = getImageFromAsset(BURGER_IMAGE);
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener(
|
||||||
|
(imageClassificationResult, inputImage) -> {
|
||||||
|
assertCategoriesAre(
|
||||||
|
imageClassificationResult,
|
||||||
|
Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
|
assertImageSizeIsExpected(inputImage);
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
try (ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
|
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithLiveStreamMode() throws Exception {
|
||||||
|
Image image = getImageFromAsset(BURGER_IMAGE);
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener(
|
||||||
|
(imageClassificationResult, inputImage) -> {
|
||||||
|
assertCategoriesAre(
|
||||||
|
imageClassificationResult,
|
||||||
|
Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
|
assertImageSizeIsExpected(inputImage);
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
try (ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
imageClassifier.classifyAsync(image, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Image getImageFromAsset(String filePath) throws Exception {
|
||||||
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertHasOneHeadAndOneTimestamp(
|
||||||
|
ImageClassificationResult results, long timestampMs) {
|
||||||
|
assertThat(results.classifications()).hasSize(1);
|
||||||
|
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
||||||
|
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
||||||
|
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
||||||
|
assertThat(results.classifications().get(0).entries().get(0).timestampMs())
|
||||||
|
.isEqualTo(timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertCategoriesAre(
|
||||||
|
ImageClassificationResult results, List<Category> categories) {
|
||||||
|
assertThat(results.classifications().get(0).entries().get(0).categories())
|
||||||
|
.hasSize(categories.size());
|
||||||
|
for (int i = 0; i < categories.size(); i++) {
|
||||||
|
assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
|
||||||
|
.isEqualTo(categories.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertImageSizeIsExpected(Image inputImage) {
|
||||||
|
assertThat(inputImage).isNotNull();
|
||||||
|
assertThat(inputImage.getWidth()).isEqualTo(480);
|
||||||
|
assertThat(inputImage.getHeight()).isEqualTo(325);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user