From 5e6842aa5cf9a57ec234d029ea872658f7aa9ad7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 8 Nov 2022 09:44:15 -0800 Subject: [PATCH] Migrate Java ImageClassifier / TextClassifier to new result format. PiperOrigin-RevId: 486976459 --- .../tasks/components/containers/BUILD | 20 +++- .../tasks/components/containers/Category.java | 11 ++ .../containers/ClassificationEntry.java | 48 -------- .../containers/ClassificationResult.java | 76 +++++++++++++ .../containers/Classifications.java | 38 +++++-- .../com/google/mediapipe/tasks/text/BUILD | 6 +- .../TextClassificationResult.java | 103 ------------------ .../text/textclassifier/TextClassifier.java | 27 ++--- .../textclassifier/TextClassifierResult.java | 55 ++++++++++ .../com/google/mediapipe/tasks/vision/BUILD | 6 +- .../ImageClassificationResult.java | 102 ----------------- .../imageclassifier/ImageClassifier.java | 38 ++++--- .../ImageClassifierResult.java | 55 ++++++++++ .../textclassifier/TextClassifierTest.java | 27 +++-- .../imageclassifier/ImageClassifierTest.java | 73 ++++++------- 15 files changed, 325 insertions(+), 360 deletions(-) delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 9dfa53031..63697229f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -20,6 +20,7 @@ android_library( name = "category", srcs = ["Category.java"], deps = [ + "//mediapipe/framework/formats:classification_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], @@ -36,20 +37,29 @@ android_library( ) android_library( - name = "classification_entry", - srcs = ["ClassificationEntry.java"], + name = "classifications", + srcs = ["Classifications.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ ":category", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], ) android_library( - name = "classifications", - srcs = ["Classifications.java"], + name = "classificationresult", + srcs = ["ClassificationResult.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ - ":classification_entry", + ":classifications", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java index 3b7c41fbe..e955605e4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -15,6 +15,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.ClassificationProto; import java.util.Objects; /** @@ -38,6 +39,16 @@ public abstract class Category { return new AutoValue_Category(score, index, categoryName, displayName); } + /** + * Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf + * message. + * + * @param proto the {@link ClassificationProto.Classification} protobuf message to convert. + */ + public static Category createFromProto(ClassificationProto.Classification proto) { + return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName()); + } + /** The probability score of this label category. */ public abstract float score(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java deleted file mode 100644 index 8fc1daa03..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.containers; - -import com.google.auto.value.AutoValue; -import java.util.Collections; -import java.util.List; - -/** - * Represents a list of predicted categories with an optional timestamp. Typically used as result - * for classification tasks. - */ -@AutoValue -public abstract class ClassificationEntry { - /** - * Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional - * timestamp. - * - * @param categories the list of {@link Category} objects that contain category name, display - * name, score and label index. - * @param timestampMs the {@link long} representing the timestamp for which these categories were - * obtained. - */ - public static ClassificationEntry create(List categories, long timestampMs) { - return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs); - } - - /** The list of predicted {@link Category} objects, sorted by descending score. */ - public abstract List categories(); - - /** - * The timestamp (in milliseconds) associated to the classification entry. This is useful for time - * series use cases, e.g. audio classification. - */ - public abstract long timestampMs(); -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java new file mode 100644 index 000000000..d30099d8b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java @@ -0,0 +1,76 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + */ +@AutoValue +public abstract class ClassificationResult { + + /** + * Creates a {@link ClassificationResult} instance. + * + * @param classifications the list of {@link Classifications} objects containing the predicted + * categories for each head of the model. + * @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + */ + public static ClassificationResult create( + List classifications, Optional timestampMs) { + return new AutoValue_ClassificationResult( + Collections.unmodifiableList(classifications), timestampMs); + } + + /** + * Creates a {@link ClassificationResult} object from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + */ + public static ClassificationResult createFromProto( + ClassificationsProto.ClassificationResult proto) { + List classifications = new ArrayList<>(); + for (ClassificationsProto.Classifications classificationsProto : + proto.getClassificationsList()) { + classifications.add(Classifications.createFromProto(classificationsProto)); + } + Optional timestampMs = + proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty(); + return create(classifications, timestampMs); + } + + /** The classification results for each head of the model. */ + public abstract List classifications(); + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. + * + *

This is only used for classification on time series (e.g. audio classification). In these + * use cases, the amount of data to process might exceed the maximum size that the model can + * process: to solve this, the input data is split into multiple chunks starting at different + * timestamps. + */ + public abstract Optional timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java index 726578729..12f14e628 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java @@ -15,8 +15,12 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.ClassificationProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; /** * Represents the list of classification for a given classifier head. Typically used as a result for @@ -28,25 +32,41 @@ public abstract class Classifications { /** * Creates a {@link Classifications} instance. * - * @param entries the list of {@link ClassificationEntry} objects containing the predicted - * categories. + * @param categories the list of {@link Category} objects containing the predicted categories. * @param headIndex the index of the classifier head. - * @param headName the name of the classifier head. + * @param headName the optional name of the classifier head. */ public static Classifications create( - List entries, int headIndex, String headName) { + List categories, int headIndex, Optional headName) { return new AutoValue_Classifications( - Collections.unmodifiableList(entries), headIndex, headName); + Collections.unmodifiableList(categories), headIndex, headName); } - /** A list of {@link ClassificationEntry} objects. */ - public abstract List entries(); + /** + * Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications} + * protobuf message. + * + * @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert. + */ + public static Classifications createFromProto(ClassificationsProto.Classifications proto) { + List categories = new ArrayList<>(); + for (ClassificationProto.Classification classificationProto : + proto.getClassificationList().getClassificationList()) { + categories.add(Category.createFromProto(classificationProto)); + } + Optional headName = + proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty(); + return create(categories, proto.getHeadIndex(), headName); + } + + /** A list of {@link Category} objects. */ + public abstract List categories(); /** * The index of the classifier head these entries refer to. This is useful for multi-head models. */ public abstract int headIndex(); - /** The name of the classifier head, which is the corresponding tensor metadata name. */ - public abstract String headName(); + /** The optional name of the classifier head, which is the corresponding tensor metadata name. */ + public abstract Optional headName(); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index fa2a547c2..b49169529 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -37,8 +37,8 @@ cc_library( android_library( name = "textclassifier", srcs = [ - "textclassifier/TextClassificationResult.java", "textclassifier/TextClassifier.java", + "textclassifier/TextClassifierResult.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -51,9 +51,7 @@ android_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java deleted file mode 100644 index c1e2446cd..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.text.textclassifier; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.ClassificationEntry; -import com.google.mediapipe.tasks.components.containers.Classifications; -import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; -import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.core.TaskResult; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** Represents the classification results generated by {@link TextClassifier}. */ -@AutoValue -public abstract class TextClassificationResult implements TaskResult { - - /** - * Creates an {@link TextClassificationResult} instance from a {@link - * ClassificationsProto.ClassificationResult} protobuf message. - * - * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf - * message. - * @param timestampMs a timestamp for this result. - */ - // TODO: consolidate output formats across platforms. - static TextClassificationResult create( - ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { - List classifications = new ArrayList<>(); - for (ClassificationsProto.Classifications classificationsProto : - classificationResult.getClassificationsList()) { - classifications.add(classificationsFromProto(classificationsProto)); - } - return new AutoValue_TextClassificationResult( - timestampMs, Collections.unmodifiableList(classifications)); - } - - @Override - public abstract long timestampMs(); - - /** Contains one set of results per classifier head. */ - @SuppressWarnings("AutoValueImmutableFields") - public abstract List classifications(); - - /** - * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. - * - * @param category the {@link CategoryProto.Category} protobuf message to convert. - */ - static Category categoryFromProto(CategoryProto.Category category) { - return Category.create( - category.getScore(), - category.getIndex(), - category.getCategoryName(), - category.getDisplayName()); - } - - /** - * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link - * ClassificationEntry} object. - * - * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. - */ - static ClassificationEntry classificationEntryFromProto( - ClassificationsProto.ClassificationEntry entry) { - List categories = new ArrayList<>(); - for (CategoryProto.Category category : entry.getCategoriesList()) { - categories.add(categoryFromProto(category)); - } - return ClassificationEntry.create(categories, entry.getTimestampMs()); - } - - /** - * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link - * Classifications} object. - * - * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to - * convert. - */ - static Classifications classificationsFromProto( - ClassificationsProto.Classifications classifications) { - List entries = new ArrayList<>(); - for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { - entries.add(classificationEntryFromProto(entry)); - } - return Classifications.create( - entries, classifications.getHeadIndex(), classifications.getHeadName()); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 07a4fa48f..341d6bf91 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable { @SuppressWarnings("ConstantCaseForConstants") private static final List OUTPUT_STREAMS = - Collections.unmodifiableList( - Arrays.asList("CLASSIFICATION_RESULT:classification_result_out")); + Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out")); - private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; private final TaskRunner runner; @@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable { * @throws MediaPipeException if there is an error during {@link TextClassifier} creation. */ public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) { - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override - public TextClassificationResult convertToTaskResult(List packets) { + public TextClassifierResult convertToTaskResult(List packets) { try { - return TextClassificationResult.create( - PacketGetter.getProto( - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), - ClassificationsProto.ClassificationResult.getDefaultInstance()), - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + return TextClassifierResult.create( + ClassificationResult.createFromProto( + PacketGetter.getProto( + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance())), + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()); } catch (IOException e) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); @@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable { * * @param inputText a {@link String} for processing. */ - public TextClassificationResult classify(String inputText) { + public TextClassifierResult classify(String inputText) { Map inputPackets = new HashMap<>(); inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); - return (TextClassificationResult) runner.process(inputPackets); + return (TextClassifierResult) runner.process(inputPackets); } /** Closes and cleans up the {@link TextClassifier}. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java new file mode 100644 index 000000000..64de0ee8d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java @@ -0,0 +1,55 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the classification results generated by {@link TextClassifier}. */ +@AutoValue +public abstract class TextClassifierResult implements TaskResult { + + /** + * Creates an {@link TextClassifierResult} instance. + * + * @param classificationResult the {@link ClassificationResult} object containing one set of + * results per classifier head. + * @param timestampMs a timestamp for this result. + */ + static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) { + return new AutoValue_TextClassifierResult(classificationResult, timestampMs); + } + + /** + * Creates an {@link TextClassifierResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static TextClassifierResult createFromProto( + ClassificationsProto.ClassificationResult proto, long timestampMs) { + return create(ClassificationResult.createFromProto(proto), timestampMs); + } + + /** Contains one set of results per classifier head. */ + public abstract ClassificationResult classificationResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index d15040ae7..146097bbd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -84,8 +84,8 @@ android_library( android_library( name = "imageclassifier", srcs = [ - "imageclassifier/ImageClassificationResult.java", "imageclassifier/ImageClassifier.java", + "imageclassifier/ImageClassifierResult.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -100,9 +100,7 @@ android_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java deleted file mode 100644 index d82a47b86..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.vision.imageclassifier; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.ClassificationEntry; -import com.google.mediapipe.tasks.components.containers.Classifications; -import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; -import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.core.TaskResult; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** Represents the classification results generated by {@link ImageClassifier}. */ -@AutoValue -public abstract class ImageClassificationResult implements TaskResult { - - /** - * Creates an {@link ImageClassificationResult} instance from a {@link - * ClassificationsProto.ClassificationResult} protobuf message. - * - * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf - * message. - * @param timestampMs a timestamp for this result. - */ - // TODO: consolidate output formats across platforms. - static ImageClassificationResult create( - ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { - List classifications = new ArrayList<>(); - for (ClassificationsProto.Classifications classificationsProto : - classificationResult.getClassificationsList()) { - classifications.add(classificationsFromProto(classificationsProto)); - } - return new AutoValue_ImageClassificationResult( - timestampMs, Collections.unmodifiableList(classifications)); - } - - @Override - public abstract long timestampMs(); - - /** Contains one set of results per classifier head. */ - public abstract List classifications(); - - /** - * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. - * - * @param category the {@link CategoryProto.Category} protobuf message to convert. - */ - static Category categoryFromProto(CategoryProto.Category category) { - return Category.create( - category.getScore(), - category.getIndex(), - category.getCategoryName(), - category.getDisplayName()); - } - - /** - * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link - * ClassificationEntry} object. - * - * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. - */ - static ClassificationEntry classificationEntryFromProto( - ClassificationsProto.ClassificationEntry entry) { - List categories = new ArrayList<>(); - for (CategoryProto.Category category : entry.getCategoriesList()) { - categories.add(categoryFromProto(category)); - } - return ClassificationEntry.create(categories, entry.getTimestampMs()); - } - - /** - * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link - * Classifications} object. - * - * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to - * convert. - */ - static Classifications classificationsFromProto( - ClassificationsProto.Classifications classifications) { - List entries = new ArrayList<>(); - for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { - entries.add(classificationEntryFromProto(entry)); - } - return Classifications.create( - entries, classifications.getHeadIndex(), classifications.getHeadName()); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 3863b6fe0..f01546ffc 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi { Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); private static final List OUTPUT_STREAMS = Collections.unmodifiableList( - Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out")); - private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out")); + private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; @@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. */ public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override - public ImageClassificationResult convertToTaskResult(List packets) { + public ImageClassifierResult convertToTaskResult(List packets) { try { - return ImageClassificationResult.create( - PacketGetter.getProto( - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), - ClassificationsProto.ClassificationResult.getDefaultInstance()), - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + return ImageClassifierResult.create( + ClassificationResult.createFromProto( + PacketGetter.getProto( + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance())), + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()); } catch (IOException e) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); @@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(MPImage image) { + public ImageClassifierResult classify(MPImage image) { return classify(image, ImageProcessingOptions.builder().build()); } @@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { * input image before running inference. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify( + public ImageClassifierResult classify( MPImage image, ImageProcessingOptions imageProcessingOptions) { - return (ImageClassificationResult) processImageData(image, imageProcessingOptions); + return (ImageClassifierResult) processImageData(image, imageProcessingOptions); } /** @@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { + public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) { return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } @@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo( + public ImageClassifierResult classifyForVideo( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { - return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); + return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs); } /** @@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * the image classifier is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener resultListener); + ResultListener resultListener); /** Sets an optional {@link ErrorListener}. */ public abstract Builder setErrorListener(ErrorListener errorListener); @@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract Optional classifierOptions(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java new file mode 100644 index 000000000..924542158 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java @@ -0,0 +1,55 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the classification results generated by {@link ImageClassifier}. */ +@AutoValue +public abstract class ImageClassifierResult implements TaskResult { + + /** + * Creates an {@link ImageClassifierResult} instance. + * + * @param classificationResult the {@link ClassificationResult} object containing one set of + * results per classifier head. + * @param timestampMs a timestamp for this result. + */ + static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) { + return new AutoValue_ImageClassifierResult(classificationResult, timestampMs); + } + + /** + * Creates an {@link ImageClassifierResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static ImageClassifierResult createFromProto( + ClassificationsProto.ClassificationResult proto, long timestampMs) { + return create(ClassificationResult.createFromProto(proto), timestampMs); + } + + /** Contains one set of results per classifier head. */ + public abstract ClassificationResult classificationResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index bfca79ced..d3f0e90f3 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -76,7 +76,7 @@ public class TextClassifierTest { public void classify_succeedsWithBert() throws Exception { TextClassifier textClassifier = TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); - TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); assertHasOneHead(negativeResults); assertCategoriesAre( negativeResults, @@ -84,7 +84,7 @@ public class TextClassifierTest { Category.create(0.95630914f, 0, "negative", ""), Category.create(0.04369091f, 1, "positive", ""))); - TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT); assertHasOneHead(positiveResults); assertCategoriesAre( positiveResults, @@ -99,7 +99,7 @@ public class TextClassifierTest { TextClassifier.createFromFile( ApplicationProvider.getApplicationContext(), TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE)); - TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); assertHasOneHead(negativeResults); assertCategoriesAre( negativeResults, @@ -107,7 +107,7 @@ public class TextClassifierTest { Category.create(0.95630914f, 0, "negative", ""), Category.create(0.04369091f, 1, "positive", ""))); - TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT); assertHasOneHead(positiveResults); assertHasOneHead(positiveResults); assertCategoriesAre( @@ -122,7 +122,7 @@ public class TextClassifierTest { TextClassifier textClassifier = TextClassifier.createFromFile( ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); - TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); assertHasOneHead(negativeResults); assertCategoriesAre( negativeResults, @@ -130,7 +130,7 @@ public class TextClassifierTest { Category.create(0.6647746f, 0, "Negative", ""), Category.create(0.33522537f, 1, "Positive", ""))); - TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT); assertHasOneHead(positiveResults); assertCategoriesAre( positiveResults, @@ -139,16 +139,15 @@ public class TextClassifierTest { Category.create(0.48799595f, 1, "Positive", ""))); } - private static void assertHasOneHead(TextClassificationResult results) { - assertThat(results.classifications()).hasSize(1); - assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); - assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); - assertThat(results.classifications().get(0).entries()).hasSize(1); + private static void assertHasOneHead(TextClassifierResult results) { + assertThat(results.classificationResult().classifications()).hasSize(1); + assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classificationResult().classifications().get(0).headName().get()) + .isEqualTo("probability"); } - private static void assertCategoriesAre( - TextClassificationResult results, List categories) { - assertThat(results.classifications().get(0).entries().get(0).categories()) + private static void assertCategoriesAre(TextClassifierResult results, List categories) { + assertThat(results.classificationResult().classifications().get(0).categories()) .isEqualTo(categories); } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 99ebd9777..69820ce2d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -91,11 +91,12 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromFile( ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); - assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001); - assertThat(results.classifications().get(0).entries().get(0).categories().get(0)) + assertHasOneHead(results); + assertThat(results.classificationResult().classifications().get(0).categories()) + .hasSize(1001); + assertThat(results.classificationResult().classifications().get(0).categories().get(0)) .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", "")); } @@ -108,9 +109,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -128,9 +129,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", ""))); } @@ -144,9 +145,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -166,9 +167,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -190,9 +191,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -214,10 +215,10 @@ public class ImageClassifierTest { RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); - ImageClassificationResult results = + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); } @@ -233,10 +234,10 @@ public class ImageClassifierTest { ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRotationDegrees(-90).build(); - ImageClassificationResult results = + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -258,11 +259,11 @@ public class ImageClassifierTest { RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); - ImageClassificationResult results = + ImageClassifierResult results = imageClassifier.classify( getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); } @@ -391,9 +392,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); } @@ -410,9 +411,8 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassificationResult results = - imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); - assertHasOneHeadAndOneTimestamp(results, i); + ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); } @@ -478,24 +478,17 @@ public class ImageClassifierTest { return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); } - private static void assertHasOneHeadAndOneTimestamp( - ImageClassificationResult results, long timestampMs) { - assertThat(results.classifications()).hasSize(1); - assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); - assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); - assertThat(results.classifications().get(0).entries()).hasSize(1); - assertThat(results.classifications().get(0).entries().get(0).timestampMs()) - .isEqualTo(timestampMs); + private static void assertHasOneHead(ImageClassifierResult results) { + assertThat(results.classificationResult().classifications()).hasSize(1); + assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classificationResult().classifications().get(0).headName().get()) + .isEqualTo("probability"); } private static void assertCategoriesAre( - ImageClassificationResult results, List categories) { - assertThat(results.classifications().get(0).entries().get(0).categories()) - .hasSize(categories.size()); - for (int i = 0; i < categories.size(); i++) { - assertThat(results.classifications().get(0).entries().get(0).categories().get(i)) - .isEqualTo(categories.get(i)); - } + ImageClassifierResult results, List categories) { + assertThat(results.classificationResult().classifications().get(0).categories()) + .isEqualTo(categories); } private static void assertImageSizeIsExpected(MPImage inputImage) {