Create Language Detection Web API
PiperOrigin-RevId: 523445524
This commit is contained in:
parent
b66b0e0c72
commit
221d545080
|
@ -19,6 +19,7 @@ mediapipe_files(srcs = [
|
||||||
|
|
||||||
TEXT_LIBS = [
|
TEXT_LIBS = [
|
||||||
"//mediapipe/tasks/web/core:fileset_resolver",
|
"//mediapipe/tasks/web/core:fileset_resolver",
|
||||||
|
"//mediapipe/tasks/web/text/language_detector",
|
||||||
"//mediapipe/tasks/web/text/text_classifier",
|
"//mediapipe/tasks/web/text/text_classifier",
|
||||||
"//mediapipe/tasks/web/text/text_embedder",
|
"//mediapipe/tasks/web/text/text_embedder",
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,9 +2,23 @@
|
||||||
|
|
||||||
This package contains the text tasks for MediaPipe.
|
This package contains the text tasks for MediaPipe.
|
||||||
|
|
||||||
|
## Language Detection
|
||||||
|
|
||||||
|
The MediaPipe Language Detector task predicts the language of an input text.
|
||||||
|
|
||||||
|
```
|
||||||
|
const text = await FilesetResolver.forTextTasks(
|
||||||
|
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm"
|
||||||
|
);
|
||||||
|
const languageDetector = await LanguageDetector.createFromModelPath(text,
|
||||||
|
"model.tflite"
|
||||||
|
);
|
||||||
|
const result = languageDetector.detect(textData);
|
||||||
|
```
|
||||||
|
|
||||||
## Text Classification
|
## Text Classification
|
||||||
|
|
||||||
MediaPipe Text Classifier task lets you classify text into a set of defined
|
The MediaPipe Text Classifier task lets you classify text into a set of defined
|
||||||
categories, such as positive or negative sentiment.
|
categories, such as positive or negative sentiment.
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -15,13 +15,15 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver';
|
||||||
|
import {LanguageDetector as LanguageDetectorImpl} from '../../../tasks/web/text/language_detector/language_detector';
|
||||||
import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier';
|
import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier';
|
||||||
import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder';
|
import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder';
|
||||||
|
|
||||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||||
// as exports.
|
// as exports.
|
||||||
const FilesetResolver = FilesetResolverImpl;
|
const FilesetResolver = FilesetResolverImpl;
|
||||||
|
const LanguageDetector = LanguageDetectorImpl;
|
||||||
const TextClassifier = TextClassifierImpl;
|
const TextClassifier = TextClassifierImpl;
|
||||||
const TextEmbedder = TextEmbedderImpl;
|
const TextEmbedder = TextEmbedderImpl;
|
||||||
|
|
||||||
export {FilesetResolver, TextClassifier, TextEmbedder};
|
export {LanguageDetector, FilesetResolver, TextClassifier, TextEmbedder};
|
||||||
|
|
66
mediapipe/tasks/web/text/language_detector/BUILD
Normal file
66
mediapipe/tasks/web/text/language_detector/BUILD
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
# This contains the MediaPipe Language Detector Task.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "language_detector",
|
||||||
|
srcs = ["language_detector.ts"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":language_detector_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_declaration(
|
||||||
|
name = "language_detector_types",
|
||||||
|
srcs = [
|
||||||
|
"language_detector_options.d.ts",
|
||||||
|
"language_detector_result.d.ts",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "language_detector_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"language_detector_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":language_detector",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "language_detector_test",
|
||||||
|
deps = [":language_detector_test_lib"],
|
||||||
|
)
|
182
mediapipe/tasks/web/text/language_detector/language_detector.ts
Normal file
182
mediapipe/tasks/web/text/language_detector/language_detector.ts
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
|
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
|
||||||
|
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
|
||||||
|
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
|
||||||
|
import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {LanguageDetectorOptions} from './language_detector_options';
|
||||||
|
import {LanguageDetectorResult} from './language_detector_result';
|
||||||
|
|
||||||
|
export * from './language_detector_options';
|
||||||
|
export * from './language_detector_result';
|
||||||
|
|
||||||
|
const INPUT_STREAM = 'text_in';
|
||||||
|
const CLASSIFICATIONS_STREAM = 'classifications_out';
|
||||||
|
const TEXT_CLASSIFIER_GRAPH =
|
||||||
|
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/** Predicts the language of an input text. */
|
||||||
|
export class LanguageDetector extends TaskRunner {
|
||||||
|
private result: LanguageDetectorResult = {languages: []};
|
||||||
|
private readonly options = new TextClassifierGraphOptions();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new language detector from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmFileset A configuration object that provides the location of the
|
||||||
|
* Wasm binary and its loader.
|
||||||
|
* @param textClassifierOptions The options for the language detector. Note
|
||||||
|
* that either a path to the TFLite model or the model itself needs to be
|
||||||
|
* provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static createFromOptions(
|
||||||
|
wasmFileset: WasmFileset, textClassifierOptions: LanguageDetectorOptions):
|
||||||
|
Promise<LanguageDetector> {
|
||||||
|
return TaskRunner.createInstance(
|
||||||
|
LanguageDetector, /* canvas= */ null, wasmFileset,
|
||||||
|
textClassifierOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new language detector based on
|
||||||
|
* the provided model asset buffer.
|
||||||
|
* @param wasmFileset A configuration object that provides the location of the
|
||||||
|
* Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmFileset: WasmFileset,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<LanguageDetector> {
|
||||||
|
return TaskRunner.createInstance(
|
||||||
|
LanguageDetector, /* canvas= */ null, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new language detector based on
|
||||||
|
* the path to the model asset.
|
||||||
|
* @param wasmFileset A configuration object that provides the location of the
|
||||||
|
* Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static createFromModelPath(
|
||||||
|
wasmFileset: WasmFileset,
|
||||||
|
modelAssetPath: string): Promise<LanguageDetector> {
|
||||||
|
return TaskRunner.createInstance(
|
||||||
|
LanguageDetector, /* canvas= */ null, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetPath}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @hideconstructor */
|
||||||
|
constructor(
|
||||||
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(new CachedGraphRunner(wasmModule, glCanvas));
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the language detector.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those options.
|
||||||
|
* You can reset an option back to its default value by explicitly setting it
|
||||||
|
* to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the language detector.
|
||||||
|
*/
|
||||||
|
override setOptions(options: LanguageDetectorOptions): Promise<void> {
|
||||||
|
this.options.setClassifierOptions(convertClassifierOptionsToProto(
|
||||||
|
options, this.options.getClassifierOptions()));
|
||||||
|
return this.applyOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
|
this.options.setBaseOptions(proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Predicts the language of the input text.
|
||||||
|
*
|
||||||
|
* @param text The text to process.
|
||||||
|
* @return The languages detected in the input text.
|
||||||
|
*/
|
||||||
|
detect(text: string): LanguageDetectorResult {
|
||||||
|
this.result = {languages: []};
|
||||||
|
this.graphRunner.addStringToStream(
|
||||||
|
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||||
|
this.finishProcessing();
|
||||||
|
return this.result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
protected override refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
TextClassifierGraphOptions.ext, this.options);
|
||||||
|
|
||||||
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
|
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
||||||
|
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
||||||
|
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||||
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
|
this.graphRunner.attachProtoListener(
|
||||||
|
CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => {
|
||||||
|
const {classifications} = convertFromClassificationResultProto(
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto));
|
||||||
|
if (classifications.length !== 1) {
|
||||||
|
throw new Error(`Expected 1 classification head, got ${
|
||||||
|
classifications.length}`);
|
||||||
|
}
|
||||||
|
this.result.languages = classifications[0].categories.map(c => {
|
||||||
|
return {languageCode: c.categoryName, probability: c.score};
|
||||||
|
});
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
CLASSIFICATIONS_STREAM, timestamp => {
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
22
mediapipe/tasks/web/text/language_detector/language_detector_options.d.ts
vendored
Normal file
22
mediapipe/tasks/web/text/language_detector/language_detector_options.d.ts
vendored
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options';
|
||||||
|
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||||
|
|
||||||
|
/** Options to configure the MediaPipe Language Detector Task */
|
||||||
|
export declare interface LanguageDetectorOptions extends ClassifierOptions,
|
||||||
|
TaskRunnerOptions {}
|
33
mediapipe/tasks/web/text/language_detector/language_detector_result.d.ts
vendored
Normal file
33
mediapipe/tasks/web/text/language_detector/language_detector_result.d.ts
vendored
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** A language code and its probability. */
|
||||||
|
export declare interface LanguageDetectorPrediction {
|
||||||
|
/**
|
||||||
|
* An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek,
|
||||||
|
* "ja"-Latn for Japanese (romaji).
|
||||||
|
*/
|
||||||
|
languageCode: string;
|
||||||
|
|
||||||
|
/** The probability */
|
||||||
|
probability: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The result of language detection. */
|
||||||
|
export declare interface LanguageDetectorResult {
|
||||||
|
/** A list of language predictions. */
|
||||||
|
languages: LanguageDetectorPrediction[];
|
||||||
|
}
|
|
@ -0,0 +1,169 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {LanguageDetector} from './language_detector';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class LanguageDetectorFake extends LanguageDetector implements
|
||||||
|
MediapipeTasksFake {
|
||||||
|
calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
protoListener:
|
||||||
|
((binaryProto: Uint8Array, timestamp: number) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('classifications_out');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('LanguageDetector', () => {
|
||||||
|
let languageDetector: LanguageDetectorFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
languageDetector = new LanguageDetectorFake();
|
||||||
|
await languageDetector.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(languageDetector);
|
||||||
|
verifyListenersRegistered(languageDetector);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await languageDetector.setOptions({maxResults: 1});
|
||||||
|
verifyGraph(languageDetector, [['classifierOptions', 'maxResults'], 1]);
|
||||||
|
verifyListenersRegistered(languageDetector);
|
||||||
|
|
||||||
|
await languageDetector.setOptions({maxResults: 5});
|
||||||
|
verifyGraph(languageDetector, [['classifierOptions', 'maxResults'], 5]);
|
||||||
|
verifyListenersRegistered(languageDetector);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await languageDetector.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
languageDetector,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */
|
||||||
|
[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await languageDetector.setOptions({maxResults: 1});
|
||||||
|
await languageDetector.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(languageDetector, [
|
||||||
|
'classifierOptions', {
|
||||||
|
maxResults: 1,
|
||||||
|
displayNamesLocale: 'en',
|
||||||
|
scoreThreshold: undefined,
|
||||||
|
categoryAllowlistList: [],
|
||||||
|
categoryDenylistList: []
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
const classificationResult = new ClassificationResult();
|
||||||
|
const classifcations = new Classifications();
|
||||||
|
classifcations.setHeadIndex(1);
|
||||||
|
classifcations.setHeadName('headName');
|
||||||
|
const classificationList = new ClassificationList();
|
||||||
|
const classification = new Classification();
|
||||||
|
classification.setIndex(1);
|
||||||
|
classification.setScore(0.9);
|
||||||
|
classification.setDisplayName('English');
|
||||||
|
classification.setLabel('en');
|
||||||
|
classificationList.addClassification(classification);
|
||||||
|
classifcations.setClassificationList(classificationList);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
languageDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(languageDetector);
|
||||||
|
languageDetector.protoListener!
|
||||||
|
(classificationResult.serializeBinary(), 1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the language detector
|
||||||
|
const result = languageDetector.detect('Hello world!');
|
||||||
|
|
||||||
|
expect(languageDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(result).toEqual({
|
||||||
|
languages: [{
|
||||||
|
languageCode: 'en',
|
||||||
|
probability: 0.9,
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('validates that we get a single classification head', async () => {
|
||||||
|
const classificationResult = new ClassificationResult();
|
||||||
|
const classifcations = new Classifications();
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
languageDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(languageDetector);
|
||||||
|
languageDetector.protoListener!
|
||||||
|
(classificationResult.serializeBinary(), 1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Validate that we get an error with more than one classification head
|
||||||
|
expect(() => {
|
||||||
|
languageDetector.detect('Hello world!');
|
||||||
|
}).toThrowError('Expected 1 classification head, got 2');
|
||||||
|
});
|
||||||
|
});
|
|
@ -15,5 +15,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export * from '../../../tasks/web/core/fileset_resolver';
|
export * from '../../../tasks/web/core/fileset_resolver';
|
||||||
|
export * from '../../../tasks/web/text/language_detector/language_detector';
|
||||||
export * from '../../../tasks/web/text/text_classifier/text_classifier';
|
export * from '../../../tasks/web/text/text_classifier/text_classifier';
|
||||||
export * from '../../../tasks/web/text/text_embedder/text_embedder';
|
export * from '../../../tasks/web/text/text_embedder/text_embedder';
|
||||||
|
|
Loading…
Reference in New Issue
Block a user