Adds a LanguageDetector Java API.
PiperOrigin-RevId: 522895455
This commit is contained in:
parent
4f77504af6
commit
c036b9f408
|
@ -40,7 +40,7 @@ public class MediaPipeException extends RuntimeException {
|
||||||
return statusMessage;
|
return statusMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The 17 canonical status codes. */
|
/** The 18 canonical status codes. */
|
||||||
public enum StatusCode {
|
public enum StatusCode {
|
||||||
OK("ok"),
|
OK("ok"),
|
||||||
CANCELLED("canceled"),
|
CANCELLED("canceled"),
|
||||||
|
@ -58,7 +58,8 @@ public class MediaPipeException extends RuntimeException {
|
||||||
INTERNAL("internal"),
|
INTERNAL("internal"),
|
||||||
UNAVAILABLE("unavailable"),
|
UNAVAILABLE("unavailable"),
|
||||||
DATA_LOSS("data loss"),
|
DATA_LOSS("data loss"),
|
||||||
UNAUTHENTICATED("unauthenticated");
|
UNAUTHENTICATED("unauthenticated"),
|
||||||
|
IO_EXCEPTION("i/o exception");
|
||||||
|
|
||||||
StatusCode(String description) {
|
StatusCode(String description) {
|
||||||
this.description = description;
|
this.description = description;
|
||||||
|
|
|
@ -92,6 +92,34 @@ android_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "languagedetector",
|
||||||
|
srcs = [
|
||||||
|
"languagedetector/LanguageDetector.java",
|
||||||
|
"languagedetector/LanguageDetectorResult.java",
|
||||||
|
"languagedetector/LanguagePrediction.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = "textembedder/AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
":libmediapipe_tasks_text_jni_lib",
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
||||||
|
|
||||||
mediapipe_tasks_text_aar(
|
mediapipe_tasks_text_aar(
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.text.languagedetector">
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,330 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.languagedetector;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.text.textclassifier.proto.TextClassifierGraphOptionsProto;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predicts the language of an input text.
|
||||||
|
*
|
||||||
|
* <p>This API expects a TFLite model with <a
|
||||||
|
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata</a> that contains
|
||||||
|
* the mandatory (described below) input tensors, output tensor, and the language codes in an
|
||||||
|
* AssociatedFile.
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Input tensor
|
||||||
|
* <ul>
|
||||||
|
* <li>One input tensor ({@code kTfLiteString}) of shape [1] containing the input string.
|
||||||
|
* </ul>
|
||||||
|
* <li>Output tensor
|
||||||
|
* <ul>
|
||||||
|
* <li>One output tensor ({@code kTfLiteFloat32}) of shape {@code [1 x N]} where {@code N}
|
||||||
|
* is the number of languages.
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public final class LanguageDetector implements AutoCloseable {
|
||||||
|
private static final String TAG = LanguageDetector.class.getSimpleName();
|
||||||
|
private static final String TEXT_IN_STREAM_NAME = "text_in";
|
||||||
|
|
||||||
|
private static final List<String> inputStreams =
|
||||||
|
Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
|
||||||
|
|
||||||
|
private static final List<String> outputStreams =
|
||||||
|
Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out"));
|
||||||
|
|
||||||
|
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final String TASK_GRAPH_NAME =
|
||||||
|
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
private final TaskRunner runner;
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("mediapipe_tasks_text_jni");
|
||||||
|
ProtoUtil.registerTypeName(
|
||||||
|
ClassificationsProto.ClassificationResult.class,
|
||||||
|
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link LanguageDetector} instance from a model file and the default {@link
|
||||||
|
* LanguageDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the text model with metadata in the assets.
|
||||||
|
* @throws MediaPipeException if there is is an error during {@link LanguageDetector} creation.
|
||||||
|
*/
|
||||||
|
public static LanguageDetector createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, LanguageDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link LanguageDetector} instance from a model file and the default {@link
|
||||||
|
* LanguageDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the text model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link LanguageDetector} creation.
|
||||||
|
*/
|
||||||
|
public static LanguageDetector createFromFile(Context context, File modelFile)
|
||||||
|
throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, LanguageDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link LanguageDetector} instance from {@link LanguageDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param options a {@link LanguageDetectorOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link LanguageDetector} creation.
|
||||||
|
*/
|
||||||
|
public static LanguageDetector createFromOptions(
|
||||||
|
Context context, LanguageDetectorOptions options) {
|
||||||
|
OutputHandler<LanguageDetectorResult, Void> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<LanguageDetectorResult, Void>() {
|
||||||
|
@Override
|
||||||
|
public LanguageDetectorResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
if (packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
|
List<Category> emptyCategories = new ArrayList<>();
|
||||||
|
Classifications emptyClassifications =
|
||||||
|
Classifications.create(emptyCategories, 0, Optional.empty());
|
||||||
|
ClassificationResult classificationResult =
|
||||||
|
ClassificationResult.create(
|
||||||
|
Arrays.asList(emptyClassifications), Optional.empty());
|
||||||
|
return LanguageDetectorResult.create(
|
||||||
|
classificationResult,
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return LanguageDetectorResult.create(
|
||||||
|
ClassificationResult.createFromProto(
|
||||||
|
PacketGetter.getProto(
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
|
ClassificationsProto.ClassificationResult.getDefaultInstance())),
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.IO_EXCEPTION.ordinal(), e.getMessage());
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Void convertToTaskInput(List<Packet> packets) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<LanguageDetectorOptions>builder()
|
||||||
|
.setTaskName(LanguageDetector.class.getSimpleName())
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(inputStreams)
|
||||||
|
.setOutputStreams(outputStreams)
|
||||||
|
.setTaskOptions(options)
|
||||||
|
.setEnableFlowLimiting(false)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
return new LanguageDetector(runner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize a {@link LanguageDetector} from a {@link TaskRunner}.
|
||||||
|
*
|
||||||
|
* @param runner a {@link TaskRunner}.
|
||||||
|
*/
|
||||||
|
private LanguageDetector(TaskRunner runner) {
|
||||||
|
this.runner = runner;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predicts the language of the input text.
|
||||||
|
*
|
||||||
|
* @param inputText a {@link String} for processing.
|
||||||
|
*/
|
||||||
|
public LanguageDetectorResult detect(String inputText) {
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
||||||
|
return (LanguageDetectorResult) runner.process(inputPackets);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the {@link LanguageDetector}. */
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
runner.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up a {@link LanguageDetector}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class LanguageDetectorOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link LanguageDetectorOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the base options for the text classifier task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||||
|
* Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setDisplayNamesLocale(String locale);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional maximum number of top-scored classification results to return.
|
||||||
|
*
|
||||||
|
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||||
|
*/
|
||||||
|
public abstract Builder setMaxResults(Integer maxResults);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||||
|
*
|
||||||
|
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional allowlist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||||
|
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryDenylist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional denylist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||||
|
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryAllowlist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||||
|
|
||||||
|
abstract LanguageDetectorOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link LanguageDetectorOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if any of the set options are invalid.
|
||||||
|
*/
|
||||||
|
public final LanguageDetectorOptions build() {
|
||||||
|
LanguageDetectorOptions options = autoBuild();
|
||||||
|
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||||
|
throw new IllegalArgumentException("If specified, maxResults must be > 0.");
|
||||||
|
}
|
||||||
|
if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Category allowlist and denylist are mutually exclusive.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract Optional<String> displayNamesLocale();
|
||||||
|
|
||||||
|
abstract Optional<Integer> maxResults();
|
||||||
|
|
||||||
|
abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
// For backwards-compatibility reasons in OSS we want to avoid dependencies on libraries like
|
||||||
|
// Guava, so we don't use ImmutableList.
|
||||||
|
@SuppressWarnings("AutoValueImmutableFields")
|
||||||
|
abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
@SuppressWarnings("AutoValueImmutableFields")
|
||||||
|
abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
|
@SuppressWarnings("AutoValueImmutableFields")
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_LanguageDetector_LanguageDetectorOptions.Builder()
|
||||||
|
.setCategoryAllowlist(Collections.emptyList())
|
||||||
|
.setCategoryDenylist(Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link LanguageDetectorOptions} to a {@link CalculatorOptions} protobuf message.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder =
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||||
|
displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale);
|
||||||
|
maxResults().ifPresent(classifierOptionsBuilder::setMaxResults);
|
||||||
|
scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold);
|
||||||
|
if (!categoryAllowlist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||||
|
}
|
||||||
|
if (!categoryDenylist().isEmpty()) {
|
||||||
|
classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||||
|
}
|
||||||
|
TextClassifierGraphOptionsProto.TextClassifierGraphOptions taskOptions =
|
||||||
|
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
|
||||||
|
.setBaseOptions(baseOptionsBuilder)
|
||||||
|
.setClassifierOptions(classifierOptionsBuilder)
|
||||||
|
.build();
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, taskOptions)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.languagedetector;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Represents the prediction results generated by {@link LanguageDetector}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class LanguageDetectorResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link LanguageDetectorResult} instance.
|
||||||
|
*
|
||||||
|
* @param classificationResult the {@link ClassificationResult} object containing one set of
|
||||||
|
* results per classifier head.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static LanguageDetectorResult create(
|
||||||
|
ClassificationResult classificationResult, long timestampMs) {
|
||||||
|
if (classificationResult.classifications().size() != 1) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Expected 1 classification head, got " + classificationResult.classifications().size());
|
||||||
|
}
|
||||||
|
Classifications classifications = classificationResult.classifications().get(0);
|
||||||
|
List<LanguagePrediction> languagePredictions = new ArrayList<>();
|
||||||
|
for (Category category : classifications.categories()) {
|
||||||
|
languagePredictions.add(LanguagePrediction.create(category.categoryName(), category.score()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return new AutoValue_LanguageDetectorResult(languagePredictions, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link LanguageDetectorResult} instance from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static LanguageDetectorResult createFromProto(
|
||||||
|
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||||
|
return create(ClassificationResult.createFromProto(proto), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A list of predictions from the LanguageDetector. */
|
||||||
|
// For backwards-compatibility reasons in OSS we want to avoid dependencies on libraries like
|
||||||
|
// Guava, so we don't use ImmutableList.
|
||||||
|
@SuppressWarnings("AutoValueImmutableFields")
|
||||||
|
public abstract List<LanguagePrediction> languagesAndScores();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.languagedetector;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
|
||||||
|
/** A language code and its probability. Used as part of the output of {@link LanguageDetector}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class LanguagePrediction {
|
||||||
|
/**
|
||||||
|
* Creates a {@link LanguageDetectorPrediction} instance.
|
||||||
|
*
|
||||||
|
* @param languageCode An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||||
|
* "ja"-Latn for Japanese (romaji).
|
||||||
|
* @param probability The probability for the prediction.
|
||||||
|
*/
|
||||||
|
public static LanguagePrediction create(String languageCode, float probability) {
|
||||||
|
return new AutoValue_LanguagePrediction(languageCode, probability);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The i18n language / locale code for the prediction. */
|
||||||
|
public abstract String languageCode();
|
||||||
|
|
||||||
|
/** The probability for the prediction. */
|
||||||
|
public abstract float probability();
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.text.languagedetectortest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="languagedetectortest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.text.languagedetectortest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable this in OSS
|
|
@ -0,0 +1,110 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.text.languagedetector;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import androidx.test.core.app.ApplicationProvider;
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TestUtils;
|
||||||
|
import com.google.mediapipe.tasks.text.languagedetector.LanguageDetector.LanguageDetectorOptions;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
|
||||||
|
/** Test for {@link LanguageDetector}/ */
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public class LanguageDetectorTest {
|
||||||
|
private static final String MODEL_FILE = "language_detector.tflite";
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void options_failsWithNegativeMaxResults() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
LanguageDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setMaxResults(-1)
|
||||||
|
.build());
|
||||||
|
assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void options_failsWithBothAllowlistAndDenylist() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
LanguageDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setCategoryAllowlist(Arrays.asList("foo"))
|
||||||
|
.setCategoryDenylist(Arrays.asList("bar"))
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("Category allowlist and denylist are mutually exclusive");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingModel() throws Exception {
|
||||||
|
String nonExistentFile = "/path/to/non/existent/file";
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
LanguageDetector.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(), nonExistentFile));
|
||||||
|
assertThat(exception).hasMessageThat().contains(nonExistentFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithL2CModel() throws Exception {
|
||||||
|
LanguageDetector languageDetector =
|
||||||
|
LanguageDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||||
|
LanguageDetectorResult enResult =
|
||||||
|
languageDetector.detect("To be, or not to be, that is the question");
|
||||||
|
assertThat(enResult.languagesAndScores().size()).isEqualTo(1);
|
||||||
|
assertThat(enResult.languagesAndScores().get(0))
|
||||||
|
.isEqualTo(LanguagePrediction.create("en", 0.9998559f));
|
||||||
|
LanguageDetectorResult frResult =
|
||||||
|
languageDetector.detect(
|
||||||
|
"Il y a beaucoup de bouches qui parlent et fort peu de têtes qui pensent.");
|
||||||
|
assertThat(frResult.languagesAndScores().size()).isEqualTo(1);
|
||||||
|
assertThat(frResult.languagesAndScores().get(0))
|
||||||
|
.isEqualTo(LanguagePrediction.create("fr", 0.9997813f));
|
||||||
|
LanguageDetectorResult ruResult = languageDetector.detect("это какой-то английский язык");
|
||||||
|
assertThat(ruResult.languagesAndScores().size()).isEqualTo(1);
|
||||||
|
assertThat(ruResult.languagesAndScores().get(0))
|
||||||
|
.isEqualTo(LanguagePrediction.create("ru", 0.9933616f));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithFileObject() throws Exception {
|
||||||
|
LanguageDetector languageDetector =
|
||||||
|
LanguageDetector.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(),
|
||||||
|
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||||
|
LanguageDetectorResult mixedResult = languageDetector.detect("分久必合合久必分");
|
||||||
|
assertThat(mixedResult.languagesAndScores().size()).isEqualTo(2);
|
||||||
|
assertThat(mixedResult.languagesAndScores().get(0))
|
||||||
|
.isEqualTo(LanguagePrediction.create("zh", 0.50542367f));
|
||||||
|
assertThat(mixedResult.languagesAndScores().get(1))
|
||||||
|
.isEqualTo(LanguagePrediction.create("ja", 0.4816168f));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user