Migrate Java ImageClassifier / TextClassifier to new result format.
PiperOrigin-RevId: 486976459
This commit is contained in:
parent
26066787b3
commit
5e6842aa5c
|
@ -20,6 +20,7 @@ android_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["Category.java"],
|
srcs = ["Category.java"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
|
@ -36,20 +37,29 @@ android_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
android_library(
|
android_library(
|
||||||
name = "classification_entry",
|
name = "classifications",
|
||||||
srcs = ["ClassificationEntry.java"],
|
srcs = ["Classifications.java"],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":category",
|
":category",
|
||||||
|
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
android_library(
|
android_library(
|
||||||
name = "classifications",
|
name = "classificationresult",
|
||||||
srcs = ["Classifications.java"],
|
srcs = ["ClassificationResult.java"],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":classification_entry",
|
":classifications",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package com.google.mediapipe.tasks.components.containers;
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -38,6 +39,16 @@ public abstract class Category {
|
||||||
return new AutoValue_Category(score, index, categoryName, displayName);
|
return new AutoValue_Category(score, index, categoryName, displayName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf
|
||||||
|
* message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationProto.Classification} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
public static Category createFromProto(ClassificationProto.Classification proto) {
|
||||||
|
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
|
||||||
|
}
|
||||||
|
|
||||||
/** The probability score of this label category. */
|
/** The probability score of this label category. */
|
||||||
public abstract float score();
|
public abstract float score();
|
||||||
|
|
||||||
|
|
|
@ -1,48 +0,0 @@
|
||||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package com.google.mediapipe.tasks.components.containers;
|
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a list of predicted categories with an optional timestamp. Typically used as result
|
|
||||||
* for classification tasks.
|
|
||||||
*/
|
|
||||||
@AutoValue
|
|
||||||
public abstract class ClassificationEntry {
|
|
||||||
/**
|
|
||||||
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
|
|
||||||
* timestamp.
|
|
||||||
*
|
|
||||||
* @param categories the list of {@link Category} objects that contain category name, display
|
|
||||||
* name, score and label index.
|
|
||||||
* @param timestampMs the {@link long} representing the timestamp for which these categories were
|
|
||||||
* obtained.
|
|
||||||
*/
|
|
||||||
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
|
|
||||||
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** The list of predicted {@link Category} objects, sorted by descending score. */
|
|
||||||
public abstract List<Category> categories();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
|
|
||||||
* series use cases, e.g. audio classification.
|
|
||||||
*/
|
|
||||||
public abstract long timestampMs();
|
|
||||||
}
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the classification results of a model. Typically used as a result for classification
|
||||||
|
* tasks.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ClassificationResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link ClassificationResult} instance.
|
||||||
|
*
|
||||||
|
* @param classifications the list of {@link Classifications} objects containing the predicted
|
||||||
|
* categories for each head of the model.
|
||||||
|
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
* corresponding to these results.
|
||||||
|
*/
|
||||||
|
public static ClassificationResult create(
|
||||||
|
List<Classifications> classifications, Optional<Long> timestampMs) {
|
||||||
|
return new AutoValue_ClassificationResult(
|
||||||
|
Collections.unmodifiableList(classifications), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link ClassificationResult} object from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
public static ClassificationResult createFromProto(
|
||||||
|
ClassificationsProto.ClassificationResult proto) {
|
||||||
|
List<Classifications> classifications = new ArrayList<>();
|
||||||
|
for (ClassificationsProto.Classifications classificationsProto :
|
||||||
|
proto.getClassificationsList()) {
|
||||||
|
classifications.add(Classifications.createFromProto(classificationsProto));
|
||||||
|
}
|
||||||
|
Optional<Long> timestampMs =
|
||||||
|
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
|
||||||
|
return create(classifications, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The classification results for each head of the model. */
|
||||||
|
public abstract List<Classifications> classifications();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
|
||||||
|
* these results.
|
||||||
|
*
|
||||||
|
* <p>This is only used for classification on time series (e.g. audio classification). In these
|
||||||
|
* use cases, the amount of data to process might exceed the maximum size that the model can
|
||||||
|
* process: to solve this, the input data is split into multiple chunks starting at different
|
||||||
|
* timestamps.
|
||||||
|
*/
|
||||||
|
public abstract Optional<Long> timestampMs();
|
||||||
|
}
|
|
@ -15,8 +15,12 @@
|
||||||
package com.google.mediapipe.tasks.components.containers;
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents the list of classification for a given classifier head. Typically used as a result for
|
* Represents the list of classification for a given classifier head. Typically used as a result for
|
||||||
|
@ -28,25 +32,41 @@ public abstract class Classifications {
|
||||||
/**
|
/**
|
||||||
* Creates a {@link Classifications} instance.
|
* Creates a {@link Classifications} instance.
|
||||||
*
|
*
|
||||||
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
|
* @param categories the list of {@link Category} objects containing the predicted categories.
|
||||||
* categories.
|
|
||||||
* @param headIndex the index of the classifier head.
|
* @param headIndex the index of the classifier head.
|
||||||
* @param headName the name of the classifier head.
|
* @param headName the optional name of the classifier head.
|
||||||
*/
|
*/
|
||||||
public static Classifications create(
|
public static Classifications create(
|
||||||
List<ClassificationEntry> entries, int headIndex, String headName) {
|
List<Category> categories, int headIndex, Optional<String> headName) {
|
||||||
return new AutoValue_Classifications(
|
return new AutoValue_Classifications(
|
||||||
Collections.unmodifiableList(entries), headIndex, headName);
|
Collections.unmodifiableList(categories), headIndex, headName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** A list of {@link ClassificationEntry} objects. */
|
/**
|
||||||
public abstract List<ClassificationEntry> entries();
|
* Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications}
|
||||||
|
* protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
|
||||||
|
*/
|
||||||
|
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
|
||||||
|
List<Category> categories = new ArrayList<>();
|
||||||
|
for (ClassificationProto.Classification classificationProto :
|
||||||
|
proto.getClassificationList().getClassificationList()) {
|
||||||
|
categories.add(Category.createFromProto(classificationProto));
|
||||||
|
}
|
||||||
|
Optional<String> headName =
|
||||||
|
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
|
||||||
|
return create(categories, proto.getHeadIndex(), headName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A list of {@link Category} objects. */
|
||||||
|
public abstract List<Category> categories();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The index of the classifier head these entries refer to. This is useful for multi-head models.
|
* The index of the classifier head these entries refer to. This is useful for multi-head models.
|
||||||
*/
|
*/
|
||||||
public abstract int headIndex();
|
public abstract int headIndex();
|
||||||
|
|
||||||
/** The name of the classifier head, which is the corresponding tensor metadata name. */
|
/** The optional name of the classifier head, which is the corresponding tensor metadata name. */
|
||||||
public abstract String headName();
|
public abstract Optional<String> headName();
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,8 +37,8 @@ cc_library(
|
||||||
android_library(
|
android_library(
|
||||||
name = "textclassifier",
|
name = "textclassifier",
|
||||||
srcs = [
|
srcs = [
|
||||||
"textclassifier/TextClassificationResult.java",
|
|
||||||
"textclassifier/TextClassifier.java",
|
"textclassifier/TextClassifier.java",
|
||||||
|
"textclassifier/TextClassifierResult.java",
|
||||||
],
|
],
|
||||||
javacopts = [
|
javacopts = [
|
||||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
@ -51,9 +51,7 @@ android_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||||
|
|
|
@ -1,103 +0,0 @@
|
||||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package com.google.mediapipe.tasks.text.textclassifier;
|
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/** Represents the classification results generated by {@link TextClassifier}. */
|
|
||||||
@AutoValue
|
|
||||||
public abstract class TextClassificationResult implements TaskResult {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an {@link TextClassificationResult} instance from a {@link
|
|
||||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
|
||||||
*
|
|
||||||
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
|
||||||
* message.
|
|
||||||
* @param timestampMs a timestamp for this result.
|
|
||||||
*/
|
|
||||||
// TODO: consolidate output formats across platforms.
|
|
||||||
static TextClassificationResult create(
|
|
||||||
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
|
||||||
List<Classifications> classifications = new ArrayList<>();
|
|
||||||
for (ClassificationsProto.Classifications classificationsProto :
|
|
||||||
classificationResult.getClassificationsList()) {
|
|
||||||
classifications.add(classificationsFromProto(classificationsProto));
|
|
||||||
}
|
|
||||||
return new AutoValue_TextClassificationResult(
|
|
||||||
timestampMs, Collections.unmodifiableList(classifications));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public abstract long timestampMs();
|
|
||||||
|
|
||||||
/** Contains one set of results per classifier head. */
|
|
||||||
@SuppressWarnings("AutoValueImmutableFields")
|
|
||||||
public abstract List<Classifications> classifications();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
|
||||||
*
|
|
||||||
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
|
||||||
*/
|
|
||||||
static Category categoryFromProto(CategoryProto.Category category) {
|
|
||||||
return Category.create(
|
|
||||||
category.getScore(),
|
|
||||||
category.getIndex(),
|
|
||||||
category.getCategoryName(),
|
|
||||||
category.getDisplayName());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
|
||||||
* ClassificationEntry} object.
|
|
||||||
*
|
|
||||||
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
|
||||||
*/
|
|
||||||
static ClassificationEntry classificationEntryFromProto(
|
|
||||||
ClassificationsProto.ClassificationEntry entry) {
|
|
||||||
List<Category> categories = new ArrayList<>();
|
|
||||||
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
|
||||||
categories.add(categoryFromProto(category));
|
|
||||||
}
|
|
||||||
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
|
||||||
* Classifications} object.
|
|
||||||
*
|
|
||||||
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
|
||||||
* convert.
|
|
||||||
*/
|
|
||||||
static Classifications classificationsFromProto(
|
|
||||||
ClassificationsProto.Classifications classifications) {
|
|
||||||
List<ClassificationEntry> entries = new ArrayList<>();
|
|
||||||
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
|
||||||
entries.add(classificationEntryFromProto(entry));
|
|
||||||
}
|
|
||||||
return Classifications.create(
|
|
||||||
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
|
||||||
import com.google.mediapipe.framework.Packet;
|
import com.google.mediapipe.framework.Packet;
|
||||||
import com.google.mediapipe.framework.PacketGetter;
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
import com.google.mediapipe.framework.ProtoUtil;
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable {
|
||||||
|
|
||||||
@SuppressWarnings("ConstantCaseForConstants")
|
@SuppressWarnings("ConstantCaseForConstants")
|
||||||
private static final List<String> OUTPUT_STREAMS =
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
Collections.unmodifiableList(
|
Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out"));
|
||||||
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
|
|
||||||
|
|
||||||
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
private final TaskRunner runner;
|
private final TaskRunner runner;
|
||||||
|
@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable {
|
||||||
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
||||||
*/
|
*/
|
||||||
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
|
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
|
||||||
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>();
|
OutputHandler<TextClassifierResult, Void> handler = new OutputHandler<>();
|
||||||
handler.setOutputPacketConverter(
|
handler.setOutputPacketConverter(
|
||||||
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() {
|
new OutputHandler.OutputPacketConverter<TextClassifierResult, Void>() {
|
||||||
@Override
|
@Override
|
||||||
public TextClassificationResult convertToTaskResult(List<Packet> packets) {
|
public TextClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||||
try {
|
try {
|
||||||
return TextClassificationResult.create(
|
return TextClassifierResult.create(
|
||||||
PacketGetter.getProto(
|
ClassificationResult.createFromProto(
|
||||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
PacketGetter.getProto(
|
||||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
ClassificationsProto.ClassificationResult.getDefaultInstance())),
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new MediaPipeException(
|
throw new MediaPipeException(
|
||||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||||
|
@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable {
|
||||||
*
|
*
|
||||||
* @param inputText a {@link String} for processing.
|
* @param inputText a {@link String} for processing.
|
||||||
*/
|
*/
|
||||||
public TextClassificationResult classify(String inputText) {
|
public TextClassifierResult classify(String inputText) {
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
||||||
return (TextClassificationResult) runner.process(inputPackets);
|
return (TextClassifierResult) runner.process(inputPackets);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Closes and cleans up the {@link TextClassifier}. */
|
/** Closes and cleans up the {@link TextClassifier}. */
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.textclassifier;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
|
||||||
|
/** Represents the classification results generated by {@link TextClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class TextClassifierResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link TextClassifierResult} instance.
|
||||||
|
*
|
||||||
|
* @param classificationResult the {@link ClassificationResult} object containing one set of
|
||||||
|
* results per classifier head.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
|
||||||
|
return new AutoValue_TextClassifierResult(classificationResult, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link TextClassifierResult} instance from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
// TODO: consolidate output formats across platforms.
|
||||||
|
static TextClassifierResult createFromProto(
|
||||||
|
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||||
|
return create(ClassificationResult.createFromProto(proto), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Contains one set of results per classifier head. */
|
||||||
|
public abstract ClassificationResult classificationResult();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -84,8 +84,8 @@ android_library(
|
||||||
android_library(
|
android_library(
|
||||||
name = "imageclassifier",
|
name = "imageclassifier",
|
||||||
srcs = [
|
srcs = [
|
||||||
"imageclassifier/ImageClassificationResult.java",
|
|
||||||
"imageclassifier/ImageClassifier.java",
|
"imageclassifier/ImageClassifier.java",
|
||||||
|
"imageclassifier/ImageClassifierResult.java",
|
||||||
],
|
],
|
||||||
javacopts = [
|
javacopts = [
|
||||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
@ -100,9 +100,7 @@ android_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
"//third_party:autovalue",
|
"//third_party:autovalue",
|
||||||
|
|
|
@ -1,102 +0,0 @@
|
||||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/** Represents the classification results generated by {@link ImageClassifier}. */
|
|
||||||
@AutoValue
|
|
||||||
public abstract class ImageClassificationResult implements TaskResult {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an {@link ImageClassificationResult} instance from a {@link
|
|
||||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
|
||||||
*
|
|
||||||
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
|
||||||
* message.
|
|
||||||
* @param timestampMs a timestamp for this result.
|
|
||||||
*/
|
|
||||||
// TODO: consolidate output formats across platforms.
|
|
||||||
static ImageClassificationResult create(
|
|
||||||
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
|
||||||
List<Classifications> classifications = new ArrayList<>();
|
|
||||||
for (ClassificationsProto.Classifications classificationsProto :
|
|
||||||
classificationResult.getClassificationsList()) {
|
|
||||||
classifications.add(classificationsFromProto(classificationsProto));
|
|
||||||
}
|
|
||||||
return new AutoValue_ImageClassificationResult(
|
|
||||||
timestampMs, Collections.unmodifiableList(classifications));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public abstract long timestampMs();
|
|
||||||
|
|
||||||
/** Contains one set of results per classifier head. */
|
|
||||||
public abstract List<Classifications> classifications();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
|
||||||
*
|
|
||||||
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
|
||||||
*/
|
|
||||||
static Category categoryFromProto(CategoryProto.Category category) {
|
|
||||||
return Category.create(
|
|
||||||
category.getScore(),
|
|
||||||
category.getIndex(),
|
|
||||||
category.getCategoryName(),
|
|
||||||
category.getDisplayName());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
|
||||||
* ClassificationEntry} object.
|
|
||||||
*
|
|
||||||
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
|
||||||
*/
|
|
||||||
static ClassificationEntry classificationEntryFromProto(
|
|
||||||
ClassificationsProto.ClassificationEntry entry) {
|
|
||||||
List<Category> categories = new ArrayList<>();
|
|
||||||
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
|
||||||
categories.add(categoryFromProto(category));
|
|
||||||
}
|
|
||||||
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
|
||||||
* Classifications} object.
|
|
||||||
*
|
|
||||||
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
|
||||||
* convert.
|
|
||||||
*/
|
|
||||||
static Classifications classificationsFromProto(
|
|
||||||
ClassificationsProto.Classifications classifications) {
|
|
||||||
List<ClassificationEntry> entries = new ArrayList<>();
|
|
||||||
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
|
||||||
entries.add(classificationEntryFromProto(entry));
|
|
||||||
}
|
|
||||||
return Classifications.create(
|
|
||||||
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
||||||
import com.google.mediapipe.framework.ProtoUtil;
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.MPImage;
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
private static final List<String> OUTPUT_STREAMS =
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
Collections.unmodifiableList(
|
Collections.unmodifiableList(
|
||||||
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out"));
|
Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out"));
|
||||||
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||||
|
@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
|
||||||
*/
|
*/
|
||||||
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
|
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
|
||||||
OutputHandler<ImageClassificationResult, MPImage> handler = new OutputHandler<>();
|
OutputHandler<ImageClassifierResult, MPImage> handler = new OutputHandler<>();
|
||||||
handler.setOutputPacketConverter(
|
handler.setOutputPacketConverter(
|
||||||
new OutputHandler.OutputPacketConverter<ImageClassificationResult, MPImage>() {
|
new OutputHandler.OutputPacketConverter<ImageClassifierResult, MPImage>() {
|
||||||
@Override
|
@Override
|
||||||
public ImageClassificationResult convertToTaskResult(List<Packet> packets) {
|
public ImageClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||||
try {
|
try {
|
||||||
return ImageClassificationResult.create(
|
return ImageClassifierResult.create(
|
||||||
PacketGetter.getProto(
|
ClassificationResult.createFromProto(
|
||||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
PacketGetter.getProto(
|
||||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
ClassificationsProto.ClassificationResult.getDefaultInstance())),
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new MediaPipeException(
|
throw new MediaPipeException(
|
||||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||||
|
@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classify(MPImage image) {
|
public ImageClassifierResult classify(MPImage image) {
|
||||||
return classify(image, ImageProcessingOptions.builder().build());
|
return classify(image, ImageProcessingOptions.builder().build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* input image before running inference.
|
* input image before running inference.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classify(
|
public ImageClassifierResult classify(
|
||||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
return (ImageClassificationResult) processImageData(image, imageProcessingOptions);
|
return (ImageClassifierResult) processImageData(image, imageProcessingOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* @param timestampMs the input timestamp (in milliseconds).
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) {
|
public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) {
|
||||||
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* @param timestampMs the input timestamp (in milliseconds).
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classifyForVideo(
|
public ImageClassifierResult classifyForVideo(
|
||||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* the image classifier is in the live stream mode.
|
* the image classifier is in the live stream mode.
|
||||||
*/
|
*/
|
||||||
public abstract Builder setResultListener(
|
public abstract Builder setResultListener(
|
||||||
ResultListener<ImageClassificationResult, MPImage> resultListener);
|
ResultListener<ImageClassifierResult, MPImage> resultListener);
|
||||||
|
|
||||||
/** Sets an optional {@link ErrorListener}. */
|
/** Sets an optional {@link ErrorListener}. */
|
||||||
public abstract Builder setErrorListener(ErrorListener errorListener);
|
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||||
|
@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
|
|
||||||
abstract Optional<ClassifierOptions> classifierOptions();
|
abstract Optional<ClassifierOptions> classifierOptions();
|
||||||
|
|
||||||
abstract Optional<ResultListener<ImageClassificationResult, MPImage>> resultListener();
|
abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
|
||||||
|
|
||||||
abstract Optional<ErrorListener> errorListener();
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
|
||||||
|
/** Represents the classification results generated by {@link ImageClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ImageClassifierResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassifierResult} instance.
|
||||||
|
*
|
||||||
|
* @param classificationResult the {@link ClassificationResult} object containing one set of
|
||||||
|
* results per classifier head.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
|
||||||
|
return new AutoValue_ImageClassifierResult(classificationResult, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ImageClassifierResult} instance from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
// TODO: consolidate output formats across platforms.
|
||||||
|
static ImageClassifierResult createFromProto(
|
||||||
|
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||||
|
return create(ClassificationResult.createFromProto(proto), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Contains one set of results per classifier head. */
|
||||||
|
public abstract ClassificationResult classificationResult();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -76,7 +76,7 @@ public class TextClassifierTest {
|
||||||
public void classify_succeedsWithBert() throws Exception {
|
public void classify_succeedsWithBert() throws Exception {
|
||||||
TextClassifier textClassifier =
|
TextClassifier textClassifier =
|
||||||
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||||
assertHasOneHead(negativeResults);
|
assertHasOneHead(negativeResults);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
negativeResults,
|
negativeResults,
|
||||||
|
@ -84,7 +84,7 @@ public class TextClassifierTest {
|
||||||
Category.create(0.95630914f, 0, "negative", ""),
|
Category.create(0.95630914f, 0, "negative", ""),
|
||||||
Category.create(0.04369091f, 1, "positive", "")));
|
Category.create(0.04369091f, 1, "positive", "")));
|
||||||
|
|
||||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||||
assertHasOneHead(positiveResults);
|
assertHasOneHead(positiveResults);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
positiveResults,
|
positiveResults,
|
||||||
|
@ -99,7 +99,7 @@ public class TextClassifierTest {
|
||||||
TextClassifier.createFromFile(
|
TextClassifier.createFromFile(
|
||||||
ApplicationProvider.getApplicationContext(),
|
ApplicationProvider.getApplicationContext(),
|
||||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
|
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
|
||||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||||
assertHasOneHead(negativeResults);
|
assertHasOneHead(negativeResults);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
negativeResults,
|
negativeResults,
|
||||||
|
@ -107,7 +107,7 @@ public class TextClassifierTest {
|
||||||
Category.create(0.95630914f, 0, "negative", ""),
|
Category.create(0.95630914f, 0, "negative", ""),
|
||||||
Category.create(0.04369091f, 1, "positive", "")));
|
Category.create(0.04369091f, 1, "positive", "")));
|
||||||
|
|
||||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||||
assertHasOneHead(positiveResults);
|
assertHasOneHead(positiveResults);
|
||||||
assertHasOneHead(positiveResults);
|
assertHasOneHead(positiveResults);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
|
@ -122,7 +122,7 @@ public class TextClassifierTest {
|
||||||
TextClassifier textClassifier =
|
TextClassifier textClassifier =
|
||||||
TextClassifier.createFromFile(
|
TextClassifier.createFromFile(
|
||||||
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
|
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
|
||||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||||
assertHasOneHead(negativeResults);
|
assertHasOneHead(negativeResults);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
negativeResults,
|
negativeResults,
|
||||||
|
@ -130,7 +130,7 @@ public class TextClassifierTest {
|
||||||
Category.create(0.6647746f, 0, "Negative", ""),
|
Category.create(0.6647746f, 0, "Negative", ""),
|
||||||
Category.create(0.33522537f, 1, "Positive", "")));
|
Category.create(0.33522537f, 1, "Positive", "")));
|
||||||
|
|
||||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||||
assertHasOneHead(positiveResults);
|
assertHasOneHead(positiveResults);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
positiveResults,
|
positiveResults,
|
||||||
|
@ -139,16 +139,15 @@ public class TextClassifierTest {
|
||||||
Category.create(0.48799595f, 1, "Positive", "")));
|
Category.create(0.48799595f, 1, "Positive", "")));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertHasOneHead(TextClassificationResult results) {
|
private static void assertHasOneHead(TextClassifierResult results) {
|
||||||
assertThat(results.classifications()).hasSize(1);
|
assertThat(results.classificationResult().classifications()).hasSize(1);
|
||||||
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
|
||||||
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
assertThat(results.classificationResult().classifications().get(0).headName().get())
|
||||||
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
.isEqualTo("probability");
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertCategoriesAre(
|
private static void assertCategoriesAre(TextClassifierResult results, List<Category> categories) {
|
||||||
TextClassificationResult results, List<Category> categories) {
|
assertThat(results.classificationResult().classifications().get(0).categories())
|
||||||
assertThat(results.classifications().get(0).entries().get(0).categories())
|
|
||||||
.isEqualTo(categories);
|
.isEqualTo(categories);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,11 +91,12 @@ public class ImageClassifierTest {
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromFile(
|
ImageClassifier.createFromFile(
|
||||||
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
|
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001);
|
assertThat(results.classificationResult().classifications().get(0).categories())
|
||||||
assertThat(results.classifications().get(0).entries().get(0).categories().get(0))
|
.hasSize(1001);
|
||||||
|
assertThat(results.classificationResult().classifications().get(0).categories().get(0))
|
||||||
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
|
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,9 +109,9 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results,
|
results,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -128,9 +129,9 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
|
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
|
||||||
}
|
}
|
||||||
|
@ -144,9 +145,9 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results,
|
results,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -166,9 +167,9 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results,
|
results,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -190,9 +191,9 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results,
|
results,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -214,10 +215,10 @@ public class ImageClassifierTest {
|
||||||
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
||||||
ImageProcessingOptions imageProcessingOptions =
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
|
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
|
||||||
ImageClassificationResult results =
|
ImageClassifierResult results =
|
||||||
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
|
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
||||||
}
|
}
|
||||||
|
@ -233,10 +234,10 @@ public class ImageClassifierTest {
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageProcessingOptions imageProcessingOptions =
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||||
ImageClassificationResult results =
|
ImageClassifierResult results =
|
||||||
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results,
|
results,
|
||||||
Arrays.asList(
|
Arrays.asList(
|
||||||
|
@ -258,11 +259,11 @@ public class ImageClassifierTest {
|
||||||
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
|
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
|
||||||
ImageProcessingOptions imageProcessingOptions =
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
|
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
|
||||||
ImageClassificationResult results =
|
ImageClassifierResult results =
|
||||||
imageClassifier.classify(
|
imageClassifier.classify(
|
||||||
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
|
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
|
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
|
||||||
}
|
}
|
||||||
|
@ -391,9 +392,9 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHead(results);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
}
|
}
|
||||||
|
@ -410,9 +411,8 @@ public class ImageClassifierTest {
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
ImageClassificationResult results =
|
ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
||||||
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
assertHasOneHead(results);
|
||||||
assertHasOneHeadAndOneTimestamp(results, i);
|
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
}
|
}
|
||||||
|
@ -478,24 +478,17 @@ public class ImageClassifierTest {
|
||||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertHasOneHeadAndOneTimestamp(
|
private static void assertHasOneHead(ImageClassifierResult results) {
|
||||||
ImageClassificationResult results, long timestampMs) {
|
assertThat(results.classificationResult().classifications()).hasSize(1);
|
||||||
assertThat(results.classifications()).hasSize(1);
|
assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
|
||||||
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
assertThat(results.classificationResult().classifications().get(0).headName().get())
|
||||||
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
.isEqualTo("probability");
|
||||||
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
|
||||||
assertThat(results.classifications().get(0).entries().get(0).timestampMs())
|
|
||||||
.isEqualTo(timestampMs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertCategoriesAre(
|
private static void assertCategoriesAre(
|
||||||
ImageClassificationResult results, List<Category> categories) {
|
ImageClassifierResult results, List<Category> categories) {
|
||||||
assertThat(results.classifications().get(0).entries().get(0).categories())
|
assertThat(results.classificationResult().classifications().get(0).categories())
|
||||||
.hasSize(categories.size());
|
.isEqualTo(categories);
|
||||||
for (int i = 0; i < categories.size(); i++) {
|
|
||||||
assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
|
|
||||||
.isEqualTo(categories.get(i));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user