Allow users to pass canvas element

PiperOrigin-RevId: 518870611
This commit is contained in:
Sebastian Schmidt 2023-03-23 08:44:23 -07:00 committed by Copybara-Service
parent 5998e96eed
commit 1c9e6894f3
19 changed files with 160 additions and 112 deletions

View File

@ -60,9 +60,8 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions):
Promise<AudioClassifier> { Promise<AudioClassifier> {
return AudioTaskRunner.createInstance( return AudioTaskRunner.createAudioInstance(
AudioClassifier, /* initializeCanvas= */ false, wasmFileset, AudioClassifier, wasmFileset, audioClassifierOptions);
audioClassifierOptions);
} }
/** /**
@ -75,9 +74,8 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<AudioClassifier> { modelAssetBuffer: Uint8Array): Promise<AudioClassifier> {
return AudioTaskRunner.createInstance( return AudioTaskRunner.createAudioInstance(
AudioClassifier, /* initializeCanvas= */ false, wasmFileset, AudioClassifier, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -91,7 +89,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<AudioClassifier> { modelAssetPath: string): Promise<AudioClassifier> {
return AudioTaskRunner.createInstance( return AudioTaskRunner.createInstance(
AudioClassifier, /* initializeCanvas= */ false, wasmFileset, AudioClassifier, /* canvas= */ null, wasmFileset,
{baseOptions: {modelAssetPath}}); {baseOptions: {modelAssetPath}});
} }

View File

@ -60,9 +60,8 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> { audioEmbedderOptions: AudioEmbedderOptions): Promise<AudioEmbedder> {
return AudioTaskRunner.createInstance( return AudioTaskRunner.createAudioInstance(
AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, AudioEmbedder, wasmFileset, audioEmbedderOptions);
audioEmbedderOptions);
} }
/** /**
@ -75,9 +74,8 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> { modelAssetBuffer: Uint8Array): Promise<AudioEmbedder> {
return AudioTaskRunner.createInstance( return AudioTaskRunner.createAudioInstance(
AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, AudioEmbedder, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -90,9 +88,8 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<AudioEmbedder> { modelAssetPath: string): Promise<AudioEmbedder> {
return AudioTaskRunner.createInstance( return AudioTaskRunner.createAudioInstance(
AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, AudioEmbedder, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -7,5 +7,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_ts_library( mediapipe_ts_library(
name = "audio_task_runner", name = "audio_task_runner",
srcs = ["audio_task_runner.ts"], srcs = ["audio_task_runner.ts"],
deps = ["//mediapipe/tasks/web/core:task_runner"], deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
) )

View File

@ -15,11 +15,22 @@
*/ */
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
/** Base class for all MediaPipe Audio Tasks. */ /** Base class for all MediaPipe Audio Tasks. */
export abstract class AudioTaskRunner<T> extends TaskRunner { export abstract class AudioTaskRunner<T> extends TaskRunner {
private defaultSampleRate = 48000; private defaultSampleRate = 48000;
protected static async createAudioInstance<T, I extends AudioTaskRunner<T>>(
type: WasmMediaPipeConstructor<I>, fileset: WasmFileset,
options: TaskRunnerOptions): Promise<I> {
return TaskRunner.createInstance(
type, /* canvas= */ null, fileset, options);
}
/** /**
* Sets the sample rate for API calls that omit an explicit sample rate. * Sets the sample rate for API calls that omit an explicit sample rate.
* `48000` is used as a default if this method is not called. * `48000` is used as a default if this method is not called.

View File

@ -30,6 +30,7 @@ const NO_ASSETS = undefined;
// tslint:disable-next-line:enforce-name-casing // tslint:disable-next-line:enforce-name-casing
const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
/** /**
* An implementation of the GraphRunner that exposes the resource graph * An implementation of the GraphRunner that exposes the resource graph
* service. * service.
@ -42,7 +43,8 @@ export class CachedGraphRunner extends CachedGraphRunnerType {}
* @return A fully instantiated instance of `T`. * @return A fully instantiated instance of `T`.
*/ */
export async function createTaskRunner<T extends TaskRunner>( export async function createTaskRunner<T extends TaskRunner>(
type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean, type: WasmMediaPipeConstructor<T>,
canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined,
fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> { fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> {
const fileLocator: FileLocator = { const fileLocator: FileLocator = {
locateFile() { locateFile() {
@ -51,12 +53,6 @@ export async function createTaskRunner<T extends TaskRunner>(
} }
}; };
// Initialize a canvas if requested. If OffscreenCanvas is available, we
// let the graph runner initialize it by passing `undefined`.
const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ?
document.createElement('canvas') :
undefined) :
null;
const instance = await createMediaPipeLib( const instance = await createMediaPipeLib(
type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator);
await instance.setOptions(options); await instance.setOptions(options);
@ -75,9 +71,10 @@ export abstract class TaskRunner {
* @return A fully instantiated instance of `T`. * @return A fully instantiated instance of `T`.
*/ */
protected static async createInstance<T extends TaskRunner>( protected static async createInstance<T extends TaskRunner>(
type: WasmMediaPipeConstructor<T>, initializeCanvas: boolean, type: WasmMediaPipeConstructor<T>,
canvas: HTMLCanvasElement|OffscreenCanvas|null|undefined,
fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> { fileset: WasmFileset, options: TaskRunnerOptions): Promise<T> {
return createTaskRunner(type, initializeCanvas, fileset, options); return createTaskRunner(type, canvas, fileset, options);
} }
/** @hideconstructor protected */ /** @hideconstructor protected */

View File

@ -58,7 +58,7 @@ export class TextClassifier extends TaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> { textClassifierOptions: TextClassifierOptions): Promise<TextClassifier> {
return TaskRunner.createInstance( return TaskRunner.createInstance(
TextClassifier, /* initializeCanvas= */ false, wasmFileset, TextClassifier, /* canvas= */ null, wasmFileset,
textClassifierOptions); textClassifierOptions);
} }
@ -73,7 +73,7 @@ export class TextClassifier extends TaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<TextClassifier> { modelAssetBuffer: Uint8Array): Promise<TextClassifier> {
return TaskRunner.createInstance( return TaskRunner.createInstance(
TextClassifier, /* initializeCanvas= */ false, wasmFileset, TextClassifier, /* canvas= */ null, wasmFileset,
{baseOptions: {modelAssetBuffer}}); {baseOptions: {modelAssetBuffer}});
} }
@ -88,7 +88,7 @@ export class TextClassifier extends TaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<TextClassifier> { modelAssetPath: string): Promise<TextClassifier> {
return TaskRunner.createInstance( return TaskRunner.createInstance(
TextClassifier, /* initializeCanvas= */ false, wasmFileset, TextClassifier, /* canvas= */ null, wasmFileset,
{baseOptions: {modelAssetPath}}); {baseOptions: {modelAssetPath}});
} }

View File

@ -62,8 +62,7 @@ export class TextEmbedder extends TaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> { textEmbedderOptions: TextEmbedderOptions): Promise<TextEmbedder> {
return TaskRunner.createInstance( return TaskRunner.createInstance(
TextEmbedder, /* initializeCanvas= */ false, wasmFileset, TextEmbedder, /* canvas= */ null, wasmFileset, textEmbedderOptions);
textEmbedderOptions);
} }
/** /**
@ -77,7 +76,7 @@ export class TextEmbedder extends TaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<TextEmbedder> { modelAssetBuffer: Uint8Array): Promise<TextEmbedder> {
return TaskRunner.createInstance( return TaskRunner.createInstance(
TextEmbedder, /* initializeCanvas= */ false, wasmFileset, TextEmbedder, /* canvas= */ null, wasmFileset,
{baseOptions: {modelAssetBuffer}}); {baseOptions: {modelAssetBuffer}});
} }
@ -92,7 +91,7 @@ export class TextEmbedder extends TaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<TextEmbedder> { modelAssetPath: string): Promise<TextEmbedder> {
return TaskRunner.createInstance( return TaskRunner.createInstance(
TextEmbedder, /* initializeCanvas= */ false, wasmFileset, TextEmbedder, /* canvas= */ null, wasmFileset,
{baseOptions: {modelAssetPath}}); {baseOptions: {modelAssetPath}});
} }

View File

@ -25,6 +25,14 @@ export type RunningMode = 'IMAGE'|'VIDEO';
/** The options for configuring a MediaPipe vision task. */ /** The options for configuring a MediaPipe vision task. */
export declare interface VisionTaskOptions extends TaskRunnerOptions { export declare interface VisionTaskOptions extends TaskRunnerOptions {
/**
* The canvas element to bind textures to. This has to be set for GPU
* processing. The task will initialize a WebGL context and throw an eror if
* this fails (e.g. if you have already initialized a different type of
* context).
*/
canvas?: HTMLCanvasElement|OffscreenCanvas;
/** /**
* The running mode of the task. Default to the image mode. * The running mode of the task. Default to the image mode.
* Vision tasks have two running modes: * Vision tasks have two running modes:

View File

@ -36,6 +36,8 @@ const IMAGE = {} as unknown as HTMLImageElement;
const TIMESTAMP = 42; const TIMESTAMP = 42;
class VisionTaskRunnerFake extends VisionTaskRunner { class VisionTaskRunnerFake extends VisionTaskRunner {
override graphRunner!: VisionGraphRunner;
baseOptions = new BaseOptionsProto(); baseOptions = new BaseOptionsProto();
fakeGraphRunner: jasmine.SpyObj<VisionGraphRunner>; fakeGraphRunner: jasmine.SpyObj<VisionGraphRunner>;
expectedImageSource?: ImageSource; expectedImageSource?: ImageSource;
@ -46,7 +48,7 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
jasmine.createSpyObj<VisionGraphRunner>([ jasmine.createSpyObj<VisionGraphRunner>([
'addProtoToStream', 'addGpuBufferAsImageToStream', 'addProtoToStream', 'addGpuBufferAsImageToStream',
'setAutoRenderToScreen', 'registerModelResourcesGraphService', 'setAutoRenderToScreen', 'registerModelResourcesGraphService',
'finishProcessing' 'finishProcessing', 'wasmModule'
]), ]),
IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed); IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed);
@ -72,7 +74,7 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
expect(imageSource).toBe(this.expectedImageSource!); expect(imageSource).toBe(this.expectedImageSource!);
}); });
// SetOptions with a modelAssetBuffer runs synchonously // SetOptions with a modelAssetBuffer runs synchronously
void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}}); void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}});
} }
@ -165,6 +167,24 @@ describe('VisionTaskRunner', () => {
}).toThrowError(/Task is not initialized with video mode./); }).toThrowError(/Task is not initialized with video mode./);
}); });
it('validates that the canvas cannot be changed', async () => {
if (typeof OffscreenCanvas === 'undefined') {
console.log('Test is not supported under Node.');
return;
}
const visionTaskRunner = new VisionTaskRunnerFake();
const canvas = new OffscreenCanvas(1, 1);
visionTaskRunner.graphRunner.wasmModule.canvas = canvas;
expect(() => {
visionTaskRunner.setOptions({canvas});
}).not.toThrow();
expect(() => {
visionTaskRunner.setOptions({canvas: new OffscreenCanvas(2, 2)});
}).toThrowError(/You must create a new task to reset the canvas./);
});
it('sends packets to graph', async () => { it('sends packets to graph', async () => {
const visionTaskRunner = new VisionTaskRunnerFake(); const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); await visionTaskRunner.setOptions({runningMode: 'VIDEO'});

View File

@ -16,8 +16,9 @@
import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb';
import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service';
@ -32,8 +33,33 @@ export class VisionGraphRunner extends GraphRunnerVisionType {}
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/**
* Creates a canvas for a MediaPipe vision task. Returns `undefined` if the
* GraphRunner should create its own canvas.
*/
function createCanvas(): HTMLCanvasElement|OffscreenCanvas|undefined {
// Returns an HTML canvas or `undefined` if OffscreenCanvas is available
// (since the graph runner can initialize its own OffscreenCanvas).
return typeof OffscreenCanvas === 'undefined' ?
document.createElement('canvas') :
undefined;
}
/** Base class for all MediaPipe Vision Tasks. */ /** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner extends TaskRunner { export abstract class VisionTaskRunner extends TaskRunner {
protected static async createVisionInstance<T extends VisionTaskRunner>(
type: WasmMediaPipeConstructor<T>, fileset: WasmFileset,
options: VisionTaskOptions): Promise<T> {
if (options.baseOptions?.delegate === 'GPU') {
if (!options.canvas) {
throw new Error('You must specify a canvas for GPU processing.');
}
}
const canvas = options.canvas ?? createCanvas();
return TaskRunner.createInstance(type, canvas, fileset, options);
}
/** /**
* Constructor to initialize a `VisionTaskRunner`. * Constructor to initialize a `VisionTaskRunner`.
* *
@ -62,6 +88,13 @@ export abstract class VisionTaskRunner extends TaskRunner {
!!options.runningMode && options.runningMode !== 'IMAGE'; !!options.runningMode && options.runningMode !== 'IMAGE';
this.baseOptions.setUseStreamMode(useStreamMode); this.baseOptions.setUseStreamMode(useStreamMode);
} }
if ('canvas' in options) {
if (this.graphRunner.wasmModule.canvas !== options.canvas) {
throw new Error('You must create a new task to reset the canvas.');
}
}
return super.applyOptions(options); return super.applyOptions(options);
} }

View File

@ -58,9 +58,8 @@ export class FaceStylizer extends VisionTaskRunner {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
faceStylizerOptions: FaceStylizerOptions): Promise<FaceStylizer> { faceStylizerOptions: FaceStylizerOptions): Promise<FaceStylizer> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
FaceStylizer, /* initializeCanvas= */ true, wasmFileset, FaceStylizer, wasmFileset, faceStylizerOptions);
faceStylizerOptions);
} }
/** /**
@ -73,9 +72,8 @@ export class FaceStylizer extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<FaceStylizer> { modelAssetBuffer: Uint8Array): Promise<FaceStylizer> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
FaceStylizer, /* initializeCanvas= */ true, wasmFileset, FaceStylizer, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -88,9 +86,8 @@ export class FaceStylizer extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<FaceStylizer> { modelAssetPath: string): Promise<FaceStylizer> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
FaceStylizer, /* initializeCanvas= */ true, wasmFileset, FaceStylizer, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -85,9 +85,8 @@ export class GestureRecognizer extends VisionTaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
gestureRecognizerOptions: GestureRecognizerOptions): gestureRecognizerOptions: GestureRecognizerOptions):
Promise<GestureRecognizer> { Promise<GestureRecognizer> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, GestureRecognizer, wasmFileset, gestureRecognizerOptions);
gestureRecognizerOptions);
} }
/** /**
@ -100,9 +99,8 @@ export class GestureRecognizer extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> { modelAssetBuffer: Uint8Array): Promise<GestureRecognizer> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, GestureRecognizer, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -115,9 +113,8 @@ export class GestureRecognizer extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<GestureRecognizer> { modelAssetPath: string): Promise<GestureRecognizer> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, GestureRecognizer, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -75,9 +75,8 @@ export class HandLandmarker extends VisionTaskRunner {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> { handLandmarkerOptions: HandLandmarkerOptions): Promise<HandLandmarker> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
HandLandmarker, /* initializeCanvas= */ true, wasmFileset, HandLandmarker, wasmFileset, handLandmarkerOptions);
handLandmarkerOptions);
} }
/** /**
@ -90,9 +89,8 @@ export class HandLandmarker extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<HandLandmarker> { modelAssetBuffer: Uint8Array): Promise<HandLandmarker> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
HandLandmarker, /* initializeCanvas= */ true, wasmFileset, HandLandmarker, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -105,9 +103,8 @@ export class HandLandmarker extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<HandLandmarker> { modelAssetPath: string): Promise<HandLandmarker> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
HandLandmarker, /* initializeCanvas= */ true, wasmFileset, HandLandmarker, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -60,9 +60,8 @@ export class ImageClassifier extends VisionTaskRunner {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions):
Promise<ImageClassifier> { Promise<ImageClassifier> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageClassifier, /* initializeCanvas= */ true, wasmFileset, ImageClassifier, wasmFileset, imageClassifierOptions);
imageClassifierOptions);
} }
/** /**
@ -75,9 +74,8 @@ export class ImageClassifier extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ImageClassifier> { modelAssetBuffer: Uint8Array): Promise<ImageClassifier> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageClassifier, /* initializeCanvas= */ true, wasmFileset, ImageClassifier, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -90,9 +88,8 @@ export class ImageClassifier extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ImageClassifier> { modelAssetPath: string): Promise<ImageClassifier> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageClassifier, /* initializeCanvas= */ true, wasmFileset, ImageClassifier, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -63,9 +63,8 @@ export class ImageEmbedder extends VisionTaskRunner {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> { imageEmbedderOptions: ImageEmbedderOptions): Promise<ImageEmbedder> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, ImageEmbedder, wasmFileset, imageEmbedderOptions);
imageEmbedderOptions);
} }
/** /**
@ -78,9 +77,8 @@ export class ImageEmbedder extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> { modelAssetBuffer: Uint8Array): Promise<ImageEmbedder> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, ImageEmbedder, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -93,9 +91,8 @@ export class ImageEmbedder extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ImageEmbedder> { modelAssetPath: string): Promise<ImageEmbedder> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, ImageEmbedder, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -64,9 +64,8 @@ export class ImageSegmenter extends VisionTaskRunner {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
imageSegmenterOptions: ImageSegmenterOptions): Promise<ImageSegmenter> { imageSegmenterOptions: ImageSegmenterOptions): Promise<ImageSegmenter> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, ImageSegmenter, wasmFileset, imageSegmenterOptions);
imageSegmenterOptions);
} }
/** /**
@ -79,9 +78,8 @@ export class ImageSegmenter extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ImageSegmenter> { modelAssetBuffer: Uint8Array): Promise<ImageSegmenter> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, ImageSegmenter, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -94,9 +92,8 @@ export class ImageSegmenter extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ImageSegmenter> { modelAssetPath: string): Promise<ImageSegmenter> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, ImageSegmenter, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -87,9 +87,8 @@ export class InteractiveSegmenter extends VisionTaskRunner {
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
interactiveSegmenterOptions: InteractiveSegmenterOptions): interactiveSegmenterOptions: InteractiveSegmenterOptions):
Promise<InteractiveSegmenter> { Promise<InteractiveSegmenter> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, InteractiveSegmenter, wasmFileset, interactiveSegmenterOptions);
interactiveSegmenterOptions);
} }
/** /**
@ -103,9 +102,8 @@ export class InteractiveSegmenter extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<InteractiveSegmenter> { modelAssetBuffer: Uint8Array): Promise<InteractiveSegmenter> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, InteractiveSegmenter, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -119,9 +117,8 @@ export class InteractiveSegmenter extends VisionTaskRunner {
static createFromModelPath( static createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<InteractiveSegmenter> { modelAssetPath: string): Promise<InteractiveSegmenter> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, InteractiveSegmenter, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -59,9 +59,8 @@ export class ObjectDetector extends VisionTaskRunner {
static createFromOptions( static createFromOptions(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> { objectDetectorOptions: ObjectDetectorOptions): Promise<ObjectDetector> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ObjectDetector, /* initializeCanvas= */ true, wasmFileset, ObjectDetector, wasmFileset, objectDetectorOptions);
objectDetectorOptions);
} }
/** /**
@ -74,9 +73,8 @@ export class ObjectDetector extends VisionTaskRunner {
static createFromModelBuffer( static createFromModelBuffer(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ObjectDetector> { modelAssetBuffer: Uint8Array): Promise<ObjectDetector> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ObjectDetector, /* initializeCanvas= */ true, wasmFileset, ObjectDetector, wasmFileset, {baseOptions: {modelAssetBuffer}});
{baseOptions: {modelAssetBuffer}});
} }
/** /**
@ -89,9 +87,8 @@ export class ObjectDetector extends VisionTaskRunner {
static async createFromModelPath( static async createFromModelPath(
wasmFileset: WasmFileset, wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ObjectDetector> { modelAssetPath: string): Promise<ObjectDetector> {
return VisionTaskRunner.createInstance( return VisionTaskRunner.createVisionInstance(
ObjectDetector, /* initializeCanvas= */ true, wasmFileset, ObjectDetector, wasmFileset, {baseOptions: {modelAssetPath}});
{baseOptions: {modelAssetPath}});
} }
/** @hideconstructor */ /** @hideconstructor */

View File

@ -352,10 +352,15 @@ export class GraphRunner {
} else { } else {
this.wasmModule._bindTextureToStream(streamNamePtr); this.wasmModule._bindTextureToStream(streamNamePtr);
} }
const gl: any = const gl =
this.wasmModule.canvas.getContext('webgl2') || (this.wasmModule.canvas.getContext('webgl2') ||
this.wasmModule.canvas.getContext('webgl'); this.wasmModule.canvas.getContext('webgl')) as WebGL2RenderingContext |
console.assert(gl); WebGLRenderingContext | null;
if (!gl) {
throw new Error(
'Failed to obtain WebGL context from the provided canvas. ' +
'`getContext()` should only be invoked with `webgl` or `webgl2`.');
}
gl.texImage2D( gl.texImage2D(
gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, imageSource); gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, imageSource);