Adds a Java API for TextClassifier.
PiperOrigin-RevId: 482394706
This commit is contained in:
parent
3d588bae8b
commit
41d6f6d005
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.text.textclassifier.proto";
|
||||
option java_outer_classname = "TextClassifierGraphOptionsProto";
|
||||
|
||||
message TextClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TextClassifierGraphOptions ext = 462704549;
|
||||
|
|
|
@ -117,7 +117,7 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
|||
if (errorListener != null) {
|
||||
errorListener.onError(e);
|
||||
} else {
|
||||
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e);
|
||||
Log.e(TAG, "Error occurs when getting MediaPipe task result. " + e);
|
||||
}
|
||||
} finally {
|
||||
for (Packet packet : packets) {
|
||||
|
|
63
mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
Normal file
63
mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
# The native library of all MediaPipe text tasks.
|
||||
cc_binary(
|
||||
name = "libmediapipe_tasks_text_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libmediapipe_tasks_text_jni_lib",
|
||||
srcs = [":libmediapipe_tasks_text_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "textclassifier",
|
||||
srcs = [
|
||||
"textclassifier/TextClassificationResult.java",
|
||||
"textclassifier/TextClassifier.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "textclassifier/AndroidManifest.xml",
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.text.textclassifier">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the classification results generated by {@link TextClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class TextClassificationResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link TextClassificationResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
||||
* message.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static TextClassificationResult create(
|
||||
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
||||
List<Classifications> classifications = new ArrayList<>();
|
||||
for (ClassificationsProto.Classifications classificationsProto :
|
||||
classificationResult.getClassificationsList()) {
|
||||
classifications.add(classificationsFromProto(classificationsProto));
|
||||
}
|
||||
return new AutoValue_TextClassificationResult(
|
||||
timestampMs, Collections.unmodifiableList(classifications));
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Contains one set of results per classifier head. */
|
||||
@SuppressWarnings("AutoValueImmutableFields")
|
||||
public abstract List<Classifications> classifications();
|
||||
|
||||
/**
|
||||
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
||||
*
|
||||
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
||||
*/
|
||||
static Category categoryFromProto(CategoryProto.Category category) {
|
||||
return Category.create(
|
||||
category.getScore(),
|
||||
category.getIndex(),
|
||||
category.getCategoryName(),
|
||||
category.getDisplayName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
||||
* ClassificationEntry} object.
|
||||
*
|
||||
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
||||
*/
|
||||
static ClassificationEntry classificationEntryFromProto(
|
||||
ClassificationsProto.ClassificationEntry entry) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
||||
categories.add(categoryFromProto(category));
|
||||
}
|
||||
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
||||
* Classifications} object.
|
||||
*
|
||||
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
||||
* convert.
|
||||
*/
|
||||
static Classifications classificationsFromProto(
|
||||
ClassificationsProto.Classifications classifications) {
|
||||
List<ClassificationEntry> entries = new ArrayList<>();
|
||||
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
||||
entries.add(classificationEntryFromProto(entry));
|
||||
}
|
||||
return Classifications.create(
|
||||
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.text.textclassifier.proto.TextClassifierGraphOptionsProto;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs classification on text.
|
||||
*
|
||||
* <p>This API expects a TFLite model with (optional) <a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata</a> that contains
|
||||
* the mandatory (described below) input tensors, output tensor, and the optional (but recommended)
|
||||
* label items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
|
||||
*
|
||||
* <p>Metadata is required for models with int32 input tensors because it contains the input process
|
||||
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input tensors
|
||||
* <ul>
|
||||
* <li>Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x
|
||||
* bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input
|
||||
* signature requires a Bert Tokenizer process unit in the model metadata.
|
||||
* <li>Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x
|
||||
* max_seq_len]} representing the input ids. This input signature requires a Regex
|
||||
* Tokenizer process unit in the model metadata.
|
||||
* <li>Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code
|
||||
* [1]} containing the input string.
|
||||
* </ul>
|
||||
* <li>At least one output tensor ({@code kTfLiteFloat32}/{@code kBool}) with:
|
||||
* <ul>
|
||||
* <li>{@code N} classes and shape {@code [1 x N]}
|
||||
* <li>optional (but recommended) label map(s) as AssociatedFile-s with type
|
||||
* TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
|
||||
* any) is used to fill the {@code class_name} field of the results. The {@code
|
||||
* display_name} field is filled from the AssociatedFile (if any) whose locale matches
|
||||
* the {@code display_names_locale} field of the {@code TextClassifierOptions} used at
|
||||
* creation time ("en" by default, i.e. English). If none of these are available, only
|
||||
* the {@code index} field of the results will be filled.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class TextClassifier implements AutoCloseable {
|
||||
private static final String TAG = TextClassifier.class.getSimpleName();
|
||||
private static final String TEXT_IN_STREAM_NAME = "text_in";
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
|
||||
|
||||
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
private final TaskRunner runner;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_text_jni");
|
||||
ProtoUtil.registerTypeName(
|
||||
ClassificationsProto.ClassificationResult.class,
|
||||
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TextClassifier} instance from a model file and the default {@link
|
||||
* TextClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the text model with metadata in the assets.
|
||||
* @throws MediaPipeException if there is is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TextClassifier} instance from a model file and the default {@link
|
||||
* TextClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the text model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TextClassifier} instance from {@link TextClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param options a {@link TextClassifierOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
|
||||
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() {
|
||||
@Override
|
||||
public TextClassificationResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
return TextClassificationResult.create(
|
||||
PacketGetter.getProto(
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<TextClassifierOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(options)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
handler);
|
||||
return new TextClassifier(runner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link TextClassifier} from a {@link TaskRunner}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
*/
|
||||
private TextClassifier(TaskRunner runner) {
|
||||
this.runner = runner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs classification on the input text.
|
||||
*
|
||||
* @param inputText a {@link String} for processing.
|
||||
*/
|
||||
public TextClassificationResult classify(String inputText) {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
||||
return (TextClassificationResult) runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the {@link TextClassifier}. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
}
|
||||
|
||||
/** Options for setting up a {@link TextClassifier}. */
|
||||
@AutoValue
|
||||
public abstract static class TextClassifierOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link TextClassifierOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the text classifier task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
|
||||
public abstract TextClassifierOptions build();
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
|
||||
}
|
||||
|
||||
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.text.textclassifiertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="textclassifiertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.text.textclassifiertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.text.textclassifier.TextClassifier.TextClassifierOptions;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
/** Test for {@link TextClassifier}/ */
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class TextClassifierTest {
|
||||
private static final String BERT_MODEL_FILE = "bert_text_classifier.tflite";
|
||||
private static final String REGEX_MODEL_FILE =
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
private static final String STRING_TO_BOOL_MODEL_FILE =
|
||||
"test_model_text_classifier_bool_output.tflite";
|
||||
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
|
||||
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingModel() throws Exception {
|
||||
String nonExistentFile = "/path/to/non/existent/file";
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), nonExistentFile));
|
||||
assertThat(exception).hasMessageThat().contains(nonExistentFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingOpResolver() throws Exception {
|
||||
TextClassifierOptions options =
|
||||
TextClassifierOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(STRING_TO_BOOL_MODEL_FILE).build())
|
||||
.build();
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
TextClassifier.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options));
|
||||
// TODO: Make MediaPipe InferenceCalculator report the detailed.
|
||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("interpreter_builder(&interpreter) == kTfLiteOk");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithBert() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.95630914f, 0, "negative", ""),
|
||||
Category.create(0.04369091f, 1, "positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.99997187f, 1, "positive", ""),
|
||||
Category.create(2.8132641E-5f, 0, "negative", "")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithFileObject() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.95630914f, 0, "negative", ""),
|
||||
Category.create(0.04369091f, 1, "positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.99997187f, 1, "positive", ""),
|
||||
Category.create(2.8132641E-5f, 0, "negative", "")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithRegex() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.6647746f, 0, "Negative", ""),
|
||||
Category.create(0.33522537f, 1, "Positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.5120041f, 0, "Negative", ""),
|
||||
Category.create(0.48799595f, 1, "Positive", "")));
|
||||
}
|
||||
|
||||
private static void assertHasOneHead(TextClassificationResult results) {
|
||||
assertThat(results.classifications()).hasSize(1);
|
||||
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
||||
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
||||
}
|
||||
|
||||
private static void assertCategoriesAre(
|
||||
TextClassificationResult results, List<Category> categories) {
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories())
|
||||
.isEqualTo(categories);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user