diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 9dfa53031..63697229f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -20,6 +20,7 @@ android_library( name = "category", srcs = ["Category.java"], deps = [ + "//mediapipe/framework/formats:classification_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], @@ -36,20 +37,29 @@ android_library( ) android_library( - name = "classification_entry", - srcs = ["ClassificationEntry.java"], + name = "classifications", + srcs = ["Classifications.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ ":category", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], ) android_library( - name = "classifications", - srcs = ["Classifications.java"], + name = "classificationresult", + srcs = ["ClassificationResult.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ - ":classification_entry", + ":classifications", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//third_party:autovalue", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java index 3b7c41fbe..e955605e4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -15,6 +15,7 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.ClassificationProto; import java.util.Objects; /** @@ -38,6 +39,16 @@ public abstract class Category { return new AutoValue_Category(score, index, categoryName, displayName); } + /** + * Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf + * message. + * + * @param proto the {@link ClassificationProto.Classification} protobuf message to convert. + */ + public static Category createFromProto(ClassificationProto.Classification proto) { + return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName()); + } + /** The probability score of this label category. */ public abstract float score(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java deleted file mode 100644 index 8fc1daa03..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.containers; - -import com.google.auto.value.AutoValue; -import java.util.Collections; -import java.util.List; - -/** - * Represents a list of predicted categories with an optional timestamp. Typically used as result - * for classification tasks. - */ -@AutoValue -public abstract class ClassificationEntry { - /** - * Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional - * timestamp. - * - * @param categories the list of {@link Category} objects that contain category name, display - * name, score and label index. - * @param timestampMs the {@link long} representing the timestamp for which these categories were - * obtained. - */ - public static ClassificationEntry create(List categories, long timestampMs) { - return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs); - } - - /** The list of predicted {@link Category} objects, sorted by descending score. */ - public abstract List categories(); - - /** - * The timestamp (in milliseconds) associated to the classification entry. This is useful for time - * series use cases, e.g. audio classification. - */ - public abstract long timestampMs(); -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java new file mode 100644 index 000000000..d30099d8b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationResult.java @@ -0,0 +1,76 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Represents the classification results of a model. Typically used as a result for classification + * tasks. + */ +@AutoValue +public abstract class ClassificationResult { + + /** + * Creates a {@link ClassificationResult} instance. + * + * @param classifications the list of {@link Classifications} objects containing the predicted + * categories for each head of the model. + * @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + */ + public static ClassificationResult create( + List classifications, Optional timestampMs) { + return new AutoValue_ClassificationResult( + Collections.unmodifiableList(classifications), timestampMs); + } + + /** + * Creates a {@link ClassificationResult} object from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + */ + public static ClassificationResult createFromProto( + ClassificationsProto.ClassificationResult proto) { + List classifications = new ArrayList<>(); + for (ClassificationsProto.Classifications classificationsProto : + proto.getClassificationsList()) { + classifications.add(Classifications.createFromProto(classificationsProto)); + } + Optional timestampMs = + proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty(); + return create(classifications, timestampMs); + } + + /** The classification results for each head of the model. */ + public abstract List classifications(); + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to + * these results. + * + *

This is only used for classification on time series (e.g. audio classification). In these + * use cases, the amount of data to process might exceed the maximum size that the model can + * process: to solve this, the input data is split into multiple chunks starting at different + * timestamps. + */ + public abstract Optional timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java index 726578729..12f14e628 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java @@ -15,8 +15,12 @@ package com.google.mediapipe.tasks.components.containers; import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.ClassificationProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Optional; /** * Represents the list of classification for a given classifier head. Typically used as a result for @@ -28,25 +32,41 @@ public abstract class Classifications { /** * Creates a {@link Classifications} instance. * - * @param entries the list of {@link ClassificationEntry} objects containing the predicted - * categories. + * @param categories the list of {@link Category} objects containing the predicted categories. * @param headIndex the index of the classifier head. - * @param headName the name of the classifier head. + * @param headName the optional name of the classifier head. */ public static Classifications create( - List entries, int headIndex, String headName) { + List categories, int headIndex, Optional headName) { return new AutoValue_Classifications( - Collections.unmodifiableList(entries), headIndex, headName); + Collections.unmodifiableList(categories), headIndex, headName); } - /** A list of {@link ClassificationEntry} objects. */ - public abstract List entries(); + /** + * Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications} + * protobuf message. + * + * @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert. + */ + public static Classifications createFromProto(ClassificationsProto.Classifications proto) { + List categories = new ArrayList<>(); + for (ClassificationProto.Classification classificationProto : + proto.getClassificationList().getClassificationList()) { + categories.add(Category.createFromProto(classificationProto)); + } + Optional headName = + proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty(); + return create(categories, proto.getHeadIndex(), headName); + } + + /** A list of {@link Category} objects. */ + public abstract List categories(); /** * The index of the classifier head these entries refer to. This is useful for multi-head models. */ public abstract int headIndex(); - /** The name of the classifier head, which is the corresponding tensor metadata name. */ - public abstract String headName(); + /** The optional name of the classifier head, which is the corresponding tensor metadata name. */ + public abstract Optional headName(); } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 16308e71f..62c424f66 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -234,7 +234,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index fa2a547c2..b49169529 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -37,8 +37,8 @@ cc_library( android_library( name = "textclassifier", srcs = [ - "textclassifier/TextClassificationResult.java", "textclassifier/TextClassifier.java", + "textclassifier/TextClassifierResult.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -51,9 +51,7 @@ android_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java deleted file mode 100644 index c1e2446cd..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.text.textclassifier; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.ClassificationEntry; -import com.google.mediapipe.tasks.components.containers.Classifications; -import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; -import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.core.TaskResult; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** Represents the classification results generated by {@link TextClassifier}. */ -@AutoValue -public abstract class TextClassificationResult implements TaskResult { - - /** - * Creates an {@link TextClassificationResult} instance from a {@link - * ClassificationsProto.ClassificationResult} protobuf message. - * - * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf - * message. - * @param timestampMs a timestamp for this result. - */ - // TODO: consolidate output formats across platforms. - static TextClassificationResult create( - ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { - List classifications = new ArrayList<>(); - for (ClassificationsProto.Classifications classificationsProto : - classificationResult.getClassificationsList()) { - classifications.add(classificationsFromProto(classificationsProto)); - } - return new AutoValue_TextClassificationResult( - timestampMs, Collections.unmodifiableList(classifications)); - } - - @Override - public abstract long timestampMs(); - - /** Contains one set of results per classifier head. */ - @SuppressWarnings("AutoValueImmutableFields") - public abstract List classifications(); - - /** - * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. - * - * @param category the {@link CategoryProto.Category} protobuf message to convert. - */ - static Category categoryFromProto(CategoryProto.Category category) { - return Category.create( - category.getScore(), - category.getIndex(), - category.getCategoryName(), - category.getDisplayName()); - } - - /** - * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link - * ClassificationEntry} object. - * - * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. - */ - static ClassificationEntry classificationEntryFromProto( - ClassificationsProto.ClassificationEntry entry) { - List categories = new ArrayList<>(); - for (CategoryProto.Category category : entry.getCategoriesList()) { - categories.add(categoryFromProto(category)); - } - return ClassificationEntry.create(categories, entry.getTimestampMs()); - } - - /** - * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link - * Classifications} object. - * - * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to - * convert. - */ - static Classifications classificationsFromProto( - ClassificationsProto.Classifications classifications) { - List entries = new ArrayList<>(); - for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { - entries.add(classificationEntryFromProto(entry)); - } - return Classifications.create( - entries, classifications.getHeadIndex(), classifications.getHeadName()); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 07a4fa48f..341d6bf91 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable { @SuppressWarnings("ConstantCaseForConstants") private static final List OUTPUT_STREAMS = - Collections.unmodifiableList( - Arrays.asList("CLASSIFICATION_RESULT:classification_result_out")); + Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out")); - private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; private final TaskRunner runner; @@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable { * @throws MediaPipeException if there is an error during {@link TextClassifier} creation. */ public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) { - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override - public TextClassificationResult convertToTaskResult(List packets) { + public TextClassifierResult convertToTaskResult(List packets) { try { - return TextClassificationResult.create( - PacketGetter.getProto( - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), - ClassificationsProto.ClassificationResult.getDefaultInstance()), - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + return TextClassifierResult.create( + ClassificationResult.createFromProto( + PacketGetter.getProto( + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance())), + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()); } catch (IOException e) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); @@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable { * * @param inputText a {@link String} for processing. */ - public TextClassificationResult classify(String inputText) { + public TextClassifierResult classify(String inputText) { Map inputPackets = new HashMap<>(); inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); - return (TextClassificationResult) runner.process(inputPackets); + return (TextClassifierResult) runner.process(inputPackets); } /** Closes and cleans up the {@link TextClassifier}. */ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java new file mode 100644 index 000000000..64de0ee8d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifierResult.java @@ -0,0 +1,55 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the classification results generated by {@link TextClassifier}. */ +@AutoValue +public abstract class TextClassifierResult implements TaskResult { + + /** + * Creates an {@link TextClassifierResult} instance. + * + * @param classificationResult the {@link ClassificationResult} object containing one set of + * results per classifier head. + * @param timestampMs a timestamp for this result. + */ + static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) { + return new AutoValue_TextClassifierResult(classificationResult, timestampMs); + } + + /** + * Creates an {@link TextClassifierResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static TextClassifierResult createFromProto( + ClassificationsProto.ClassificationResult proto, long timestampMs) { + return create(ClassificationResult.createFromProto(proto), timestampMs); + } + + /** Contains one set of results per classifier head. */ + public abstract ClassificationResult classificationResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index d15040ae7..146097bbd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -84,8 +84,8 @@ android_library( android_library( name = "imageclassifier", srcs = [ - "imageclassifier/ImageClassificationResult.java", "imageclassifier/ImageClassifier.java", + "imageclassifier/ImageClassifierResult.java", ], javacopts = [ "-Xep:AndroidJdkLibsChecker:OFF", @@ -100,9 +100,7 @@ android_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java deleted file mode 100644 index d82a47b86..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.vision.imageclassifier; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.ClassificationEntry; -import com.google.mediapipe.tasks.components.containers.Classifications; -import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; -import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.core.TaskResult; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** Represents the classification results generated by {@link ImageClassifier}. */ -@AutoValue -public abstract class ImageClassificationResult implements TaskResult { - - /** - * Creates an {@link ImageClassificationResult} instance from a {@link - * ClassificationsProto.ClassificationResult} protobuf message. - * - * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf - * message. - * @param timestampMs a timestamp for this result. - */ - // TODO: consolidate output formats across platforms. - static ImageClassificationResult create( - ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { - List classifications = new ArrayList<>(); - for (ClassificationsProto.Classifications classificationsProto : - classificationResult.getClassificationsList()) { - classifications.add(classificationsFromProto(classificationsProto)); - } - return new AutoValue_ImageClassificationResult( - timestampMs, Collections.unmodifiableList(classifications)); - } - - @Override - public abstract long timestampMs(); - - /** Contains one set of results per classifier head. */ - public abstract List classifications(); - - /** - * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. - * - * @param category the {@link CategoryProto.Category} protobuf message to convert. - */ - static Category categoryFromProto(CategoryProto.Category category) { - return Category.create( - category.getScore(), - category.getIndex(), - category.getCategoryName(), - category.getDisplayName()); - } - - /** - * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link - * ClassificationEntry} object. - * - * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. - */ - static ClassificationEntry classificationEntryFromProto( - ClassificationsProto.ClassificationEntry entry) { - List categories = new ArrayList<>(); - for (CategoryProto.Category category : entry.getCategoriesList()) { - categories.add(categoryFromProto(category)); - } - return ClassificationEntry.create(categories, entry.getTimestampMs()); - } - - /** - * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link - * Classifications} object. - * - * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to - * convert. - */ - static Classifications classificationsFromProto( - ClassificationsProto.Classifications classifications) { - List entries = new ArrayList<>(); - for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { - entries.add(classificationEntryFromProto(entry)); - } - return Classifications.create( - entries, classifications.getHeadIndex(), classifications.getHeadName()); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 3863b6fe0..f01546ffc 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi { Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); private static final List OUTPUT_STREAMS = Collections.unmodifiableList( - Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out")); - private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out")); + private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0; private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; @@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. */ public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override - public ImageClassificationResult convertToTaskResult(List packets) { + public ImageClassifierResult convertToTaskResult(List packets) { try { - return ImageClassificationResult.create( - PacketGetter.getProto( - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), - ClassificationsProto.ClassificationResult.getDefaultInstance()), - packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + return ImageClassifierResult.create( + ClassificationResult.createFromProto( + PacketGetter.getProto( + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance())), + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()); } catch (IOException e) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); @@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @param image a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(MPImage image) { + public ImageClassifierResult classify(MPImage image) { return classify(image, ImageProcessingOptions.builder().build()); } @@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { * input image before running inference. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify( + public ImageClassifierResult classify( MPImage image, ImageProcessingOptions imageProcessingOptions) { - return (ImageClassificationResult) processImageData(image, imageProcessingOptions); + return (ImageClassifierResult) processImageData(image, imageProcessingOptions); } /** @@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { + public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) { return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } @@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo( + public ImageClassifierResult classifyForVideo( MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { - return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); + return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs); } /** @@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * the image classifier is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener resultListener); + ResultListener resultListener); /** Sets an optional {@link ErrorListener}. */ public abstract Builder setErrorListener(ErrorListener errorListener); @@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract Optional classifierOptions(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java new file mode 100644 index 000000000..924542158 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierResult.java @@ -0,0 +1,55 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the classification results generated by {@link ImageClassifier}. */ +@AutoValue +public abstract class ImageClassifierResult implements TaskResult { + + /** + * Creates an {@link ImageClassifierResult} instance. + * + * @param classificationResult the {@link ClassificationResult} object containing one set of + * results per classifier head. + * @param timestampMs a timestamp for this result. + */ + static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) { + return new AutoValue_ImageClassifierResult(classificationResult, timestampMs); + } + + /** + * Creates an {@link ImageClassifierResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static ImageClassifierResult createFromProto( + ClassificationsProto.ClassificationResult proto, long timestampMs) { + return create(ClassificationResult.createFromProto(proto), timestampMs); + } + + /** Contains one set of results per classifier head. */ + public abstract ClassificationResult classificationResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index bfca79ced..d3f0e90f3 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -76,7 +76,7 @@ public class TextClassifierTest { public void classify_succeedsWithBert() throws Exception { TextClassifier textClassifier = TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); - TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); assertHasOneHead(negativeResults); assertCategoriesAre( negativeResults, @@ -84,7 +84,7 @@ public class TextClassifierTest { Category.create(0.95630914f, 0, "negative", ""), Category.create(0.04369091f, 1, "positive", ""))); - TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT); assertHasOneHead(positiveResults); assertCategoriesAre( positiveResults, @@ -99,7 +99,7 @@ public class TextClassifierTest { TextClassifier.createFromFile( ApplicationProvider.getApplicationContext(), TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE)); - TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); assertHasOneHead(negativeResults); assertCategoriesAre( negativeResults, @@ -107,7 +107,7 @@ public class TextClassifierTest { Category.create(0.95630914f, 0, "negative", ""), Category.create(0.04369091f, 1, "positive", ""))); - TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT); assertHasOneHead(positiveResults); assertHasOneHead(positiveResults); assertCategoriesAre( @@ -122,7 +122,7 @@ public class TextClassifierTest { TextClassifier textClassifier = TextClassifier.createFromFile( ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); - TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); assertHasOneHead(negativeResults); assertCategoriesAre( negativeResults, @@ -130,7 +130,7 @@ public class TextClassifierTest { Category.create(0.6647746f, 0, "Negative", ""), Category.create(0.33522537f, 1, "Positive", ""))); - TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT); assertHasOneHead(positiveResults); assertCategoriesAre( positiveResults, @@ -139,16 +139,15 @@ public class TextClassifierTest { Category.create(0.48799595f, 1, "Positive", ""))); } - private static void assertHasOneHead(TextClassificationResult results) { - assertThat(results.classifications()).hasSize(1); - assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); - assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); - assertThat(results.classifications().get(0).entries()).hasSize(1); + private static void assertHasOneHead(TextClassifierResult results) { + assertThat(results.classificationResult().classifications()).hasSize(1); + assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classificationResult().classifications().get(0).headName().get()) + .isEqualTo("probability"); } - private static void assertCategoriesAre( - TextClassificationResult results, List categories) { - assertThat(results.classifications().get(0).entries().get(0).categories()) + private static void assertCategoriesAre(TextClassifierResult results, List categories) { + assertThat(results.classificationResult().classifications().get(0).categories()) .isEqualTo(categories); } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 99ebd9777..69820ce2d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -91,11 +91,12 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromFile( ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); - assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001); - assertThat(results.classifications().get(0).entries().get(0).categories().get(0)) + assertHasOneHead(results); + assertThat(results.classificationResult().classifications().get(0).categories()) + .hasSize(1001); + assertThat(results.classificationResult().classifications().get(0).categories().get(0)) .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", "")); } @@ -108,9 +109,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -128,9 +129,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", ""))); } @@ -144,9 +145,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -166,9 +167,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -190,9 +191,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -214,10 +215,10 @@ public class ImageClassifierTest { RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); - ImageClassificationResult results = + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); } @@ -233,10 +234,10 @@ public class ImageClassifierTest { ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRotationDegrees(-90).build(); - ImageClassificationResult results = + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList( @@ -258,11 +259,11 @@ public class ImageClassifierTest { RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); - ImageClassificationResult results = + ImageClassifierResult results = imageClassifier.classify( getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); } @@ -391,9 +392,9 @@ public class ImageClassifierTest { .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); - ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); + ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); - assertHasOneHeadAndOneTimestamp(results, 0); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); } @@ -410,9 +411,8 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassificationResult results = - imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); - assertHasOneHeadAndOneTimestamp(results, i); + ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); } @@ -478,24 +478,17 @@ public class ImageClassifierTest { return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); } - private static void assertHasOneHeadAndOneTimestamp( - ImageClassificationResult results, long timestampMs) { - assertThat(results.classifications()).hasSize(1); - assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); - assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); - assertThat(results.classifications().get(0).entries()).hasSize(1); - assertThat(results.classifications().get(0).entries().get(0).timestampMs()) - .isEqualTo(timestampMs); + private static void assertHasOneHead(ImageClassifierResult results) { + assertThat(results.classificationResult().classifications()).hasSize(1); + assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classificationResult().classifications().get(0).headName().get()) + .isEqualTo("probability"); } private static void assertCategoriesAre( - ImageClassificationResult results, List categories) { - assertThat(results.classifications().get(0).entries().get(0).categories()) - .hasSize(categories.size()); - for (int i = 0; i < categories.size(); i++) { - assertThat(results.classifications().get(0).entries().get(0).categories().get(i)) - .isEqualTo(categories.get(i)); - } + ImageClassifierResult results, List categories) { + assertThat(results.classificationResult().classifications().get(0).categories()) + .isEqualTo(categories); } private static void assertImageSizeIsExpected(MPImage inputImage) { diff --git a/mediapipe/tasks/python/test/audio/audio_classifier_test.py b/mediapipe/tasks/python/test/audio/audio_classifier_test.py index 983e922e7..0d067e587 100644 --- a/mediapipe/tasks/python/test/audio/audio_classifier_test.py +++ b/mediapipe/tasks/python/test/audio/audio_classifier_test.py @@ -174,12 +174,10 @@ class AudioClassifierTest(parameterized.TestCase): self.assertIsInstance(classifier, _AudioClassifier) def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): - base_options = _BaseOptions(model_asset_path='') + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') options = _AudioClassifierOptions(base_options=base_options) _AudioClassifier.create_from_options(options) diff --git a/mediapipe/tasks/python/test/text/text_classifier_test.py b/mediapipe/tasks/python/test/text/text_classifier_test.py index c93def48e..434181fbe 100644 --- a/mediapipe/tasks/python/test/text/text_classifier_test.py +++ b/mediapipe/tasks/python/test/text/text_classifier_test.py @@ -154,12 +154,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertIsInstance(classifier, _TextClassifier) def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): - base_options = _BaseOptions(model_asset_path='') + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') options = _TextClassifierOptions(base_options=base_options) _TextClassifier.create_from_options(options) diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py index 11941ce23..97fb30b32 100644 --- a/mediapipe/tasks/python/test/vision/image_classifier_test.py +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -147,12 +147,10 @@ class ImageClassifierTest(parameterized.TestCase): self.assertIsInstance(classifier, _ImageClassifier) def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): - base_options = _BaseOptions(model_asset_path='') + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') options = _ImageClassifierOptions(base_options=base_options) _ImageClassifier.create_from_options(options) diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 5072d3482..7f0b47eb7 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -97,12 +97,10 @@ class ImageSegmenterTest(parameterized.TestCase): self.assertIsInstance(segmenter, _ImageSegmenter) def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): - base_options = _BaseOptions(model_asset_path='') + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') options = _ImageSegmenterOptions(base_options=base_options) _ImageSegmenter.create_from_options(options) diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index 53c64427f..5afa31459 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -119,12 +119,10 @@ class ObjectDetectorTest(parameterized.TestCase): self.assertIsInstance(detector, _ObjectDetector) def test_create_from_options_fails_with_invalid_model_path(self): - # Invalid empty model path. with self.assertRaisesRegex( - ValueError, - r"ExternalFile must specify at least one of 'file_content', " - r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): - base_options = _BaseOptions(model_asset_path='') + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite') options = _ObjectDetectorOptions(base_options=base_options) _ObjectDetector.create_from_options(options) diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 14999a03e..081e63c2c 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -28,6 +28,7 @@ mediapipe_files(srcs = [ "bert_text_classifier.tflite", "mobilebert_embedding_with_metadata.tflite", "mobilebert_with_metadata.tflite", + "regex_one_embedding_with_metadata.tflite", "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", "universal_sentence_encoder_qa_with_metadata.tflite", @@ -92,6 +93,11 @@ filegroup( srcs = ["mobilebert_embedding_with_metadata.tflite"], ) +filegroup( + name = "regex_embedding_with_metadata", + srcs = ["regex_one_embedding_with_metadata.tflite"], +) + filegroup( name = "universal_sentence_encoder_qa", data = ["universal_sentence_encoder_qa_with_metadata.tflite"], diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index fd79487a4..39353b226 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner { */ async setOptions(options: AudioClassifierOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index e6b9adf20..cd7190dd9 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -26,6 +26,8 @@ mediapipe_ts_library( name = "base_options", srcs = ["base_options.ts"], deps = [ + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index 2f7d0db37..a7f7bd280 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -14,6 +14,8 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; import {BaseOptions} from '../../../../tasks/web/core/base_options'; @@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; * Converts a BaseOptions API object to its Protobuf representation. * @throws If neither a model assset path or buffer is provided */ -export async function convertBaseOptionsToProto(baseOptions: BaseOptions): - Promise { - if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); +export async function convertBaseOptionsToProto( + updatedOptions: BaseOptions, + currentOptions?: BaseOptionsProto): Promise { + const result = + currentOptions ? currentOptions.clone() : new BaseOptionsProto(); + + await configureExternalFile(updatedOptions, result); + configureAcceleration(updatedOptions, result); + + return result; +} + +/** + * Configues the `externalFile` option and validates that a single model is + * provided. + */ +async function configureExternalFile( + options: BaseOptions, proto: BaseOptionsProto) { + const externalFile = proto.getModelAsset() || new ExternalFile(); + proto.setModelAsset(externalFile); + + if (options.modelAssetPath || options.modelAssetBuffer) { + if (options.modelAssetPath && options.modelAssetBuffer) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } + + let modelAssetBuffer = options.modelAssetBuffer; + if (!modelAssetBuffer) { + const response = await fetch(options.modelAssetPath!.toString()); + modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); + } + externalFile.setFileContent(modelAssetBuffer); } - if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) { + + if (!externalFile.hasFileContent()) { throw new Error( 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); } - - let modelAssetBuffer = baseOptions.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(baseOptions.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - - const proto = new BaseOptionsProto(); - const externalFile = new ExternalFile(); - externalFile.setFileContent(modelAssetBuffer); - proto.setModelAsset(externalFile); - return proto; +} + +/** Configues the `acceleration` option. */ +function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { + if ('delegate' in options) { + const acceleration = new Acceleration(); + if (options.delegate === 'cpu') { + acceleration.setXnnpack( + new InferenceCalculatorOptions.Delegate.Xnnpack()); + proto.setAcceleration(acceleration); + } else if (options.delegate === 'gpu') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + proto.setAcceleration(acceleration); + } else { + proto.clearAcceleration(); + } + } } diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/base_options.d.ts index 02a288a87..54a59a21b 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/base_options.d.ts @@ -22,10 +22,14 @@ export interface BaseOptions { * The model path to the model asset file. Only one of `modelAssetPath` or * `modelAssetBuffer` can be set. */ - modelAssetPath?: string; + modelAssetPath?: string|undefined; + /** * A buffer containing the model aaset. Only one of `modelAssetPath` or * `modelAssetBuffer` can be set. */ - modelAssetBuffer?: Uint8Array; + modelAssetBuffer?: Uint8Array|undefined; + + /** Overrides the default backend to use for the provided model. */ + delegate?: 'cpu'|'gpu'|undefined; } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index ff36bb9e0..d92248b80 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner { */ async setOptions(options: TextClassifierOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index ad8db1477..1275ae875 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner { */ async setOptions(options: GestureRecognizerOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 39674e85c..cb63874c4 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner { */ async setOptions(options: ImageClassifierOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index c3bb21baa..022bf6301 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner { */ async setOptions(options: ObjectDetectorOptions): Promise { if (options.baseOptions) { - const baseOptionsProto = - await convertBaseOptionsToProto(options.baseOptions); + const baseOptionsProto = await convertBaseOptionsToProto( + options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index de1d0a976..d5ddc8d78 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -402,8 +402,8 @@ def external_files(): http_file( name = "com_google_mediapipe_labels_txt", - sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a", - urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667855388142641"], + sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9", + urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667888034706429"], ) http_file( @@ -552,8 +552,14 @@ def external_files(): http_file( name = "com_google_mediapipe_movie_review_json", - sha256 = "89ad347ad1cb7c587da144de6efbadec1d3e8ff0cd13e379dd16661a8186fbb5", - urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667855392734031"], + sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e", + urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667888039053188"], + ) + + http_file( + name = "com_google_mediapipe_movie_review_labels_txt", + sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a", + urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667888041670721"], ) http_file( @@ -688,6 +694,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"], ) + http_file( + name = "com_google_mediapipe_regex_one_embedding_with_metadata_tflite", + sha256 = "b8f5d6d090c2c73984b2b92cd2fda27e5630562741a93d127b9a744d60505bc0", + urls = ["https://storage.googleapis.com/mediapipe-assets/regex_one_embedding_with_metadata.tflite?generation=1667888045310541"], + ) + + http_file( + name = "com_google_mediapipe_regex_vocab_txt", + sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923", + urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667888047885461"], + ) + http_file( name = "com_google_mediapipe_right_hands_jpg", sha256 = "240c082e80128ff1ca8a83ce645e2ba4d8bc30f0967b7991cf5fa375bab489e1",