diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD new file mode 100644 index 000000000..bc3048df1 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -0,0 +1,33 @@ +# This contains the MediaPipe Audio Classifier Task. +# +# This task takes audio data and outputs the classification result. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_classifier", + srcs = [ + "audio_classifier.ts", + "audio_classifier_options.ts", + "audio_classifier_result.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts new file mode 100644 index 000000000..fd79487a4 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -0,0 +1,215 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {AudioClassifierOptions} from './audio_classifier_options'; +import {Classifications} from './audio_classifier_result'; + +const MEDIAPIPE_GRAPH = + 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; + +// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and +// cannot be changed +// TODO: Change this to `audio_in` to match the name in the CC +// implementation +const AUDIO_STREAM = 'input_audio'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const CLASSIFICATION_RESULT_STREAM = 'classification_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs audio classification. */ +export class AudioClassifier extends TaskRunner { + private classifications: Classifications[] = []; + private defaultSampleRate = 48000; + private readonly options = new AudioClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new audio classifier from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param audioClassifierOptions The options for the audio classifier. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + audioClassifierOptions: AudioClassifierOptions): + Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const classifier = await createMediaPipeLib( + AudioClassifier, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await classifier.setOptions(audioClassifierOptions); + return classifier; + } + + /** + * Initializes the Wasm runtime and creates a new audio classifier based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return AudioClassifier.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio classifier based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return AudioClassifier.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the audio classifier. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio classifier. + */ + async setOptions(options: AudioClassifierOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + this.refreshGraph(); + } + + /** + * Sets the sample rate for all calls to `classify()` that omit an explicit + * sample rate. `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** + * Performs audio classification on the provided audio data and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The classification result of the audio datas + */ + classify(audioData: Float32Array, sampleRate?: number): Classifications[] { + sampleRate = sampleRate ?? this.defaultSampleRate; + + // Configures the number of samples in the WASM layer. We re-configure the + // number of samples and the sample rate for every frame, but ignore other + // side effects of this function (such as sending the input side packet and + // the input stream header). + this.configureAudio( + /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); + + const timestamp = performance.now(); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); + this.addAudioToStream(audioData, timestamp); + + this.classifications = []; + this.finishProcessing(); + return [...this.classifications]; + } + + /** + * Internal function for converting raw data into a classification, and + * adding it to our classfications list. + **/ + private addJsAudioClassification(binaryProto: Uint8Array): void { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + this.classifications.push( + ...convertFromClassificationResultProto(classificationResult)); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioClassifierGraphOptions.ext, this.options); + + // Perform audio classification. Pre-processing and results post-processing + // are built-in. + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(MEDIAPIPE_GRAPH); + classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); + classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + classifierNode.addOutputStream( + 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { + this.addJsAudioClassification(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts new file mode 100644 index 000000000..93bd9927e --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts new file mode 100644 index 000000000..0a51dee04 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Category} from '../../../../tasks/web/components/containers/category'; +export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; diff --git a/mediapipe/tasks/web/components/processors/classifier_options.ts b/mediapipe/tasks/web/components/processors/classifier_options.ts index 8e01dd373..5b8ae796e 100644 --- a/mediapipe/tasks/web/components/processors/classifier_options.ts +++ b/mediapipe/tasks/web/components/processors/classifier_options.ts @@ -29,31 +29,31 @@ export function convertClassifierOptionsToProto( baseOptions?: ClassifierOptionsProto): ClassifierOptionsProto { const classifierOptions = baseOptions ? baseOptions.clone() : new ClassifierOptionsProto(); - if (options.displayNamesLocale) { + if (options.displayNamesLocale !== undefined) { classifierOptions.setDisplayNamesLocale(options.displayNamesLocale); } else if (options.displayNamesLocale === undefined) { classifierOptions.clearDisplayNamesLocale(); } - if (options.maxResults) { + if (options.maxResults !== undefined) { classifierOptions.setMaxResults(options.maxResults); } else if ('maxResults' in options) { // Check for undefined classifierOptions.clearMaxResults(); } - if (options.scoreThreshold) { + if (options.scoreThreshold !== undefined) { classifierOptions.setScoreThreshold(options.scoreThreshold); } else if ('scoreThreshold' in options) { // Check for undefined classifierOptions.clearScoreThreshold(); } - if (options.categoryAllowlist) { + if (options.categoryAllowlist !== undefined) { classifierOptions.setCategoryAllowlistList(options.categoryAllowlist); } else if ('categoryAllowlist' in options) { // Check for undefined classifierOptions.clearCategoryAllowlistList(); } - if (options.categoryDenylist) { + if (options.categoryDenylist !== undefined) { classifierOptions.setCategoryDenylistList(options.categoryDenylist); } else if ('categoryDenylist' in options) { // Check for undefined classifierOptions.clearCategoryDenylistList(); diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index a5547ad6e..4fb57d6c3 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -12,6 +12,18 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "task_runner", + srcs = [ + "task_runner.ts", + ], + deps = [ + "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", + "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) + mediapipe_ts_library( name = "classifier_options", srcs = [ diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts new file mode 100644 index 000000000..c948930fc --- /dev/null +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -0,0 +1,83 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; +import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; +import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; + +// tslint:disable-next-line:enforce-name-casing +const WasmMediaPipeImageLib = + SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); + +/** Base class for all MediaPipe Tasks. */ +export abstract class TaskRunner extends WasmMediaPipeImageLib { + private processingErrors: Error[] = []; + + constructor(wasmModule: WasmModule) { + super(wasmModule); + + // Disables the automatic render-to-screen code, which allows for pure + // CPU processing. + this.setAutoRenderToScreen(false); + + // Enables use of our model resource caching graph service. + this.registerModelResourcesGraphService(); + } + + /** + * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run + * over the video stream. Will replace the previously running MediaPipe graph, + * if there is one. + * @param graphData The raw MediaPipe graph data, either in binary + * protobuffer format (.binarypb), or else in raw text format (.pbtxt or + * .textproto). + * @param isBinary This should be set to true if the graph is in + * binary format, and false if it is in human-readable text format. + */ + override setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.attachErrorListener((code, message) => { + this.processingErrors.push(new Error(message)); + }); + super.setGraph(graphData, isBinary); + this.handleErrors(); + } + + /** + * Forces all queued-up packets to be pushed through the MediaPipe graph as + * far as possible, performing all processing until no more processing can be + * done. + */ + override finishProcessing(): void { + super.finishProcessing(); + this.handleErrors(); + } + + /** Throws the error from the error listener if an error was raised. */ + private handleErrors() { + const errorCount = this.processingErrors.length; + if (errorCount === 1) { + // Re-throw error to get a more meaningful stacktrace + throw new Error(this.processingErrors[0].message); + } else if (errorCount > 1) { + throw new Error( + 'Encountered multiple errors: ' + + this.processingErrors.map(e => e.message).join(', ')); + } + this.processingErrors = []; + } +} + + diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD new file mode 100644 index 000000000..e984a9554 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -0,0 +1,34 @@ +# This contains the MediaPipe Text Classifier Task. +# +# This task takes text input performs Natural Language classification (including +# BERT-based text classification). + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "text_classifier", + srcs = [ + "text_classifier.ts", + "text_classifier_options.d.ts", + "text_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts new file mode 100644 index 000000000..ff36bb9e0 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -0,0 +1,180 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {TextClassifierOptions} from './text_classifier_options'; +import {Classifications} from './text_classifier_result'; + +const INPUT_STREAM = 'text_in'; +const CLASSIFICATION_RESULT_STREAM = 'classification_result_out'; +const TEXT_CLASSIFIER_GRAPH = + 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs Natural Language classification. */ +export class TextClassifier extends TaskRunner { + private classifications: Classifications[] = []; + private readonly options = new TextClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new text classifier from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param textClassifierOptions The options for the text classifier. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + textClassifierOptions: TextClassifierOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const classifier = await createMediaPipeLib( + TextClassifier, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await classifier.setOptions(textClassifierOptions); + return classifier; + } + + /** + * Initializes the Wasm runtime and creates a new text classifier based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return TextClassifier.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new text classifier based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return TextClassifier.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the text classifier. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the text classifier. + */ + async setOptions(options: TextClassifierOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + this.refreshGraph(); + } + + + /** + * Performs Natural Language classification on the provided text and waits + * synchronously for the response. + * + * @param text The text to process. + * @return The classification result of the text + */ + classify(text: string): Classifications[] { + // Get classification classes by running our MediaPipe graph. + this.classifications = []; + this.addStringToStream( + text, INPUT_STREAM, /* timestamp= */ performance.now()); + this.finishProcessing(); + return [...this.classifications]; + } + + // Internal function for converting raw data into a classification, and + // adding it to our classifications list. + private addJsTextClassification(binaryProto: Uint8Array): void { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + console.log(classificationResult.toObject()); + this.classifications.push( + ...convertFromClassificationResultProto(classificationResult)); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + TextClassifierGraphOptions.ext, this.options); + + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH); + classifierNode.addInputStream('TEXT:' + INPUT_STREAM); + classifierNode.addOutputStream( + 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { + this.addJsTextClassification(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts new file mode 100644 index 000000000..51b2b3947 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts new file mode 100644 index 000000000..0a51dee04 --- /dev/null +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Category} from '../../../../tasks/web/components/containers/category'; +export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD new file mode 100644 index 000000000..8988c4794 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -0,0 +1,40 @@ +# This contains the MediaPipe Gesture Recognizer Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more gesture categories, using Gesture Recognizer. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "gesture_recognizer", + srcs = [ + "gesture_recognizer.ts", + "gesture_recognizer_options.d.ts", + "gesture_recognizer_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:classification_jspb_proto", + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts new file mode 100644 index 000000000..ad8db1477 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -0,0 +1,374 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationList} from '../../../../framework/formats/classification_pb'; +import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; +import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; +import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb'; +import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; +import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; +import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {GestureRecognizerOptions} from './gesture_recognizer_options'; +import {GestureRecognitionResult} from './gesture_recognizer_result'; + +export {ImageSource}; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const HAND_GESTURES_STREAM = 'hand_gestures'; +const LANDMARKS_STREAM = 'hand_landmarks'; +const WORLD_LANDMARKS_STREAM = 'world_hand_landmarks'; +const HANDEDNESS_STREAM = 'handedness'; +const GESTURE_RECOGNIZER_GRAPH = + 'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph'; + +const DEFAULT_NUM_HANDS = 1; +const DEFAULT_SCORE_THRESHOLD = 0.5; +const DEFAULT_CATEGORY_INDEX = -1; + +const FULL_IMAGE_RECT = new NormalizedRect(); +FULL_IMAGE_RECT.setXCenter(0.5); +FULL_IMAGE_RECT.setYCenter(0.5); +FULL_IMAGE_RECT.setWidth(1); +FULL_IMAGE_RECT.setHeight(1); + +/** Performs hand gesture recognition on images. */ +export class GestureRecognizer extends TaskRunner { + private gestures: Category[][] = []; + private landmarks: Landmark[][] = []; + private worldLandmarks: Landmark[][] = []; + private handednesses: Category[][] = []; + + private readonly options: GestureRecognizerGraphOptions; + private readonly handLandmarkerGraphOptions: HandLandmarkerGraphOptions; + private readonly handLandmarksDetectorGraphOptions: + HandLandmarksDetectorGraphOptions; + private readonly handDetectorGraphOptions: HandDetectorGraphOptions; + private readonly handGestureRecognizerGraphOptions: + HandGestureRecognizerGraphOptions; + + /** + * Initializes the Wasm runtime and creates a new gesture recognizer from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param gestureRecognizerOptions The options for the gesture recognizer. + * Note that either a path to the model asset or a model buffer needs to + * be provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + gestureRecognizerOptions: GestureRecognizerOptions): + Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load via this mechanism is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const recognizer = await createMediaPipeLib( + GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await recognizer.setOptions(gestureRecognizerOptions); + return recognizer; + } + + /** + * Initializes the Wasm runtime and creates a new gesture recognizer based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return GestureRecognizer.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new gesture recognizer based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return GestureRecognizer.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + constructor(wasmModule: WasmModule) { + super(wasmModule); + + this.options = new GestureRecognizerGraphOptions(); + this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); + this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); + this.handLandmarksDetectorGraphOptions = + new HandLandmarksDetectorGraphOptions(); + this.handLandmarkerGraphOptions.setHandLandmarksDetectorGraphOptions( + this.handLandmarksDetectorGraphOptions); + this.handDetectorGraphOptions = new HandDetectorGraphOptions(); + this.handLandmarkerGraphOptions.setHandDetectorGraphOptions( + this.handDetectorGraphOptions); + this.handGestureRecognizerGraphOptions = + new HandGestureRecognizerGraphOptions(); + this.options.setHandGestureRecognizerGraphOptions( + this.handGestureRecognizerGraphOptions); + + this.initDefaults(); + + // Disables the automatic render-to-screen code, which allows for pure + // CPU processing. + this.setAutoRenderToScreen(false); + } + + /** + * Sets new options for the gesture recognizer. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the gesture recognizer. + */ + async setOptions(options: GestureRecognizerOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + if ('numHands' in options) { + this.handDetectorGraphOptions.setNumHands( + options.numHands ?? DEFAULT_NUM_HANDS); + } + if ('minHandDetectionConfidence' in options) { + this.handDetectorGraphOptions.setMinDetectionConfidence( + options.minHandDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minHandPresenceConfidence' in options) { + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + if ('minTrackingConfidence' in options) { + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + options.minTrackingConfidence ?? DEFAULT_SCORE_THRESHOLD); + } + + if (options.cannedGesturesClassifierOptions) { + // Note that we have to support both JSPB and ProtobufJS and cannot + // use JSPB's getMutableX() APIs. + const graphOptions = new GestureClassifierGraphOptions(); + graphOptions.setClassifierOptions(convertClassifierOptionsToProto( + options.cannedGesturesClassifierOptions, + this.handGestureRecognizerGraphOptions + .getCannedGestureClassifierGraphOptions() + ?.getClassifierOptions())); + this.handGestureRecognizerGraphOptions + .setCannedGestureClassifierGraphOptions(graphOptions); + } else if (options.cannedGesturesClassifierOptions === undefined) { + this.handGestureRecognizerGraphOptions + .getCannedGestureClassifierGraphOptions() + ?.clearClassifierOptions(); + } + + if (options.customGesturesClassifierOptions) { + const graphOptions = new GestureClassifierGraphOptions(); + graphOptions.setClassifierOptions(convertClassifierOptionsToProto( + options.customGesturesClassifierOptions, + this.handGestureRecognizerGraphOptions + .getCustomGestureClassifierGraphOptions() + ?.getClassifierOptions())); + this.handGestureRecognizerGraphOptions + .setCustomGestureClassifierGraphOptions(graphOptions); + } else if (options.customGesturesClassifierOptions === undefined) { + this.handGestureRecognizerGraphOptions + .getCustomGestureClassifierGraphOptions() + ?.clearClassifierOptions(); + } + + this.refreshGraph(); + } + + /** + * Performs gesture recognition on the provided single image and waits + * synchronously for the response. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The detected gestures. + */ + recognize(imageSource: ImageSource, timestamp: number = performance.now()): + GestureRecognitionResult { + this.gestures = []; + this.landmarks = []; + this.worldLandmarks = []; + this.handednesses = []; + + this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); + this.addProtoToStream( + FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', + NORM_RECT_STREAM, timestamp); + this.finishProcessing(); + + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } + + /** Sets the default values for the graph. */ + private initDefaults(): void { + this.handDetectorGraphOptions.setNumHands(DEFAULT_NUM_HANDS); + this.handDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence( + DEFAULT_SCORE_THRESHOLD); + this.handLandmarkerGraphOptions.setMinTrackingConfidence( + DEFAULT_SCORE_THRESHOLD); + } + + /** Converts the proto data to a Category[][] structure. */ + private toJsCategories(data: Uint8Array[]): Category[][] { + const result: Category[][] = []; + for (const binaryProto of data) { + const inputList = ClassificationList.deserializeBinary(binaryProto); + const outputList: Category[] = []; + for (const classification of inputList.getClassificationList()) { + outputList.push({ + score: classification.getScore() ?? 0, + index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + categoryName: classification.getLabel() ?? '', + displayName: classification.getDisplayName() ?? '', + }); + } + result.push(outputList); + } + return result; + } + + /** Converts raw data into a landmark, and adds it to our landmarks list. */ + private addJsLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handLandmarksProto = + NormalizedLandmarkList.deserializeBinary(binaryProto); + const landmarks: Landmark[] = []; + for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { + landmarks.push({ + x: handLandmarkProto.getX() ?? 0, + y: handLandmarkProto.getY() ?? 0, + z: handLandmarkProto.getZ() ?? 0, + normalized: true + }); + } + this.landmarks.push(landmarks); + } + } + + /** + * Converts raw data into a landmark, and adds it to our worldLandmarks + * list. + */ + private adddJsWorldLandmarks(data: Uint8Array[]): void { + for (const binaryProto of data) { + const handWorldLandmarksProto = + LandmarkList.deserializeBinary(binaryProto); + const worldLandmarks: Landmark[] = []; + for (const handWorldLandmarkProto of + handWorldLandmarksProto.getLandmarkList()) { + worldLandmarks.push({ + x: handWorldLandmarkProto.getX() ?? 0, + y: handWorldLandmarkProto.getY() ?? 0, + z: handWorldLandmarkProto.getZ() ?? 0, + normalized: false + }); + } + this.worldLandmarks.push(worldLandmarks); + } + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(HAND_GESTURES_STREAM); + graphConfig.addOutputStream(LANDMARKS_STREAM); + graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM); + graphConfig.addOutputStream(HANDEDNESS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + GestureRecognizerGraphOptions.ext, this.options); + + const recognizerNode = new CalculatorGraphConfig.Node(); + recognizerNode.setCalculator(GESTURE_RECOGNIZER_GRAPH); + recognizerNode.addInputStream('IMAGE:' + IMAGE_STREAM); + recognizerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + recognizerNode.addOutputStream('HAND_GESTURES:' + HAND_GESTURES_STREAM); + recognizerNode.addOutputStream('LANDMARKS:' + LANDMARKS_STREAM); + recognizerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM); + recognizerNode.addOutputStream('HANDEDNESS:' + HANDEDNESS_STREAM); + recognizerNode.setOptions(calculatorOptions); + + graphConfig.addNode(recognizerNode); + + this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts new file mode 100644 index 000000000..16169a93f --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -0,0 +1,65 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +/** Options to configure the MediaPipe Gesture Recognizer Task */ +export interface GestureRecognizerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The maximum number of hands can be detected by the GestureRecognizer. + * Defaults to 1. + */ + numHands?: number|undefined; + + /** + * The minimum confidence score for the hand detection to be considered + * successful. Defaults to 0.5. + */ + minHandDetectionConfidence?: number|undefined; + + /** + * The minimum confidence score of hand presence score in the hand landmark + * detection. Defaults to 0.5. + */ + minHandPresenceConfidence?: number|undefined; + + /** + * The minimum confidence score for the hand tracking to be considered + * successful. Defaults to 0.5. + */ + minTrackingConfidence?: number|undefined; + + /** + * Sets the optional `ClassifierOptions` controling the canned gestures + * classifier, such as score threshold, allow list and deny list of gestures. + * The categories for canned gesture + * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", + * "Thumb_Down", "Thumb_Up", "Victory", "ILoveYou"] + */ + // TODO: Note this option is subject to change + cannedGesturesClassifierOptions?: ClassifierOptions|undefined; + + /** + * Options for configuring the custom gestures classifier, such as score + * threshold, allow list and deny list of gestures. + */ + // TODO b/251816640): Note this option is subject to change. + customGesturesClassifierOptions?: ClassifierOptions|undefined; +} diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts new file mode 100644 index 000000000..cccdfaf68 --- /dev/null +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -0,0 +1,35 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; +import {Landmark} from '../../../../tasks/web/components/containers/landmark'; + +/** + * Represents the gesture recognition results generated by `GestureRecognizer`. + */ +export interface GestureRecognitionResult { + /** Hand landmarks of detected hands. */ + landmarks: Landmark[][]; + + /** Hand landmarks in world coordniates of detected hands. */ + worldLandmarks: Landmark[][]; + + /** Handedness of detected hands. */ + handednesses: Category[][]; + + /** Recognized hand gestures of detected hands */ + gestures: Category[][]; +} diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD new file mode 100644 index 000000000..6937dc4f3 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -0,0 +1,33 @@ +# This contains the MediaPipe Image Classifier Task. +# +# This task takes video or image frames and outputs the classification result. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_classifier", + srcs = [ + "image_classifier.ts", + "image_classifier_options.ts", + "image_classifier_result.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/components/processors:classifier_options", + "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts new file mode 100644 index 000000000..39674e85c --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -0,0 +1,186 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; +import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {ImageClassifierOptions} from './image_classifier_options'; +import {Classifications} from './image_classifier_result'; + +const IMAGE_CLASSIFIER_GRAPH = + 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; +const INPUT_STREAM = 'input_image'; +const CLASSIFICATION_RESULT_STREAM = 'classification_result'; + +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs classification on images. */ +export class ImageClassifier extends TaskRunner { + private classifications: Classifications[] = []; + private readonly options = new ImageClassifierGraphOptions(); + + /** + * Initializes the Wasm runtime and creates a new image classifier from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param imageClassifierOptions The options for the image classifier. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + imageClassifierOptions: ImageClassifierOptions): + Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const classifier = await createMediaPipeLib( + ImageClassifier, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await classifier.setOptions(imageClassifierOptions); + return classifier; + } + + /** + * Initializes the Wasm runtime and creates a new image classifier based on + * the provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return ImageClassifier.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image classifier based on + * the path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return ImageClassifier.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the image classifier. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the image classifier. + */ + async setOptions(options: ImageClassifierOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + this.options.setClassifierOptions(convertClassifierOptionsToProto( + options, this.options.getClassifierOptions())); + this.refreshGraph(); + } + + /** + * Performs image classification on the provided image and waits synchronously + * for the response. + * + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The classification result of the image + */ + classify(imageSource: ImageSource, timestamp?: number): Classifications[] { + // Get classification classes by running our MediaPipe graph. + this.classifications = []; + this.addGpuBufferAsImageToStream( + imageSource, INPUT_STREAM, timestamp ?? performance.now()); + this.finishProcessing(); + return [...this.classifications]; + } + + /** + * Internal function for converting raw data into a classification, and + * adding it to our classfications list. + **/ + private addJsImageClassification(binaryProto: Uint8Array): void { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + this.classifications.push( + ...convertFromClassificationResultProto(classificationResult)); + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageClassifierGraphOptions.ext, this.options); + + // Perform image classification. Pre-processing and results post-processing + // are built-in. + const classifierNode = new CalculatorGraphConfig.Node(); + classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); + classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); + classifierNode.addOutputStream( + 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.setOptions(calculatorOptions); + + graphConfig.addNode(classifierNode); + + this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { + this.addJsImageClassification(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts new file mode 100644 index 000000000..a5f5c2386 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options'; diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts new file mode 100644 index 000000000..0a51dee04 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Category} from '../../../../tasks/web/components/containers/category'; +export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD new file mode 100644 index 000000000..888537bd1 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -0,0 +1,30 @@ +# This contains the MediaPipe Object Detector Task. +# +# This task takes video frames and outputs synchronized frames along with +# the detection results for one or more object categories, using Object Detector. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "object_detector", + srcs = [ + "object_detector.ts", + "object_detector_options.d.ts", + "object_detector_result.d.ts", + ], + deps = [ + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], +) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts new file mode 100644 index 000000000..c3bb21baa --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -0,0 +1,233 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +// Placeholder for internal dependency on trusted resource url + +import {ObjectDetectorOptions} from './object_detector_options'; +import {Detection} from './object_detector_result'; + +const INPUT_STREAM = 'input_frame_gpu'; +const DETECTIONS_STREAM = 'detections'; +const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; + +const DEFAULT_CATEGORY_INDEX = -1; + +export {ImageSource}; // Used in the public API + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs object detection on images. */ +export class ObjectDetector extends TaskRunner { + private detections: Detection[] = []; + private readonly options = new ObjectDetectorOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new object detector from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param objectDetectorOptions The options for the Object Detector. Note that + * either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + objectDetectorOptions: ObjectDetectorOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const detector = await createMediaPipeLib( + ObjectDetector, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await detector.setOptions(objectDetectorOptions); + return detector; + } + + /** + * Initializes the Wasm runtime and creates a new object detector based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return ObjectDetector.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new object detector based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return ObjectDetector.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + /** + * Sets new options for the object detector. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the object detector. + */ + async setOptions(options: ObjectDetectorOptions): Promise { + if (options.baseOptions) { + const baseOptionsProto = + await convertBaseOptionsToProto(options.baseOptions); + this.options.setBaseOptions(baseOptionsProto); + } + + // Note that we have to support both JSPB and ProtobufJS, hence we + // have to expliclity clear the values instead of setting them to + // `undefined`. + if (options.displayNamesLocale !== undefined) { + this.options.setDisplayNamesLocale(options.displayNamesLocale); + } else if ('displayNamesLocale' in options) { // Check for undefined + this.options.clearDisplayNamesLocale(); + } + + if (options.maxResults !== undefined) { + this.options.setMaxResults(options.maxResults); + } else if ('maxResults' in options) { // Check for undefined + this.options.clearMaxResults(); + } + + if (options.scoreThreshold !== undefined) { + this.options.setScoreThreshold(options.scoreThreshold); + } else if ('scoreThreshold' in options) { // Check for undefined + this.options.clearScoreThreshold(); + } + + if (options.categoryAllowlist !== undefined) { + this.options.setCategoryAllowlistList(options.categoryAllowlist); + } else if ('categoryAllowlist' in options) { // Check for undefined + this.options.clearCategoryAllowlistList(); + } + + if (options.categoryDenylist !== undefined) { + this.options.setCategoryDenylistList(options.categoryDenylist); + } else if ('categoryDenylist' in options) { // Check for undefined + this.options.clearCategoryDenylistList(); + } + + this.refreshGraph(); + } + + /** + * Performs object detection on the provided single image and waits + * synchronously for the response. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. If not + * provided, defaults to `performance.now()`. + * @return The list of detected objects + */ + detect(imageSource: ImageSource, timestamp?: number): Detection[] { + // Get detections by running our MediaPipe graph. + this.detections = []; + this.addGpuBufferAsImageToStream( + imageSource, INPUT_STREAM, timestamp ?? performance.now()); + this.finishProcessing(); + return [...this.detections]; + } + + /** Converts raw data into a Detection, and adds it to our detection list. */ + private addJsObjectDetections(data: Uint8Array[]): void { + for (const binaryProto of data) { + const detectionProto = DetectionProto.deserializeBinary(binaryProto); + const scores = detectionProto.getScoreList(); + const indexes = detectionProto.getLabelIdList(); + const labels = detectionProto.getLabelList(); + const displayNames = detectionProto.getDisplayNameList(); + + const detection: Detection = {categories: []}; + for (let i = 0; i < scores.length; i++) { + detection.categories.push({ + score: scores[i], + index: indexes[i] ?? DEFAULT_CATEGORY_INDEX, + categoryName: labels[i] ?? '', + displayName: displayNames[i] ?? '', + }); + } + + const boundingBox = detectionProto.getLocationData()?.getBoundingBox(); + if (boundingBox) { + detection.boundingBox = { + originX: boundingBox.getXmin() ?? 0, + originY: boundingBox.getYmin() ?? 0, + width: boundingBox.getWidth() ?? 0, + height: boundingBox.getHeight() ?? 0 + }; + } + + this.detections.push(detection); + } + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addOutputStream(DETECTIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ObjectDetectorOptionsProto.ext, this.options); + + const detectorNode = new CalculatorGraphConfig.Node(); + detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH); + detectorNode.addInputStream('IMAGE:' + INPUT_STREAM); + detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); + detectorNode.setOptions(calculatorOptions); + + graphConfig.addNode(detectorNode); + + this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts new file mode 100644 index 000000000..eec12cf17 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -0,0 +1,52 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** Options to configure the MediaPipe Object Detector Task */ +export interface ObjectDetectorOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ + displayNamesLocale?: string|undefined; + + /** The maximum number of top-scored detection results to return. */ + maxResults?: number|undefined; + + /** + * Overrides the value provided in the model metadata. Results below this + * value are rejected. + */ + scoreThreshold?: number|undefined; + + /** + * Allowlist of category names. If non-empty, detection results whose category + * name is not in this set will be filtered out. Duplicate or unknown category + * names are ignored. Mutually exclusive with `categoryDenylist`. + */ + categoryAllowlist?: string[]|undefined; + + /** + * Denylist of category names. If non-empty, detection results whose category + * name is in this set will be filtered out. Duplicate or unknown category + * names are ignored. Mutually exclusive with `categoryAllowlist`. + */ + categoryDenylist?: string[]|undefined; +} diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts new file mode 100644 index 000000000..7b2621134 --- /dev/null +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_result.d.ts @@ -0,0 +1,38 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Category} from '../../../../tasks/web/components/containers/category'; + +/** An integer bounding box, axis aligned. */ +export interface BoundingBox { + /** The X coordinate of the top-left corner, in pixels. */ + originX: number; + /** The Y coordinate of the top-left corner, in pixels. */ + originY: number; + /** The width of the bounding box, in pixels. */ + width: number; + /** The height of the bounding box, in pixels. */ + height: number; +} + +/** Represents one object detected by the `ObjectDetector`. */ +export interface Detection { + /** A list of `Category` objects. */ + categories: Category[]; + + /** The bounding box of the detected objects. */ + boundingBox?: BoundingBox; +} diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD new file mode 100644 index 000000000..dab6be50f --- /dev/null +++ b/mediapipe/web/graph_runner/BUILD @@ -0,0 +1,41 @@ +# The TypeScript graph runner used by all MediaPipe Web tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = [ + ":internal", + "//mediapipe/tasks:internal", +]) + +package_group( + name = "internal", + packages = [ + "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", + ], +) + +mediapipe_ts_library( + name = "wasm_mediapipe_lib_ts", + srcs = [ + ":wasm_mediapipe_lib.ts", + ], + allow_unoptimized_namespaces = True, +) + +mediapipe_ts_library( + name = "wasm_mediapipe_image_lib_ts", + srcs = [ + ":wasm_mediapipe_image_lib.ts", + ], + allow_unoptimized_namespaces = True, + deps = [":wasm_mediapipe_lib_ts"], +) + +mediapipe_ts_library( + name = "register_model_resources_graph_service_ts", + srcs = [ + ":register_model_resources_graph_service.ts", + ], + allow_unoptimized_namespaces = True, + deps = [":wasm_mediapipe_lib_ts"], +) diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts new file mode 100644 index 000000000..e85d63b06 --- /dev/null +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -0,0 +1,41 @@ +import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; + +/** + * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * access to the wasmModule, among other things. The `any` type is required for + * mixin constructors. + */ +// tslint:disable-next-line:no-any +type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; + +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmModuleRegisterModelResources { + _registerModelResourcesGraphService: () => void; +} + +/** + * An implementation of WasmMediaPipeLib that supports registering model + * resources to a cache, in the form of a GraphService C++-side. We implement as + * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: + * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( + * WasmMediaPipeLib);` + */ +// tslint:disable:enforce-name-casing +export function SupportModelResourcesGraphService( + Base: TBase) { + return class extends Base { + // tslint:enable:enforce-name-casing + /** + * Instructs the graph runner to use the model resource caching graph + * service for both graph expansion/inintialization, as well as for graph + * run. + */ + registerModelResourcesGraphService(): void { + (this.wasmModule as unknown as WasmModuleRegisterModelResources) + ._registerModelResourcesGraphService(); + } + }; +} diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts new file mode 100644 index 000000000..3b45e8230 --- /dev/null +++ b/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts @@ -0,0 +1,52 @@ +import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; + +/** + * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * access to the wasmModule, among other things. The `any` type is required for + * mixin constructors. + */ +// tslint:disable-next-line:no-any +type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; + +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmImageModule { + _addBoundTextureAsImageToStream: + (streamNamePtr: number, width: number, height: number, + timestamp: number) => void; +} + +/** + * An implementation of WasmMediaPipeLib that supports binding GPU image data as + * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for + * effective multiple inheritance. Example usage: + * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` + */ +// tslint:disable-next-line:enforce-name-casing +export function SupportImage(Base: TBase) { + return class extends Base { + /** + * Takes the relevant information from the HTML video or image element, and + * passes it into the WebGL-based graph for processing on the given stream + * at the given timestamp as a MediaPipe image. Processing will not occur + * until a blocking call (like processVideoGl or finishProcessing) is made. + * @param imageSource Reference to the video frame we wish to add into our + * graph. + * @param streamName The name of the MediaPipe graph stream to add the frame + * to. + * @param timestamp The timestamp of the input frame, in ms. + */ + addGpuBufferAsImageToStream( + imageSource: ImageSource, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + const [width, height] = + this.bindTextureToStream(imageSource, streamNamePtr); + (this.wasmModule as unknown as WasmImageModule) + ._addBoundTextureAsImageToStream( + streamNamePtr, width, height, timestamp); + }); + } + }; +} diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts new file mode 100644 index 000000000..714f42134 --- /dev/null +++ b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts @@ -0,0 +1,1044 @@ +// Placeholder for internal dependency on assertTruthy +// Placeholder for internal dependency on jsloader +// Placeholder for internal dependency on trusted resource url + +// This file can serve as a common interface for most simple TypeScript +// libraries-- additionally, it can hook automatically into wasm_mediapipe_demo +// to autogenerate simple TS APIs from demos for instantaneous 1P integrations. + +/** + * Simple interface for allowing users to set the directory where internal + * wasm-loading and asset-loading code looks (e.g. for .wasm and .data file + * locations). + */ +export declare interface FileLocator { + locateFile: (filename: string) => string; +} + +/** Listener to be passed in by user for handling output audio data. */ +export type AudioOutputListener = (output: Float32Array) => void; + +/** + * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler + * doesn't break our JS/C++ bridge. + */ +export declare interface WasmModule { + canvas: HTMLCanvasElement|OffscreenCanvas|null; + HEAPU8: Uint8Array; + HEAPU32: Uint32Array; + HEAPF32: Float32Array; + HEAPF64: Float64Array; + errorListener?: ErrorListener; + _bindTextureToCanvas: () => boolean; + _changeBinaryGraph: (size: number, dataPtr: number) => void; + _changeTextGraph: (size: number, dataPtr: number) => void; + _configureAudio: + (channels: number, samples: number, sampleRate: number) => void; + _free: (ptr: number) => void; + _malloc: (size: number) => number; + _processAudio: (dataPtr: number, timestamp: number) => void; + _processFrame: (width: number, height: number, timestamp: number) => void; + _setAutoRenderToScreen: (enabled: boolean) => void; + _waitUntilIdle: () => void; + + // Exposed so that clients of this lib can access this field + dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; + // Wasm module will call us back at this function when given audio data. + onAudioOutput?: AudioOutputListener; + + // Wasm Module multistream entrypoints. Require + // gl_graph_runner_internal_multi_input as a build dependency. + stringToNewUTF8: (data: string) => number; + _bindTextureToStream: (streamNamePtr: number) => void; + _addBoundTextureToStream: + (streamNamePtr: number, width: number, height: number, + timestamp: number) => void; + _addBoolToInputStream: + (data: boolean, streamNamePtr: number, timestamp: number) => void; + _addDoubleToInputStream: + (data: number, streamNamePtr: number, timestamp: number) => void; + _addFloatToInputStream: + (data: number, streamNamePtr: number, timestamp: number) => void; + _addIntToInputStream: + (data: number, streamNamePtr: number, timestamp: number) => void; + _addStringToInputStream: + (dataPtr: number, streamNamePtr: number, timestamp: number) => void; + _addFlatHashMapToInputStream: + (keysPtr: number, valuesPtr: number, count: number, streamNamePtr: number, + timestamp: number) => void; + _addProtoToInputStream: + (dataPtr: number, dataSize: number, protoNamePtr: number, + streamNamePtr: number, timestamp: number) => void; + // Input side packets + _addBoolToInputSidePacket: (data: boolean, streamNamePtr: number) => void; + _addDoubleToInputSidePacket: (data: number, streamNamePtr: number) => void; + _addFloatToInputSidePacket: (data: number, streamNamePtr: number) => void; + _addIntToInputSidePacket: (data: number, streamNamePtr: number) => void; + _addStringToInputSidePacket: (dataPtr: number, streamNamePtr: number) => void; + _addProtoToInputSidePacket: + (dataPtr: number, dataSize: number, protoNamePtr: number, + streamNamePtr: number) => void; + + // Wasm Module output listener entrypoints. Also built as part of + // gl_graph_runner_internal_multi_input. + simpleListeners?: {[outputStreamName: string]: (data: unknown) => void}; + vectorListeners?: { + [outputStreamName: string]: ( + data: unknown, index: number, length: number) => void + }; + _attachBoolListener: (streamNamePtr: number) => void; + _attachBoolVectorListener: (streamNamePtr: number) => void; + _attachDoubleListener: (streamNamePtr: number) => void; + _attachDoubleVectorListener: (streamNamePtr: number) => void; + _attachFloatListener: (streamNamePtr: number) => void; + _attachFloatVectorListener: (streamNamePtr: number) => void; + _attachIntListener: (streamNamePtr: number) => void; + _attachIntVectorListener: (streamNamePtr: number) => void; + _attachStringListener: (streamNamePtr: number) => void; + _attachStringVectorListener: (streamNamePtr: number) => void; + _attachProtoListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + _attachProtoVectorListener: + (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Requires dependency ":gl_graph_runner_audio_out", and will register an + // audio output listening function which can be tapped into dynamically during + // graph running via onAudioOutput. This call must be made before graph is + // initialized, but after wasmModule is instantiated. + _attachAudioOutputListener: () => void; + + // TODO: Refactor to just use a few numbers (perhaps refactor away + // from gl_graph_runner_internal.cc entirely to use something a little more + // streamlined; new version is _processFrame above). + _processGl: (frameDataPtr: number) => number; +} + +// Global declarations, for tapping into Window for Wasm blob running +declare global { + interface Window { + // Created by us using wasm-runner script + Module?: WasmModule|FileLocator; + // Created by wasm-runner script + ModuleFactory?: (fileLocator: FileLocator) => Promise; + } +} + +/** + * Fetches each URL in urls, executes them one-by-one in the order they are + * passed, and then returns (or throws if something went amiss). + */ +declare function importScripts(...urls: Array): void; + +/** + * Valid types of image sources which we can run our WasmMediaPipeLib over. + */ +export type ImageSource = + HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; + + +/** A listener that will be invoked with an absl::StatusCode and message. */ +export type ErrorListener = (code: number, message: string) => void; + +// Internal type of constructors used for initializing WasmMediaPipeLib and +// subclasses. +type WasmMediaPipeConstructor = + (new ( + module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => + LibType); + +/** + * Simple class to run an arbitrary image-in/image-out MediaPipe graph (i.e. + * as created by wasm_mediapipe_demo BUILD macro), and either render results + * into canvas, or else return the output WebGLTexture. Takes a WebAssembly + * Module (must be instantiated to self.Module). + */ +export class WasmMediaPipeLib { + // TODO: These should be protected/private, but are left exposed for + // now so that we can use proper TS mixins with this class as a base. This + // should be somewhat fixed when we create our .d.ts files. + readonly wasmModule: WasmModule; + readonly hasMultiStreamSupport: boolean; + autoResizeCanvas: boolean = true; + audioPtr: number|null; + audioSize: number; + + /** + * Creates a new MediaPipe WASM module. Must be called *after* wasm Module has + * initialized. Note that we take control of the GL canvas from here on out, + * and will resize it to fit input. + * + * @param module The underlying Wasm Module to use. + * @param glCanvas The type of the GL canvas to use, or `null` if no GL + * canvas should be initialzed. Initializes an offscreen canvas if not + * provided. + */ + constructor( + module: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + this.wasmModule = module; + this.audioPtr = null; + this.audioSize = 0; + this.hasMultiStreamSupport = + (typeof this.wasmModule._addIntToInputStream === 'function'); + + if (glCanvas !== undefined) { + this.wasmModule.canvas = glCanvas; + } else { + // If no canvas is provided, assume Chrome/Firefox and just make an + // OffscreenCanvas for GPU processing. + this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } + } + + /** + * Convenience helper to load a MediaPipe graph from a file and pass it to + * setGraph. + * @param graphFile The url of the MediaPipe graph file to load. + */ + async initializeGraph(graphFile: string): Promise { + // Fetch and set graph + const response = await fetch(graphFile); + const graphData = await response.arrayBuffer(); + const isBinary = + !(graphFile.endsWith('.pbtxt') || graphFile.endsWith('.textproto')); + this.setGraph(new Uint8Array(graphData), isBinary); + } + + /** + * Convenience helper for calling setGraph with a string representing a text + * proto config. + * @param graphConfig The text proto graph config, expected to be a string in + * default JavaScript UTF-16 format. + */ + setGraphFromString(graphConfig: string): void { + this.setGraph((new TextEncoder()).encode(graphConfig), false); + } + + /** + * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run + * over the video stream. Will replace the previously running MediaPipe graph, + * if there is one. + * @param graphData The raw MediaPipe graph data, either in binary + * protobuffer format (.binarypb), or else in raw text format (.pbtxt or + * .textproto). + * @param isBinary This should be set to true if the graph is in + * binary format, and false if it is in human-readable text format. + */ + setGraph(graphData: Uint8Array, isBinary: boolean): void { + const size = graphData.length; + const dataPtr = this.wasmModule._malloc(size); + this.wasmModule.HEAPU8.set(graphData, dataPtr); + if (isBinary) { + this.wasmModule._changeBinaryGraph(size, dataPtr); + } else { + this.wasmModule._changeTextGraph(size, dataPtr); + } + this.wasmModule._free(dataPtr); + } + + /** + * Configures the current graph to handle audio in a certain way. Must be + * called before the graph is set/started in order to use processAudio. + * @param numChannels The number of channels of audio input. Only 1 + * is supported for now. + * @param numSamples The number of samples that are taken in each + * audio capture. + * @param sampleRate The rate, in Hz, of the sampling. + */ + configureAudio(numChannels: number, numSamples: number, sampleRate: number) { + this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); + if (this.wasmModule._attachAudioOutputListener) { + this.wasmModule._attachAudioOutputListener(); + } + } + + /** + * Allows disabling automatic canvas resizing, in case clients want to control + * control this. + * @param resize True will re-enable automatic canvas resizing, while false + * will disable the feature. + */ + setAutoResizeCanvas(resize: boolean): void { + this.autoResizeCanvas = resize; + } + + /** + * Allows disabling the automatic render-to-screen code, in case clients don't + * need/want this. In particular, this removes the requirement for pipelines + * to have access to GPU resources, as well as the requirement for graphs to + * have "input_frames_gpu" and "output_frames_gpu" streams defined, so pure + * CPU pipelines and non-video pipelines can be created. + * NOTE: This only affects future graph initializations (via setGraph or + * initializeGraph), and does NOT affect the currently running graph, so + * calls to this should be made *before* setGraph/initializeGraph for the + * graph file being targeted. + * @param enabled True will re-enable automatic render-to-screen code and + * cause GPU resources to once again be requested, while false will + * disable the feature. + */ + setAutoRenderToScreen(enabled: boolean): void { + this.wasmModule._setAutoRenderToScreen(enabled); + } + + /** + * Bind texture to our internal canvas, and upload image source to GPU. + * Returns tuple [width, height] of texture. Intended for internal usage. + */ + bindTextureToStream(imageSource: ImageSource, streamNamePtr?: number): + [number, number] { + if (!this.wasmModule.canvas) { + throw new Error('No OpenGL canvas configured.'); + } + + if (!streamNamePtr) { + // TODO: Remove this path once completely refactored away. + console.assert(this.wasmModule._bindTextureToCanvas()); + } else { + this.wasmModule._bindTextureToStream(streamNamePtr); + } + const gl: any = + this.wasmModule.canvas.getContext('webgl2') || + this.wasmModule.canvas.getContext('webgl'); + console.assert(gl); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, imageSource); + + let width, height; + if ((imageSource as HTMLVideoElement).videoWidth) { + width = (imageSource as HTMLVideoElement).videoWidth; + height = (imageSource as HTMLVideoElement).videoHeight; + } else { + width = imageSource.width; + height = imageSource.height; + } + + if (this.autoResizeCanvas && + (width !== this.wasmModule.canvas.width || + height !== this.wasmModule.canvas.height)) { + this.wasmModule.canvas.width = width; + this.wasmModule.canvas.height = height; + } + + return [width, height]; + } + + /** + * Takes the raw data from a JS image source, and sends it to C++ to be + * processed, waiting synchronously for the response. Note that we will resize + * our GL canvas to fit the input, so input size should only change + * infrequently. + * @param imageSource An image source to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return texture? The WebGL texture reference, if one was produced. + */ + processGl(imageSource: ImageSource, timestamp: number): WebGLTexture + |undefined { + // Bind to default input stream + const [width, height] = this.bindTextureToStream(imageSource); + + // 2 ints and a ll (timestamp) + const frameDataPtr = this.wasmModule._malloc(16); + this.wasmModule.HEAPU32[frameDataPtr / 4] = width; + this.wasmModule.HEAPU32[(frameDataPtr / 4) + 1] = height; + this.wasmModule.HEAPF64[(frameDataPtr / 8) + 1] = timestamp; + // outputPtr points in HEAPF32-space to running mspf calculations, which we + // don't use at the moment. + // tslint:disable-next-line:no-unused-variable + const outputPtr = this.wasmModule._processGl(frameDataPtr) / 4; + this.wasmModule._free(frameDataPtr); + + // TODO: Hook up WebGLTexture output, when given. + // TODO: Allow user to toggle whether or not to render output into canvas. + return undefined; + } + + /** + * Converts JavaScript string input parameters into C++ c-string pointers. + * See b/204830158 for more details. Intended for internal usage. + */ + wrapStringPtr(stringData: string, stringPtrFunc: (ptr: number) => void): + void { + if (!this.hasMultiStreamSupport) { + console.error( + 'No wasm multistream support detected: ensure dependency ' + + 'inclusion of :gl_graph_runner_internal_multi_input target'); + } + const stringDataPtr = this.wasmModule.stringToNewUTF8(stringData); + stringPtrFunc(stringDataPtr); + this.wasmModule._free(stringDataPtr); + } + + /** + * Converts JavaScript string input parameters into C++ c-string pointers. + * See b/204830158 for more details. + */ + wrapStringPtrPtr(stringData: string[], ptrFunc: (ptr: number) => void): void { + if (!this.hasMultiStreamSupport) { + console.error( + 'No wasm multistream support detected: ensure dependency ' + + 'inclusion of :gl_graph_runner_internal_multi_input target'); + } + const uint32Array = new Uint32Array(stringData.length); + for (let i = 0; i < stringData.length; i++) { + uint32Array[i] = this.wasmModule.stringToNewUTF8(stringData[i]); + } + const heapSpace = this.wasmModule._malloc(uint32Array.length * 4); + this.wasmModule.HEAPU32.set(uint32Array, heapSpace >> 2); + + ptrFunc(heapSpace); + for (const uint32ptr of uint32Array) { + this.wasmModule._free(uint32ptr); + } + this.wasmModule._free(heapSpace); + } + + /** + * Ensures existence of the simple listeners table and registers the callback. + * Intended for internal usage. + */ + setListener(outputStreamName: string, callbackFcn: (data: T) => void) { + this.wasmModule.simpleListeners = this.wasmModule.simpleListeners || {}; + this.wasmModule.simpleListeners[outputStreamName] = + callbackFcn as (data: unknown) => void; + } + + /** + * Ensures existence of the vector listeners table and registers the callback. + * Intended for internal usage. + */ + setVectorListener( + outputStreamName: string, callbackFcn: (data: T[]) => void) { + const buffer: T[] = []; + this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; + this.wasmModule.vectorListeners[outputStreamName] = + (data: unknown, index: number, length: number) => { + // The Wasm listener gets invoked once for each element. Once we + // receive all elements, we invoke the registered callback with the + // full array. + buffer[index] = data as T; + if (index === length - 1) { + // Invoke the user callback directly, as the Wasm layer may clean up + // the underlying data elements once we leave the scope of the + // listener. + callbackFcn(buffer); + } + }; + } + + /** + * Attaches a listener that will be invoked when the MediaPipe framework + * returns an error. + */ + attachErrorListener(callbackFcn: (code: number, message: string) => void) { + this.wasmModule.errorListener = callbackFcn; + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStream(audioData: Float32Array, timestamp: number) { + // 4 bytes for each F32 + const size = audioData.length * 4; + if (this.audioSize !== size) { + if (this.audioPtr) { + this.wasmModule._free(this.audioPtr); + } + this.audioPtr = this.wasmModule._malloc(size); + this.audioSize = size; + } + this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); + this.wasmModule._processAudio(this.audioPtr!, timestamp); + } + + /** + * Takes the relevant information from the HTML video or image element, and + * passes it into the WebGL-based graph for processing on the given stream at + * the given timestamp. Can be used for additional auxiliary GpuBuffer input + * streams. Processing will not occur until a blocking call (like + * processVideoGl or finishProcessing) is made. For use with + * 'gl_graph_runner_internal_multi_input'. + * @param imageSource Reference to the video frame we wish to add into our + * graph. + * @param streamName The name of the MediaPipe graph stream to add the frame + * to. + * @param timestamp The timestamp of the input frame, in ms. + */ + addGpuBufferToStream( + imageSource: ImageSource, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + const [width, height] = + this.bindTextureToStream(imageSource, streamNamePtr); + this.wasmModule._addBoundTextureToStream( + streamNamePtr, width, height, timestamp); + }); + } + + /** + * Sends a boolean packet into the specified stream at the given timestamp. + * @param data The boolean data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addBoolToStream(data: boolean, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addBoolToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends a double packet into the specified stream at the given timestamp. + * @param data The double data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addDoubleToStream(data: number, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addDoubleToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends a float packet into the specified stream at the given timestamp. + * @param data The float data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addFloatToStream(data: number, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + // NOTE: _addFloatToStream and _addIntToStream are reserved for JS + // Calculators currently; we may want to revisit this naming scheme in the + // future. + this.wasmModule._addFloatToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends an integer packet into the specified stream at the given timestamp. + * @param data The integer data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addIntToStream(data: number, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addIntToInputStream(data, streamNamePtr, timestamp); + }); + } + + /** + * Sends a string packet into the specified stream at the given timestamp. + * @param data The string data to send. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addStringToStream(data: string, streamName: string, timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wrapStringPtr(data, (dataPtr: number) => { + this.wasmModule._addStringToInputStream( + dataPtr, streamNamePtr, timestamp); + }); + }); + } + + /** + * Sends a Record packet into the specified stream at the + * given timestamp. + * @param data The records to send (will become a + * std::flat_hash_map, streamName: string, + timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wrapStringPtrPtr(Object.keys(data), (keyList: number) => { + this.wrapStringPtrPtr(Object.values(data), (valueList: number) => { + this.wasmModule._addFlatHashMapToInputStream( + keyList, valueList, Object.keys(data).length, streamNamePtr, + timestamp); + }); + }); + }); + } + + /** + * Sends a serialized protobuffer packet into the specified stream at the + * given timestamp, to be parsed into the specified protobuffer type. + * @param data The binary (serialized) raw protobuffer data. + * @param protoType The C++ namespaced type this protobuffer data corresponds + * to. It will be converted to this type when output as a packet into the + * graph. + * @param streamName The name of the graph input stream to send data into. + * @param timestamp The timestamp of the input data, in ms. + */ + addProtoToStream( + data: Uint8Array, protoType: string, streamName: string, + timestamp: number): void { + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wrapStringPtr(protoType, (protoTypePtr: number) => { + // Deep-copy proto data into Wasm heap + const dataPtr = this.wasmModule._malloc(data.length); + // TODO: Ensure this is the fastest way to copy this data. + this.wasmModule.HEAPU8.set(data, dataPtr); + this.wasmModule._addProtoToInputStream( + dataPtr, data.length, protoTypePtr, streamNamePtr, timestamp); + this.wasmModule._free(dataPtr); + }); + }); + } + + /** + * Attaches a boolean packet to the specified input_side_packet. + * @param data The boolean data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addBoolToInputSidePacket(data: boolean, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addBoolToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a double packet to the specified input_side_packet. + * @param data The double data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addDoubleToInputSidePacket(data: number, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addDoubleToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a float packet to the specified input_side_packet. + * @param data The float data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addFloatToInputSidePacket(data: number, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addFloatToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a integer packet to the specified input_side_packet. + * @param data The integer data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addIntToInputSidePacket(data: number, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wasmModule._addIntToInputSidePacket(data, sidePacketNamePtr); + }); + } + + /** + * Attaches a string packet to the specified input_side_packet. + * @param data The string data to send. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addStringToInputSidePacket(data: string, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wrapStringPtr(data, (dataPtr: number) => { + this.wasmModule._addStringToInputSidePacket(dataPtr, sidePacketNamePtr); + }); + }); + } + + /** + * Attaches a serialized proto packet to the specified input_side_packet. + * @param data The binary (serialized) raw protobuffer data. + * @param protoType The C++ namespaced type this protobuffer data corresponds + * to. It will be converted to this type for use in the graph. + * @param sidePacketName The name of the graph input side packet to send data + * into. + */ + addProtoToInputSidePacket( + data: Uint8Array, protoType: string, sidePacketName: string): void { + this.wrapStringPtr(sidePacketName, (sidePacketNamePtr: number) => { + this.wrapStringPtr(protoType, (protoTypePtr: number) => { + // Deep-copy proto data into Wasm heap + const dataPtr = this.wasmModule._malloc(data.length); + // TODO: Ensure this is the fastest way to copy this data. + this.wasmModule.HEAPU8.set(data, dataPtr); + this.wasmModule._addProtoToInputSidePacket( + dataPtr, data.length, protoTypePtr, sidePacketNamePtr); + this.wasmModule._free(dataPtr); + }); + }); + } + + /** + * Attaches a boolean packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab boolean + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachBoolListener( + outputStreamName: string, callbackFcn: (data: boolean) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for bool packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachBoolListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a bool[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachBoolVectorListener( + outputStreamName: string, callbackFcn: (data: boolean[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachBoolVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches an int packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab int + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachIntListener( + outputStreamName: string, callbackFcn: (data: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for int packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachIntListener(outputStreamNamePtr); + }); + } + + /** + * Attaches an int[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachIntVectorListener( + outputStreamName: string, callbackFcn: (data: number[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachIntVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a double packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab double + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachDoubleListener( + outputStreamName: string, callbackFcn: (data: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for double packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachDoubleListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a double[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachDoubleVectorListener( + outputStreamName: string, callbackFcn: (data: number[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachDoubleVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a float packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab float + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachFloatListener( + outputStreamName: string, callbackFcn: (data: number) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for float packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachFloatListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a float[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachFloatVectorListener( + outputStreamName: string, callbackFcn: (data: number[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachFloatVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a string packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab string + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachStringListener( + outputStreamName: string, callbackFcn: (data: string) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachStringListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a string[] packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab + * std::vector data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. + */ + attachStringVectorListener( + outputStreamName: string, callbackFcn: (data: string[]) => void): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for std::vector packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachStringVectorListener(outputStreamNamePtr); + }); + } + + /** + * Attaches a serialized proto packet listener to the specified output_stream. + * @param outputStreamName The name of the graph output stream to grab binary + * serialized proto data from (in Uint8Array format). + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that by default the data is only guaranteed to + * exist for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. If the proto data needs to be able to outlive the call, you + * may set the optional makeDeepCopy parameter to true, or can manually + * deep-copy the data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). + */ + attachProtoListener( + outputStreamName: string, callbackFcn: (data: Uint8Array) => void, + makeDeepCopy?: boolean): void { + // Set up our TS listener to receive any packets for this stream. + this.setListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for binary serialized proto data packets on this + // stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachProtoListener( + outputStreamNamePtr, makeDeepCopy || false); + }); + } + + /** + * Attaches a listener for an array of serialized proto packets to the + * specified output_stream. + * @param outputStreamName The name of the graph output stream to grab a + * vector of binary serialized proto data from (in Uint8Array[] format). + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that by default the data is only guaranteed to + * exist for the duration of the callback, and the callback will be called + * inline, so it should not perform overly complicated (or any async) + * behavior. If the proto data needs to be able to outlive the call, you + * may set the optional makeDeepCopy parameter to true, or can manually + * deep-copy the data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). + */ + attachProtoVectorListener( + outputStreamName: string, callbackFcn: (data: Uint8Array[]) => void, + makeDeepCopy?: boolean): void { + // Set up our TS listener to receive any packets for this stream. + this.setVectorListener(outputStreamName, callbackFcn); + + // Tell our graph to listen for a vector of binary serialized proto packets + // on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachProtoVectorListener( + outputStreamNamePtr, makeDeepCopy || false); + }); + } + + /** + * Sets a listener to be called back with audio output packet data, as a + * Float32Array, when graph has finished processing it. + * @param audioOutputListener The caller's listener function. + */ + setOnAudioOutput(audioOutputListener: AudioOutputListener) { + this.wasmModule.onAudioOutput = audioOutputListener; + if (!this.wasmModule._attachAudioOutputListener) { + console.warn( + 'Attempting to use AudioOutputListener without support for ' + + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); + } + } + + /** + * Forces all queued-up packets to be pushed through the MediaPipe graph as + * far as possible, performing all processing until no more processing can be + * done. + */ + finishProcessing(): void { + this.wasmModule._waitUntilIdle(); + } +} + +// Quick private helper to run the given script safely +async function runScript(scriptUrl: string) { + if (typeof importScripts === 'function') { + importScripts(scriptUrl.toString()); + } else { + await new Promise((resolve, reject) => { + fetch(scriptUrl).then(response => response.text()).then(text => Function(text)).then(resolve, reject); + }); + } +} + +/** + * Global function to initialize Wasm blob and load runtime assets for a + * specialized MediaPipe library. This allows us to create a requested + * subclass inheriting from WasmMediaPipeLib. + * @param constructorFcn The name of the class to instantiate via "new". + * @param wasmLoaderScript Url for the wasm-runner script; produced by the build + * process. + * @param assetLoaderScript Url for the asset-loading script; produced by the + * build process. + * @param fileLocator A function to override the file locations for assets + * loaded by the MediaPipe library. + * @return promise A promise which will resolve when initialization has + * completed successfully. + */ +export async function createMediaPipeLib( + constructorFcn: WasmMediaPipeConstructor, + wasmLoaderScript?: string, + assetLoaderScript?: string, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + fileLocator?: FileLocator): Promise { + const scripts = []; + // Run wasm-loader script here + if (wasmLoaderScript) { + scripts.push(wasmLoaderScript); + } + // Run asset-loader script here + if (assetLoaderScript) { + scripts.push(assetLoaderScript); + } + // Load scripts in parallel, browser will execute them in sequence. + if (scripts.length) { + await Promise.all(scripts.map(runScript)); + } + if (!self.ModuleFactory) { + throw new Error('ModuleFactory not set.'); + } + // TODO: Ensure that fileLocator is passed in by all users + // and make it required + const module = + await self.ModuleFactory(fileLocator || self.Module as FileLocator); + // Don't reuse factory or module seed + self.ModuleFactory = self.Module = undefined; + return new constructorFcn(module, glCanvas); +} + +/** + * Global function to initialize Wasm blob and load runtime assets for a generic + * MediaPipe library. + * @param wasmLoaderScript Url for the wasm-runner script; produced by the build + * process. + * @param assetLoaderScript Url for the asset-loading script; produced by the + * build process. + * @param fileLocator A function to override the file locations for assets + * loaded by the MediaPipe library. + * @return promise A promise which will resolve when initialization has + * completed successfully. + */ +export async function createWasmMediaPipeLib( + wasmLoaderScript?: string, + assetLoaderScript?: string, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, + fileLocator?: FileLocator): Promise { + return createMediaPipeLib( + WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + fileLocator); +}