Merge branch 'google:master' into hand-landmarker-python

This commit is contained in:
Kinar R 2022-11-09 00:52:20 +05:30 committed by GitHub
commit 6dd6d8921f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 441 additions and 420 deletions

View File

@ -20,6 +20,7 @@ android_library(
name = "category", name = "category",
srcs = ["Category.java"], srcs = ["Category.java"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
@ -36,20 +37,29 @@ android_library(
) )
android_library( android_library(
name = "classification_entry", name = "classifications",
srcs = ["ClassificationEntry.java"], srcs = ["Classifications.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
":category", ":category",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
android_library( android_library(
name = "classifications", name = "classificationresult",
srcs = ["Classifications.java"], srcs = ["ClassificationResult.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
":classification_entry", ":classifications",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//third_party:autovalue", "//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],

View File

@ -15,6 +15,7 @@
package com.google.mediapipe.tasks.components.containers; package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.ClassificationProto;
import java.util.Objects; import java.util.Objects;
/** /**
@ -38,6 +39,16 @@ public abstract class Category {
return new AutoValue_Category(score, index, categoryName, displayName); return new AutoValue_Category(score, index, categoryName, displayName);
} }
/**
* Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf
* message.
*
* @param proto the {@link ClassificationProto.Classification} protobuf message to convert.
*/
public static Category createFromProto(ClassificationProto.Classification proto) {
return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
}
/** The probability score of this label category. */ /** The probability score of this label category. */
public abstract float score(); public abstract float score();

View File

@ -1,48 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents a list of predicted categories with an optional timestamp. Typically used as result
* for classification tasks.
*/
@AutoValue
public abstract class ClassificationEntry {
/**
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
* timestamp.
*
* @param categories the list of {@link Category} objects that contain category name, display
* name, score and label index.
* @param timestampMs the {@link long} representing the timestamp for which these categories were
* obtained.
*/
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
}
/** The list of predicted {@link Category} objects, sorted by descending score. */
public abstract List<Category> categories();
/**
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
* series use cases, e.g. audio classification.
*/
public abstract long timestampMs();
}

View File

@ -0,0 +1,76 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Represents the classification results of a model. Typically used as a result for classification
* tasks.
*/
@AutoValue
public abstract class ClassificationResult {
/**
* Creates a {@link ClassificationResult} instance.
*
* @param classifications the list of {@link Classifications} objects containing the predicted
* categories for each head of the model.
* @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*/
public static ClassificationResult create(
List<Classifications> classifications, Optional<Long> timestampMs) {
return new AutoValue_ClassificationResult(
Collections.unmodifiableList(classifications), timestampMs);
}
/**
* Creates a {@link ClassificationResult} object from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
*/
public static ClassificationResult createFromProto(
ClassificationsProto.ClassificationResult proto) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
proto.getClassificationsList()) {
classifications.add(Classifications.createFromProto(classificationsProto));
}
Optional<Long> timestampMs =
proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
return create(classifications, timestampMs);
}
/** The classification results for each head of the model. */
public abstract List<Classifications> classifications();
/**
* The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
* these results.
*
* <p>This is only used for classification on time series (e.g. audio classification). In these
* use cases, the amount of data to process might exceed the maximum size that the model can
* process: to solve this, the input data is split into multiple chunks starting at different
* timestamps.
*/
public abstract Optional<Long> timestampMs();
}

View File

@ -15,8 +15,12 @@
package com.google.mediapipe.tasks.components.containers; package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.ClassificationProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
/** /**
* Represents the list of classification for a given classifier head. Typically used as a result for * Represents the list of classification for a given classifier head. Typically used as a result for
@ -28,25 +32,41 @@ public abstract class Classifications {
/** /**
* Creates a {@link Classifications} instance. * Creates a {@link Classifications} instance.
* *
* @param entries the list of {@link ClassificationEntry} objects containing the predicted * @param categories the list of {@link Category} objects containing the predicted categories.
* categories.
* @param headIndex the index of the classifier head. * @param headIndex the index of the classifier head.
* @param headName the name of the classifier head. * @param headName the optional name of the classifier head.
*/ */
public static Classifications create( public static Classifications create(
List<ClassificationEntry> entries, int headIndex, String headName) { List<Category> categories, int headIndex, Optional<String> headName) {
return new AutoValue_Classifications( return new AutoValue_Classifications(
Collections.unmodifiableList(entries), headIndex, headName); Collections.unmodifiableList(categories), headIndex, headName);
} }
/** A list of {@link ClassificationEntry} objects. */ /**
public abstract List<ClassificationEntry> entries(); * Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications}
* protobuf message.
*
* @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
*/
public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
List<Category> categories = new ArrayList<>();
for (ClassificationProto.Classification classificationProto :
proto.getClassificationList().getClassificationList()) {
categories.add(Category.createFromProto(classificationProto));
}
Optional<String> headName =
proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
return create(categories, proto.getHeadIndex(), headName);
}
/** A list of {@link Category} objects. */
public abstract List<Category> categories();
/** /**
* The index of the classifier head these entries refer to. This is useful for multi-head models. * The index of the classifier head these entries refer to. This is useful for multi-head models.
*/ */
public abstract int headIndex(); public abstract int headIndex();
/** The name of the classifier head, which is the corresponding tensor metadata name. */ /** The optional name of the classifier head, which is the corresponding tensor metadata name. */
public abstract String headName(); public abstract Optional<String> headName();
} }

View File

@ -234,7 +234,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l
"//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",

View File

@ -37,8 +37,8 @@ cc_library(
android_library( android_library(
name = "textclassifier", name = "textclassifier",
srcs = [ srcs = [
"textclassifier/TextClassificationResult.java",
"textclassifier/TextClassifier.java", "textclassifier/TextClassifier.java",
"textclassifier/TextClassifierResult.java",
], ],
javacopts = [ javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",
@ -51,9 +51,7 @@ android_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",

View File

@ -1,103 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the classification results generated by {@link TextClassifier}. */
@AutoValue
public abstract class TextClassificationResult implements TaskResult {
/**
* Creates an {@link TextClassificationResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
* message.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static TextClassificationResult create(
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
classificationResult.getClassificationsList()) {
classifications.add(classificationsFromProto(classificationsProto));
}
return new AutoValue_TextClassificationResult(
timestampMs, Collections.unmodifiableList(classifications));
}
@Override
public abstract long timestampMs();
/** Contains one set of results per classifier head. */
@SuppressWarnings("AutoValueImmutableFields")
public abstract List<Classifications> classifications();
/**
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
*
* @param category the {@link CategoryProto.Category} protobuf message to convert.
*/
static Category categoryFromProto(CategoryProto.Category category) {
return Category.create(
category.getScore(),
category.getIndex(),
category.getCategoryName(),
category.getDisplayName());
}
/**
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
* ClassificationEntry} object.
*
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
*/
static ClassificationEntry classificationEntryFromProto(
ClassificationsProto.ClassificationEntry entry) {
List<Category> categories = new ArrayList<>();
for (CategoryProto.Category category : entry.getCategoriesList()) {
categories.add(categoryFromProto(category));
}
return ClassificationEntry.create(categories, entry.getTimestampMs());
}
/**
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
* Classifications} object.
*
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
* convert.
*/
static Classifications classificationsFromProto(
ClassificationsProto.Classifications classifications) {
List<ClassificationEntry> entries = new ArrayList<>();
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
entries.add(classificationEntryFromProto(entry));
}
return Classifications.create(
entries, classifications.getHeadIndex(), classifications.getHeadName());
}
}

View File

@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable {
@SuppressWarnings("ConstantCaseForConstants") @SuppressWarnings("ConstantCaseForConstants")
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out"));
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; "mediapipe.tasks.text.text_classifier.TextClassifierGraph";
private final TaskRunner runner; private final TaskRunner runner;
@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable {
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation. * @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
*/ */
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) { public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>(); OutputHandler<TextClassifierResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() { new OutputHandler.OutputPacketConverter<TextClassifierResult, Void>() {
@Override @Override
public TextClassificationResult convertToTaskResult(List<Packet> packets) { public TextClassifierResult convertToTaskResult(List<Packet> packets) {
try { try {
return TextClassificationResult.create( return TextClassifierResult.create(
ClassificationResult.createFromProto(
PacketGetter.getProto( PacketGetter.getProto(
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()), ClassificationsProto.ClassificationResult.getDefaultInstance())),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
} catch (IOException e) { } catch (IOException e) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable {
* *
* @param inputText a {@link String} for processing. * @param inputText a {@link String} for processing.
*/ */
public TextClassificationResult classify(String inputText) { public TextClassifierResult classify(String inputText) {
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
return (TextClassificationResult) runner.process(inputPackets); return (TextClassifierResult) runner.process(inputPackets);
} }
/** Closes and cleans up the {@link TextClassifier}. */ /** Closes and cleans up the {@link TextClassifier}. */

View File

@ -0,0 +1,55 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
/** Represents the classification results generated by {@link TextClassifier}. */
@AutoValue
public abstract class TextClassifierResult implements TaskResult {
/**
* Creates an {@link TextClassifierResult} instance.
*
* @param classificationResult the {@link ClassificationResult} object containing one set of
* results per classifier head.
* @param timestampMs a timestamp for this result.
*/
static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
return new AutoValue_TextClassifierResult(classificationResult, timestampMs);
}
/**
* Creates an {@link TextClassifierResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static TextClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return create(ClassificationResult.createFromProto(proto), timestampMs);
}
/** Contains one set of results per classifier head. */
public abstract ClassificationResult classificationResult();
@Override
public abstract long timestampMs();
}

View File

@ -84,8 +84,8 @@ android_library(
android_library( android_library(
name = "imageclassifier", name = "imageclassifier",
srcs = [ srcs = [
"imageclassifier/ImageClassificationResult.java",
"imageclassifier/ImageClassifier.java", "imageclassifier/ImageClassifier.java",
"imageclassifier/ImageClassifierResult.java",
], ],
javacopts = [ javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF", "-Xep:AndroidJdkLibsChecker:OFF",
@ -100,9 +100,7 @@ android_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue", "//third_party:autovalue",

View File

@ -1,102 +0,0 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the classification results generated by {@link ImageClassifier}. */
@AutoValue
public abstract class ImageClassificationResult implements TaskResult {
/**
* Creates an {@link ImageClassificationResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
* message.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static ImageClassificationResult create(
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
classificationResult.getClassificationsList()) {
classifications.add(classificationsFromProto(classificationsProto));
}
return new AutoValue_ImageClassificationResult(
timestampMs, Collections.unmodifiableList(classifications));
}
@Override
public abstract long timestampMs();
/** Contains one set of results per classifier head. */
public abstract List<Classifications> classifications();
/**
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
*
* @param category the {@link CategoryProto.Category} protobuf message to convert.
*/
static Category categoryFromProto(CategoryProto.Category category) {
return Category.create(
category.getScore(),
category.getIndex(),
category.getCategoryName(),
category.getDisplayName());
}
/**
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
* ClassificationEntry} object.
*
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
*/
static ClassificationEntry classificationEntryFromProto(
ClassificationsProto.ClassificationEntry entry) {
List<Category> categories = new ArrayList<>();
for (CategoryProto.Category category : entry.getCategoriesList()) {
categories.add(categoryFromProto(category));
}
return ClassificationEntry.create(categories, entry.getTimestampMs());
}
/**
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
* Classifications} object.
*
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
* convert.
*/
static Classifications classificationsFromProto(
ClassificationsProto.Classifications classifications) {
List<ClassificationEntry> entries = new ArrayList<>();
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
entries.add(classificationEntryFromProto(entry));
}
return Classifications.create(
entries, classifications.getHeadIndex(), classifications.getHeadName());
}
}

View File

@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out")); Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out"));
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1; private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
*/ */
public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
OutputHandler<ImageClassificationResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageClassifierResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ImageClassificationResult, MPImage>() { new OutputHandler.OutputPacketConverter<ImageClassifierResult, MPImage>() {
@Override @Override
public ImageClassificationResult convertToTaskResult(List<Packet> packets) { public ImageClassifierResult convertToTaskResult(List<Packet> packets) {
try { try {
return ImageClassificationResult.create( return ImageClassifierResult.create(
ClassificationResult.createFromProto(
PacketGetter.getProto( PacketGetter.getProto(
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()), ClassificationsProto.ClassificationResult.getDefaultInstance())),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
} catch (IOException e) { } catch (IOException e) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classify(MPImage image) { public ImageClassifierResult classify(MPImage image) {
return classify(image, ImageProcessingOptions.builder().build()); return classify(image, ImageProcessingOptions.builder().build());
} }
@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* input image before running inference. * input image before running inference.
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classify( public ImageClassifierResult classify(
MPImage image, ImageProcessingOptions imageProcessingOptions) { MPImage image, ImageProcessingOptions imageProcessingOptions) {
return (ImageClassificationResult) processImageData(image, imageProcessingOptions); return (ImageClassifierResult) processImageData(image, imageProcessingOptions);
} }
/** /**
@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) {
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
} }
@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ImageClassificationResult classifyForVideo( public ImageClassifierResult classifyForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs);
} }
/** /**
@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* the image classifier is in the live stream mode. * the image classifier is in the live stream mode.
*/ */
public abstract Builder setResultListener( public abstract Builder setResultListener(
ResultListener<ImageClassificationResult, MPImage> resultListener); ResultListener<ImageClassifierResult, MPImage> resultListener);
/** Sets an optional {@link ErrorListener}. */ /** Sets an optional {@link ErrorListener}. */
public abstract Builder setErrorListener(ErrorListener errorListener); public abstract Builder setErrorListener(ErrorListener errorListener);
@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
abstract Optional<ClassifierOptions> classifierOptions(); abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<ResultListener<ImageClassificationResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();

View File

@ -0,0 +1,55 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imageclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
/** Represents the classification results generated by {@link ImageClassifier}. */
@AutoValue
public abstract class ImageClassifierResult implements TaskResult {
/**
* Creates an {@link ImageClassifierResult} instance.
*
* @param classificationResult the {@link ClassificationResult} object containing one set of
* results per classifier head.
* @param timestampMs a timestamp for this result.
*/
static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
return new AutoValue_ImageClassifierResult(classificationResult, timestampMs);
}
/**
* Creates an {@link ImageClassifierResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static ImageClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return create(ClassificationResult.createFromProto(proto), timestampMs);
}
/** Contains one set of results per classifier head. */
public abstract ClassificationResult classificationResult();
@Override
public abstract long timestampMs();
}

View File

@ -76,7 +76,7 @@ public class TextClassifierTest {
public void classify_succeedsWithBert() throws Exception { public void classify_succeedsWithBert() throws Exception {
TextClassifier textClassifier = TextClassifier textClassifier =
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults); assertHasOneHead(negativeResults);
assertCategoriesAre( assertCategoriesAre(
negativeResults, negativeResults,
@ -84,7 +84,7 @@ public class TextClassifierTest {
Category.create(0.95630914f, 0, "negative", ""), Category.create(0.95630914f, 0, "negative", ""),
Category.create(0.04369091f, 1, "positive", ""))); Category.create(0.04369091f, 1, "positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertCategoriesAre( assertCategoriesAre(
positiveResults, positiveResults,
@ -99,7 +99,7 @@ public class TextClassifierTest {
TextClassifier.createFromFile( TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE)); TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults); assertHasOneHead(negativeResults);
assertCategoriesAre( assertCategoriesAre(
negativeResults, negativeResults,
@ -107,7 +107,7 @@ public class TextClassifierTest {
Category.create(0.95630914f, 0, "negative", ""), Category.create(0.95630914f, 0, "negative", ""),
Category.create(0.04369091f, 1, "positive", ""))); Category.create(0.04369091f, 1, "positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertCategoriesAre( assertCategoriesAre(
@ -122,7 +122,7 @@ public class TextClassifierTest {
TextClassifier textClassifier = TextClassifier textClassifier =
TextClassifier.createFromFile( TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults); assertHasOneHead(negativeResults);
assertCategoriesAre( assertCategoriesAre(
negativeResults, negativeResults,
@ -130,7 +130,7 @@ public class TextClassifierTest {
Category.create(0.6647746f, 0, "Negative", ""), Category.create(0.6647746f, 0, "Negative", ""),
Category.create(0.33522537f, 1, "Positive", ""))); Category.create(0.33522537f, 1, "Positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults); assertHasOneHead(positiveResults);
assertCategoriesAre( assertCategoriesAre(
positiveResults, positiveResults,
@ -139,16 +139,15 @@ public class TextClassifierTest {
Category.create(0.48799595f, 1, "Positive", ""))); Category.create(0.48799595f, 1, "Positive", "")));
} }
private static void assertHasOneHead(TextClassificationResult results) { private static void assertHasOneHead(TextClassifierResult results) {
assertThat(results.classifications()).hasSize(1); assertThat(results.classificationResult().classifications()).hasSize(1);
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); assertThat(results.classificationResult().classifications().get(0).headName().get())
assertThat(results.classifications().get(0).entries()).hasSize(1); .isEqualTo("probability");
} }
private static void assertCategoriesAre( private static void assertCategoriesAre(TextClassifierResult results, List<Category> categories) {
TextClassificationResult results, List<Category> categories) { assertThat(results.classificationResult().classifications().get(0).categories())
assertThat(results.classifications().get(0).entries().get(0).categories())
.isEqualTo(categories); .isEqualTo(categories);
} }
} }

View File

@ -91,11 +91,12 @@ public class ImageClassifierTest {
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromFile( ImageClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE); ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001); assertThat(results.classificationResult().classifications().get(0).categories())
assertThat(results.classifications().get(0).entries().get(0).categories().get(0)) .hasSize(1001);
assertThat(results.classificationResult().classifications().get(0).categories().get(0))
.isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", "")); .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
} }
@ -108,9 +109,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -128,9 +129,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
} }
@ -144,9 +145,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -166,9 +167,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -190,9 +191,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -214,10 +215,10 @@ public class ImageClassifierTest {
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
ImageClassificationResult results = ImageClassifierResult results =
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
} }
@ -233,10 +234,10 @@ public class ImageClassifierTest {
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ImageClassificationResult results = ImageClassifierResult results =
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, results,
Arrays.asList( Arrays.asList(
@ -258,11 +259,11 @@ public class ImageClassifierTest {
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
ImageProcessingOptions imageProcessingOptions = ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
ImageClassificationResult results = ImageClassifierResult results =
imageClassifier.classify( imageClassifier.classify(
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
} }
@ -391,9 +392,9 @@ public class ImageClassifierTest {
.build(); .build();
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE)); ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
assertHasOneHeadAndOneTimestamp(results, 0); assertHasOneHead(results);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
} }
@ -410,9 +411,8 @@ public class ImageClassifierTest {
ImageClassifier imageClassifier = ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
ImageClassificationResult results = ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); assertHasOneHead(results);
assertHasOneHeadAndOneTimestamp(results, i);
assertCategoriesAre( assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
} }
@ -478,24 +478,17 @@ public class ImageClassifierTest {
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
} }
private static void assertHasOneHeadAndOneTimestamp( private static void assertHasOneHead(ImageClassifierResult results) {
ImageClassificationResult results, long timestampMs) { assertThat(results.classificationResult().classifications()).hasSize(1);
assertThat(results.classifications()).hasSize(1); assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); assertThat(results.classificationResult().classifications().get(0).headName().get())
assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); .isEqualTo("probability");
assertThat(results.classifications().get(0).entries()).hasSize(1);
assertThat(results.classifications().get(0).entries().get(0).timestampMs())
.isEqualTo(timestampMs);
} }
private static void assertCategoriesAre( private static void assertCategoriesAre(
ImageClassificationResult results, List<Category> categories) { ImageClassifierResult results, List<Category> categories) {
assertThat(results.classifications().get(0).entries().get(0).categories()) assertThat(results.classificationResult().classifications().get(0).categories())
.hasSize(categories.size()); .isEqualTo(categories);
for (int i = 0; i < categories.size(); i++) {
assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
.isEqualTo(categories.get(i));
}
} }
private static void assertImageSizeIsExpected(MPImage inputImage) { private static void assertImageSizeIsExpected(MPImage inputImage) {

View File

@ -174,12 +174,10 @@ class AudioClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _AudioClassifier) self.assertIsInstance(classifier, _AudioClassifier)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _AudioClassifierOptions(base_options=base_options) options = _AudioClassifierOptions(base_options=base_options)
_AudioClassifier.create_from_options(options) _AudioClassifier.create_from_options(options)

View File

@ -154,12 +154,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _TextClassifier) self.assertIsInstance(classifier, _TextClassifier)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _TextClassifierOptions(base_options=base_options) options = _TextClassifierOptions(base_options=base_options)
_TextClassifier.create_from_options(options) _TextClassifier.create_from_options(options)

View File

@ -147,12 +147,10 @@ class ImageClassifierTest(parameterized.TestCase):
self.assertIsInstance(classifier, _ImageClassifier) self.assertIsInstance(classifier, _ImageClassifier)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _ImageClassifierOptions(base_options=base_options) options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options) _ImageClassifier.create_from_options(options)

View File

@ -97,12 +97,10 @@ class ImageSegmenterTest(parameterized.TestCase):
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options) _ImageSegmenter.create_from_options(options)

View File

@ -119,12 +119,10 @@ class ObjectDetectorTest(parameterized.TestCase):
self.assertIsInstance(detector, _ObjectDetector) self.assertIsInstance(detector, _ObjectDetector)
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
r"ExternalFile must specify at least one of 'file_content', " base_options = _BaseOptions(
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): model_asset_path='/path/to/invalid/model.tflite')
base_options = _BaseOptions(model_asset_path='')
options = _ObjectDetectorOptions(base_options=base_options) options = _ObjectDetectorOptions(base_options=base_options)
_ObjectDetector.create_from_options(options) _ObjectDetector.create_from_options(options)

View File

@ -28,6 +28,7 @@ mediapipe_files(srcs = [
"bert_text_classifier.tflite", "bert_text_classifier.tflite",
"mobilebert_embedding_with_metadata.tflite", "mobilebert_embedding_with_metadata.tflite",
"mobilebert_with_metadata.tflite", "mobilebert_with_metadata.tflite",
"regex_one_embedding_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite",
"universal_sentence_encoder_qa_with_metadata.tflite", "universal_sentence_encoder_qa_with_metadata.tflite",
@ -92,6 +93,11 @@ filegroup(
srcs = ["mobilebert_embedding_with_metadata.tflite"], srcs = ["mobilebert_embedding_with_metadata.tflite"],
) )
filegroup(
name = "regex_embedding_with_metadata",
srcs = ["regex_one_embedding_with_metadata.tflite"],
)
filegroup( filegroup(
name = "universal_sentence_encoder_qa", name = "universal_sentence_encoder_qa",
data = ["universal_sentence_encoder_qa_with_metadata.tflite"], data = ["universal_sentence_encoder_qa_with_metadata.tflite"],

View File

@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
*/ */
async setOptions(options: AudioClassifierOptions): Promise<void> { async setOptions(options: AudioClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -26,6 +26,8 @@ mediapipe_ts_library(
name = "base_options", name = "base_options",
srcs = ["base_options.ts"], srcs = ["base_options.ts"],
deps = [ deps = [
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",

View File

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {BaseOptions} from '../../../../tasks/web/core/base_options';
@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
* Converts a BaseOptions API object to its Protobuf representation. * Converts a BaseOptions API object to its Protobuf representation.
* @throws If neither a model assset path or buffer is provided * @throws If neither a model assset path or buffer is provided
*/ */
export async function convertBaseOptionsToProto(baseOptions: BaseOptions): export async function convertBaseOptionsToProto(
Promise<BaseOptionsProto> { updatedOptions: BaseOptions,
if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) { currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
const result =
currentOptions ? currentOptions.clone() : new BaseOptionsProto();
await configureExternalFile(updatedOptions, result);
configureAcceleration(updatedOptions, result);
return result;
}
/**
* Configues the `externalFile` option and validates that a single model is
* provided.
*/
async function configureExternalFile(
options: BaseOptions, proto: BaseOptionsProto) {
const externalFile = proto.getModelAsset() || new ExternalFile();
proto.setModelAsset(externalFile);
if (options.modelAssetPath || options.modelAssetBuffer) {
if (options.modelAssetPath && options.modelAssetBuffer) {
throw new Error( throw new Error(
'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
} }
if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
let modelAssetBuffer = options.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(options.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
}
externalFile.setFileContent(modelAssetBuffer);
}
if (!externalFile.hasFileContent()) {
throw new Error( throw new Error(
'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
} }
let modelAssetBuffer = baseOptions.modelAssetBuffer;
if (!modelAssetBuffer) {
const response = await fetch(baseOptions.modelAssetPath!.toString());
modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
} }
const proto = new BaseOptionsProto(); /** Configues the `acceleration` option. */
const externalFile = new ExternalFile(); function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
externalFile.setFileContent(modelAssetBuffer); if ('delegate' in options) {
proto.setModelAsset(externalFile); const acceleration = new Acceleration();
return proto; if (options.delegate === 'cpu') {
acceleration.setXnnpack(
new InferenceCalculatorOptions.Delegate.Xnnpack());
proto.setAcceleration(acceleration);
} else if (options.delegate === 'gpu') {
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
proto.setAcceleration(acceleration);
} else {
proto.clearAcceleration();
}
}
} }

View File

@ -22,10 +22,14 @@ export interface BaseOptions {
* The model path to the model asset file. Only one of `modelAssetPath` or * The model path to the model asset file. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set. * `modelAssetBuffer` can be set.
*/ */
modelAssetPath?: string; modelAssetPath?: string|undefined;
/** /**
* A buffer containing the model aaset. Only one of `modelAssetPath` or * A buffer containing the model aaset. Only one of `modelAssetPath` or
* `modelAssetBuffer` can be set. * `modelAssetBuffer` can be set.
*/ */
modelAssetBuffer?: Uint8Array; modelAssetBuffer?: Uint8Array|undefined;
/** Overrides the default backend to use for the provided model. */
delegate?: 'cpu'|'gpu'|undefined;
} }

View File

@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner {
*/ */
async setOptions(options: TextClassifierOptions): Promise<void> { async setOptions(options: TextClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner {
*/ */
async setOptions(options: GestureRecognizerOptions): Promise<void> { async setOptions(options: GestureRecognizerOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner {
*/ */
async setOptions(options: ImageClassifierOptions): Promise<void> { async setOptions(options: ImageClassifierOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner {
*/ */
async setOptions(options: ObjectDetectorOptions): Promise<void> { async setOptions(options: ObjectDetectorOptions): Promise<void> {
if (options.baseOptions) { if (options.baseOptions) {
const baseOptionsProto = const baseOptionsProto = await convertBaseOptionsToProto(
await convertBaseOptionsToProto(options.baseOptions); options.baseOptions, this.options.getBaseOptions());
this.options.setBaseOptions(baseOptionsProto); this.options.setBaseOptions(baseOptionsProto);
} }

View File

@ -402,8 +402,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_labels_txt", name = "com_google_mediapipe_labels_txt",
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a", sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667855388142641"], urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667888034706429"],
) )
http_file( http_file(
@ -552,8 +552,14 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_movie_review_json", name = "com_google_mediapipe_movie_review_json",
sha256 = "89ad347ad1cb7c587da144de6efbadec1d3e8ff0cd13e379dd16661a8186fbb5", sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e",
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667855392734031"], urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667888039053188"],
)
http_file(
name = "com_google_mediapipe_movie_review_labels_txt",
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667888041670721"],
) )
http_file( http_file(
@ -688,6 +694,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"], urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"],
) )
http_file(
name = "com_google_mediapipe_regex_one_embedding_with_metadata_tflite",
sha256 = "b8f5d6d090c2c73984b2b92cd2fda27e5630562741a93d127b9a744d60505bc0",
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_one_embedding_with_metadata.tflite?generation=1667888045310541"],
)
http_file(
name = "com_google_mediapipe_regex_vocab_txt",
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667888047885461"],
)
http_file( http_file(
name = "com_google_mediapipe_right_hands_jpg", name = "com_google_mediapipe_right_hands_jpg",
sha256 = "240c082e80128ff1ca8a83ce645e2ba4d8bc30f0967b7991cf5fa375bab489e1", sha256 = "240c082e80128ff1ca8a83ce645e2ba4d8bc30f0967b7991cf5fa375bab489e1",