diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index ebe3403b2..08cbb8672 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -19,6 +19,7 @@ mediapipe_files(srcs = [ TEXT_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/text/language_detector", "//mediapipe/tasks/web/text/text_classifier", "//mediapipe/tasks/web/text/text_embedder", ] diff --git a/mediapipe/tasks/web/text/README.md b/mediapipe/tasks/web/text/README.md index 247dc6d30..089894653 100644 --- a/mediapipe/tasks/web/text/README.md +++ b/mediapipe/tasks/web/text/README.md @@ -2,9 +2,23 @@ This package contains the text tasks for MediaPipe. +## Language Detection + +The MediaPipe Language Detector task predicts the language of an input text. + +``` +const text = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" +); +const languageDetector = await LanguageDetector.createFromModelPath(text, + "model.tflite" +); +const result = languageDetector.detect(textData); +``` + ## Text Classification -MediaPipe Text Classifier task lets you classify text into a set of defined +The MediaPipe Text Classifier task lets you classify text into a set of defined categories, such as positive or negative sentiment. ``` diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index cfa990e58..2fbdd548f 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -15,13 +15,15 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {LanguageDetector as LanguageDetectorImpl} from '../../../tasks/web/text/language_detector/language_detector'; import {TextClassifier as TextClassifierImpl} from '../../../tasks/web/text/text_classifier/text_classifier'; import {TextEmbedder as TextEmbedderImpl} from '../../../tasks/web/text/text_embedder/text_embedder'; // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; +const LanguageDetector = LanguageDetectorImpl; const TextClassifier = TextClassifierImpl; const TextEmbedder = TextEmbedderImpl; -export {FilesetResolver, TextClassifier, TextEmbedder}; +export {LanguageDetector, FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/language_detector/BUILD b/mediapipe/tasks/web/text/language_detector/BUILD new file mode 100644 index 000000000..9fc870081 --- /dev/null +++ b/mediapipe/tasks/web/text/language_detector/BUILD @@ -0,0 +1,66 @@ +# This contains the MediaPipe Language Detector Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "language_detector", + srcs = ["language_detector.ts"], + visibility = ["//visibility:public"], + deps = [ + ":language_detector_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "language_detector_types", + srcs = [ + "language_detector_options.d.ts", + "language_detector_result.d.ts", + ], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) + +mediapipe_ts_library( + name = "language_detector_test_lib", + testonly = True, + srcs = [ + "language_detector_test.ts", + ], + deps = [ + ":language_detector", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + ], +) + +jasmine_node_test( + name = "language_detector_test", + deps = [":language_detector_test_lib"], +) diff --git a/mediapipe/tasks/web/text/language_detector/language_detector.ts b/mediapipe/tasks/web/text/language_detector/language_detector.ts new file mode 100644 index 000000000..13343fab3 --- /dev/null +++ b/mediapipe/tasks/web/text/language_detector/language_detector.ts @@ -0,0 +1,182 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {LanguageDetectorOptions} from './language_detector_options'; +import {LanguageDetectorResult} from './language_detector_result'; + +export * from './language_detector_options'; +export * from './language_detector_result'; + +const INPUT_STREAM = 'text_in'; +const CLASSIFICATIONS_STREAM = 'classifications_out'; +const TEXT_CLASSIFIER_GRAPH = + 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Predicts the language of an input text. */ +export class LanguageDetector extends TaskRunner { + private result: LanguageDetectorResult = {languages: []}; + private readonly options = new TextClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new language detector from the + * provided options. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param textClassifierOptions The options for the language detector. Note + * that either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, textClassifierOptions: LanguageDetectorOptions): + Promise { + return TaskRunner.createInstance( + LanguageDetector, /* canvas= */ null, wasmFileset, + textClassifierOptions); + } + + /** + * Initializes the Wasm runtime and creates a new language detector based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return TaskRunner.createInstance( + LanguageDetector, /* canvas= */ null, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new language detector based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return TaskRunner.createInstance( + LanguageDetector, /* canvas= */ null, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(new CachedGraphRunner(wasmModule, glCanvas)); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + /** + * Sets new options for the language detector. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the language detector. + */ + override setOptions(options: LanguageDetectorOptions): Promise { + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + return this.applyOptions(options); + } + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Predicts the language of the input text. + * + * @param text The text to process. + * @return The languages detected in the input text. + */ + detect(text: string): LanguageDetectorResult { + this.result = {languages: []}; + this.graphRunner.addStringToStream( + text, INPUT_STREAM, this.getSynctheticTimestamp()); + this.finishProcessing(); + return this.result; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + TextClassifierGraphOptions.ext, this.options); + + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH); + classifierNode.addInputStream('TEXT:' + INPUT_STREAM); + classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, (binaryProto, timestamp) => { + const {classifications} = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + if (classifications.length !== 1) { + throw new Error(`Expected 1 classification head, got ${ + classifications.length}`); + } + this.result.languages = classifications[0].categories.map(c => { + return {languageCode: c.categoryName, probability: c.score}; + }); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CLASSIFICATIONS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/text/language_detector/language_detector_options.d.ts b/mediapipe/tasks/web/text/language_detector/language_detector_options.d.ts new file mode 100644 index 000000000..54b735538 --- /dev/null +++ b/mediapipe/tasks/web/text/language_detector/language_detector_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Language Detector Task */ +export declare interface LanguageDetectorOptions extends ClassifierOptions, + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/language_detector/language_detector_result.d.ts b/mediapipe/tasks/web/text/language_detector/language_detector_result.d.ts new file mode 100644 index 000000000..b21285ef4 --- /dev/null +++ b/mediapipe/tasks/web/text/language_detector/language_detector_result.d.ts @@ -0,0 +1,33 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** A language code and its probability. */ +export declare interface LanguageDetectorPrediction { + /** + * An i18n language / locale code, e.g. "en" for English, "uz" for Uzbek, + * "ja"-Latn for Japanese (romaji). + */ + languageCode: string; + + /** The probability */ + probability: number; +} + +/** The result of language detection. */ +export declare interface LanguageDetectorResult { + /** A list of language predictions. */ + languages: LanguageDetectorPrediction[]; +} diff --git a/mediapipe/tasks/web/text/language_detector/language_detector_test.ts b/mediapipe/tasks/web/text/language_detector/language_detector_test.ts new file mode 100644 index 000000000..6e91b6662 --- /dev/null +++ b/mediapipe/tasks/web/text/language_detector/language_detector_test.ts @@ -0,0 +1,169 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; + +import {LanguageDetector} from './language_detector'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +class LanguageDetectorFake extends LanguageDetector implements + MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + fakeWasmModule: SpyWasmModule; + protoListener: + ((binaryProto: Uint8Array, timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachProtoListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('classifications_out'); + this.protoListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + } +} + +describe('LanguageDetector', () => { + let languageDetector: LanguageDetectorFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + languageDetector = new LanguageDetectorFake(); + await languageDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(languageDetector); + verifyListenersRegistered(languageDetector); + }); + + it('reloads graph when settings are changed', async () => { + await languageDetector.setOptions({maxResults: 1}); + verifyGraph(languageDetector, [['classifierOptions', 'maxResults'], 1]); + verifyListenersRegistered(languageDetector); + + await languageDetector.setOptions({maxResults: 5}); + verifyGraph(languageDetector, [['classifierOptions', 'maxResults'], 5]); + verifyListenersRegistered(languageDetector); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await languageDetector.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + languageDetector, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await languageDetector.setOptions({maxResults: 1}); + await languageDetector.setOptions({displayNamesLocale: 'en'}); + verifyGraph(languageDetector, [ + 'classifierOptions', { + maxResults: 1, + displayNamesLocale: 'en', + scoreThreshold: undefined, + categoryAllowlistList: [], + categoryDenylistList: [] + } + ]); + }); + + it('transforms results', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classifcations.setHeadIndex(1); + classifcations.setHeadName('headName'); + const classificationList = new ClassificationList(); + const classification = new Classification(); + classification.setIndex(1); + classification.setScore(0.9); + classification.setDisplayName('English'); + classification.setLabel('en'); + classificationList.addClassification(classification); + classifcations.setClassificationList(classificationList); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + languageDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(languageDetector); + languageDetector.protoListener! + (classificationResult.serializeBinary(), 1337); + }); + + // Invoke the language detector + const result = languageDetector.detect('Hello world!'); + + expect(languageDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result).toEqual({ + languages: [{ + languageCode: 'en', + probability: 0.9, + }] + }); + }); + + it('validates that we get a single classification head', async () => { + const classificationResult = new ClassificationResult(); + const classifcations = new Classifications(); + classificationResult.addClassifications(classifcations); + classificationResult.addClassifications(classifcations); + + // Pass the test data to our listener + languageDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(languageDetector); + languageDetector.protoListener! + (classificationResult.serializeBinary(), 1337); + }); + + // Validate that we get an error with more than one classification head + expect(() => { + languageDetector.detect('Hello world!'); + }).toThrowError('Expected 1 classification head, got 2'); + }); +}); diff --git a/mediapipe/tasks/web/text/types.ts b/mediapipe/tasks/web/text/types.ts index bd01b1c6f..9f9e48c5a 100644 --- a/mediapipe/tasks/web/text/types.ts +++ b/mediapipe/tasks/web/text/types.ts @@ -15,5 +15,6 @@ */ export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/text/language_detector/language_detector'; export * from '../../../tasks/web/text/text_classifier/text_classifier'; export * from '../../../tasks/web/text/text_embedder/text_embedder';