diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h index 1983272fc..e2647be71 100644 --- a/mediapipe/tasks/cc/vision/core/image_processing_options.h +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -28,7 +28,7 @@ namespace core { // Options for image processing. // // If both region-or-interest and rotation are specified, the crop around the -// region-of-interest is extracted first, the the specified rotation is applied +// region-of-interest is extracted first, then the specified rotation is applied // to the crop. struct ImageProcessingOptions { // The optional region-of-interest to crop from the image. If not specified, diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index fb0fdff16..a0db59d0b 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -24,3 +24,8 @@ mediapipe_ts_declaration( name = "embedding_result", srcs = ["embedding_result.d.ts"], ) + +mediapipe_ts_declaration( + name = "rect", + srcs = ["rect.d.ts"], +) diff --git a/mediapipe/tasks/web/components/containers/rect.d.ts b/mediapipe/tasks/web/components/containers/rect.d.ts new file mode 100644 index 000000000..9afece9ca --- /dev/null +++ b/mediapipe/tasks/web/components/containers/rect.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + */ +export declare interface Rect { + left: number; + top: number; + right: number; + bottom: number; +} + +/** + * Defines a rectangle, used e.g. as part of detection results or as input + * region-of-interest. + * + * The coordinates are normalized with respect to the image dimensions, i.e. + * generally in [0,1] but they may exceed these bounds if describing a region + * overlapping the image. The origin is on the top-left corner of the image. + */ +export declare interface RectF { + left: number; + top: number; + right: number; + bottom: number; +} diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 838b3f585..62dd0463a 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -32,12 +32,14 @@ export declare type SpyWasmModule = jasmine.SpyObj; * in pure JS/TS (and optionally spy on the calls). */ export function createSpyWasmModule(): SpyWasmModule { - return jasmine.createSpyObj([ + const spyWasmModule = jasmine.createSpyObj([ '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', '_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addStringToInputStream', '_registerModelResourcesGraphService', - '_configureAudio' + '_configureAudio', '_malloc', '_addProtoToInputStream' ]); + spyWasmModule.HEAPU8 = jasmine.createSpyObj(['set']); + return spyWasmModule; } /** diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 3574483df..a0a008122 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -5,6 +5,14 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", package(default_visibility = ["//mediapipe/tasks:internal"]) +mediapipe_ts_declaration( + name = "image_processing_options", + srcs = ["image_processing_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:rect", + ], +) + mediapipe_ts_declaration( name = "vision_task_options", srcs = ["vision_task_options.d.ts"], @@ -17,7 +25,9 @@ mediapipe_ts_library( name = "vision_task_runner", srcs = ["vision_task_runner.ts"], deps = [ + ":image_processing_options", ":vision_task_options", + "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", @@ -31,8 +41,10 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":image_processing_options", ":vision_task_options", ":vision_task_runner", + "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/core/image_processing_options.d.ts b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts new file mode 100644 index 000000000..b76731546 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/image_processing_options.d.ts @@ -0,0 +1,42 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {RectF} from '../../../../tasks/web/components/containers/rect'; + +/** + * Options for image processing. + * + * If both region-or-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied + * to the crop. + */ +export declare interface ImageProcessingOptions { + /** + * The optional region-of-interest to crop from the image. If not specified, + * the full image is used. + * + * Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + */ + regionOfInterest?: RectF; + + /** + * The rotation to apply to the image (or cropped region-of-interest), in + * degrees clockwise. + * + * The rotation must be a multiple (positive or negative) of 90°. + */ + rotationDegrees?: number; +} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index f3f25070e..a48381038 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -16,21 +16,62 @@ import 'jasmine'; +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester} from '../../../../tasks/web/core/task_runner_test_utils'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner'; -class VisionTaskRunnerFake extends VisionTaskRunner { + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; + +const IMAGE = {} as unknown as HTMLImageElement; +const TIMESTAMP = 42; + +class VisionTaskRunnerFake extends VisionTaskRunner { baseOptions = new BaseOptionsProto(); + fakeGraphRunner: jasmine.SpyObj; + expectedImageSource?: ImageSource; + expectedNormalizedRect?: NormalizedRect; constructor() { - super(new VisionGraphRunner(createSpyWasmModule(), /* glCanvas= */ null)); - } + super( + jasmine.createSpyObj([ + 'addProtoToStream', 'addGpuBufferAsImageToStream', + 'setAutoRenderToScreen', 'registerModelResourcesGraphService', + 'finishProcessing' + ]), + IMAGE_STREAM, NORM_RECT_STREAM); - protected override process(): void {} + this.fakeGraphRunner = + this.graphRunner as unknown as jasmine.SpyObj; + + (this.graphRunner.addProtoToStream as jasmine.Spy) + .and.callFake((serializedData, type, streamName, timestamp) => { + expect(type).toBe('mediapipe.NormalizedRect'); + expect(streamName).toBe(NORM_RECT_STREAM); + expect(timestamp).toBe(TIMESTAMP); + + const actualNormalizedRect = + NormalizedRect.deserializeBinary(serializedData); + expect(actualNormalizedRect.toObject()) + .toEqual(this.expectedNormalizedRect!.toObject()); + }); + + (this.graphRunner.addGpuBufferAsImageToStream as jasmine.Spy) + .and.callFake((imageSource, streamName, timestamp) => { + expect(streamName).toBe(IMAGE_STREAM); + expect(timestamp).toBe(TIMESTAMP); + expect(imageSource).toBe(this.expectedImageSource!); + }); + } protected override refreshGraph(): void {} @@ -38,12 +79,31 @@ class VisionTaskRunnerFake extends VisionTaskRunner { return this.applyOptions(options); } - override processImageData(image: ImageSource): void { - super.processImageData(image); + override processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { + super.processImageData(image, imageProcessingOptions); } - override processVideoData(imageFrame: ImageSource, timestamp: number): void { - super.processVideoData(imageFrame, timestamp); + override processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + super.processVideoData(imageFrame, imageProcessingOptions, timestamp); + } + + expectNormalizedRect( + xCenter: number, yCenter: number, width: number, height: number): void { + const rect = new NormalizedRect(); + rect.setXCenter(xCenter); + rect.setYCenter(yCenter); + rect.setWidth(width); + rect.setHeight(height); + this.expectedNormalizedRect = rect; + } + + expectImage(imageSource: ImageSource): void { + this.expectedImageSource = imageSource; } } @@ -51,6 +111,7 @@ describe('VisionTaskRunner', () => { let visionTaskRunner: VisionTaskRunnerFake; beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions( {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); @@ -72,7 +133,8 @@ describe('VisionTaskRunner', () => { await visionTaskRunner.setOptions({runningMode: 'video'}); // Clear running mode - await visionTaskRunner.setOptions({runningMode: undefined}); + await visionTaskRunner.setOptions( + {runningMode: /* imageProcessingOptions= */ undefined}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); @@ -80,20 +142,90 @@ describe('VisionTaskRunner', () => { it('cannot process images with video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); expect(() => { - visionTaskRunner.processImageData({} as HTMLImageElement); + visionTaskRunner.processImageData( + IMAGE, /* imageProcessingOptions= */ undefined); }).toThrowError(/Task is not initialized with image mode./); }); it('cannot process video with image mode', async () => { // Use default for `useStreamMode` expect(() => { - visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); }).toThrowError(/Task is not initialized with video mode./); // Explicitly set to image mode await visionTaskRunner.setOptions({runningMode: 'image'}); expect(() => { - visionTaskRunner.processVideoData({} as HTMLImageElement, 42); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); }).toThrowError(/Task is not initialized with video mode./); }); + + it('sends packets to graph', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.5, 0.5, 1, 1); + visionTaskRunner.processVideoData( + IMAGE, /* imageProcessingOptions= */ undefined, TIMESTAMP); + }); + + it('sends packets to graph with image processing options', async () => { + await visionTaskRunner.setOptions({runningMode: 'video'}); + + visionTaskRunner.expectImage(IMAGE); + visionTaskRunner.expectNormalizedRect(0.3, 0.6, 0.2, 0.4); + visionTaskRunner.processVideoData( + IMAGE, + {regionOfInterest: {left: 0.2, right: 0.4, top: 0.4, bottom: 0.8}}, + TIMESTAMP); + }); + + describe('validates processing options', () => { + it('with left > right', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.2, + right: 0.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with top > bottom', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.2, + bottom: 0.1, + } + }); + }).toThrowError('Expected RectF with left < right and top < bottom.'); + }); + + it('with out of range values', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 1.1, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('Expected RectF values to be in [0,1].'); + }); + + it('with non-90 degree rotation', () => { + expect(() => { + visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); + }).toThrowError('Expected rotation to be a multiple of 90°.'); + }); + }); }); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index c3e0d3c7e..9adc810fc 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -14,7 +14,9 @@ * limitations under the License. */ +import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; @@ -27,10 +29,26 @@ const GraphRunnerVisionType = /** An implementation of the GraphRunner that supports image operations */ export class VisionGraphRunner extends GraphRunnerVisionType {} +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends TaskRunner { - /** @hideconstructor protected */ - constructor(protected override readonly graphRunner: VisionGraphRunner) { +export abstract class VisionTaskRunner extends TaskRunner { + /** + * Constructor to initialize a `VisionTaskRunner`. + * + * @param graphRunner the graph runner for this task. + * @param imageStreamName the name of the input image stream. + * @param normRectStreamName the name of the input normalized rect image + * stream used to provide (mandatory) rotation and (optional) + * region-of-interest. + * + * @hideconstructor protected + */ + constructor( + protected override readonly graphRunner: VisionGraphRunner, + private readonly imageStreamName: string, + private readonly normRectStreamName: string) { super(graphRunner); } @@ -44,27 +62,84 @@ export abstract class VisionTaskRunner extends TaskRunner { return super.applyOptions(options); } - /** Sends an image packet to the graph and awaits results. */ - protected abstract process(input: ImageSource, timestamp: number): T; - /** Sends a single image to the graph and awaits results. */ - protected processImageData(image: ImageSource): T { + protected processImageData( + image: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined): void { if (!!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with image mode. ' + '\'runningMode\' must be set to \'image\'.'); } - return this.process(image, performance.now()); + this.process(image, imageProcessingOptions, performance.now()); } /** Sends a single video frame to the graph and awaits results. */ - protected processVideoData(imageFrame: ImageSource, timestamp: number): T { + protected processVideoData( + imageFrame: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { if (!this.baseOptions?.getUseStreamMode()) { throw new Error( 'Task is not initialized with video mode. ' + '\'runningMode\' must be set to \'video\'.'); } - return this.process(imageFrame, timestamp); + this.process(imageFrame, imageProcessingOptions, timestamp); + } + + private convertToNormalizedRect(imageProcessingOptions?: + ImageProcessingOptions): NormalizedRect { + const normalizedRect = new NormalizedRect(); + + if (imageProcessingOptions?.regionOfInterest) { + const roi = imageProcessingOptions.regionOfInterest; + + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new Error('Expected RectF with left < right and top < bottom.'); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + throw new Error('Expected RectF values to be in [0,1].'); + } + + normalizedRect.setXCenter((roi.left + roi.right) / 2.0); + normalizedRect.setYCenter((roi.top + roi.bottom) / 2.0); + normalizedRect.setWidth(roi.right - roi.left); + normalizedRect.setHeight(roi.bottom - roi.top); + return normalizedRect; + } else { + normalizedRect.setXCenter(0.5); + normalizedRect.setYCenter(0.5); + normalizedRect.setWidth(1); + normalizedRect.setHeight(1); + } + + if (imageProcessingOptions?.rotationDegrees) { + if (imageProcessingOptions?.rotationDegrees % 90 !== 0) { + throw new Error( + 'Expected rotation to be a multiple of 90°.', + ); + } + + // Convert to radians anti-clockwise. + normalizedRect.setRotation( + -Math.PI * imageProcessingOptions.rotationDegrees / 180.0); + } + + return normalizedRect; + } + + /** Runs the graph and blocks on the response. */ + private process( + imageSource: ImageSource, + imageProcessingOptions: ImageProcessingOptions|undefined, + timestamp: number): void { + const normalizedRect = this.convertToNormalizedRect(imageProcessingOptions); + this.graphRunner.addProtoToStream( + normalizedRect.serializeBinary(), 'mediapipe.NormalizedRect', + this.normRectStreamName, timestamp); + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, this.imageStreamName, timestamp ?? performance.now()); + this.finishProcessing(); } } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 5fdf9b43e..9156e89b7 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -20,7 +20,6 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", @@ -33,6 +32,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8d36ed89c..e0c6affcb 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -18,7 +18,6 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; @@ -30,6 +29,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -57,15 +57,8 @@ const DEFAULT_NUM_HANDS = 1; const DEFAULT_SCORE_THRESHOLD = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); - /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends - VisionTaskRunner { +export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; @@ -131,7 +124,9 @@ export class GestureRecognizer extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); @@ -228,10 +223,16 @@ export class GestureRecognizer extends * GestureRecognizer is created with running mode `image`. * * @param image A single image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognize(image: ImageSource): GestureRecognizerResult { - return this.processImageData(image); + recognize( + image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + GestureRecognizerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); } /** @@ -241,28 +242,27 @@ export class GestureRecognizer extends * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected gestures. */ - recognizeForVideo(videoFrame: ImageSource, timestamp: number): + recognizeForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): GestureRecognizerResult { - return this.processVideoData(videoFrame, timestamp); + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); } - /** Runs the gesture recognition and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - GestureRecognizerResult { + private resetResults(): void { this.gestures = []; this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; + } - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, IMAGE_STREAM, timestamp); - this.graphRunner.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - + private processResults(): GestureRecognizerResult { if (this.gestures.length === 0) { // If no gestures are detected in the image, just return an empty list return { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index e7083a050..c5687ee2f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -20,7 +20,6 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", - "//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", @@ -28,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 5db6d48f5..e238bc96f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -18,7 +18,6 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; -import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; @@ -26,6 +25,7 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -51,14 +51,9 @@ const HAND_LANDMARKER_GRAPH = const DEFAULT_NUM_HANDS = 1; const DEFAULT_SCORE_THRESHOLD = 0.5; const DEFAULT_CATEGORY_INDEX = -1; -const FULL_IMAGE_RECT = new NormalizedRect(); -FULL_IMAGE_RECT.setXCenter(0.5); -FULL_IMAGE_RECT.setYCenter(0.5); -FULL_IMAGE_RECT.setWidth(1); -FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ -export class HandLandmarker extends VisionTaskRunner { +export class HandLandmarker extends VisionTaskRunner { private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -119,7 +114,9 @@ export class HandLandmarker extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); @@ -180,10 +177,15 @@ export class HandLandmarker extends VisionTaskRunner { * HandLandmarker is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected hand landmarks. */ - detect(image: ImageSource): HandLandmarkerResult { - return this.processImageData(image); + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + HandLandmarkerResult { + this.resetResults(); + this.processImageData(image, imageProcessingOptions); + return this.processResults(); } /** @@ -193,27 +195,25 @@ export class HandLandmarker extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The detected hand landmarks. */ - detectForVideo(videoFrame: ImageSource, timestamp: number): - HandLandmarkerResult { - return this.processVideoData(videoFrame, timestamp); + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): HandLandmarkerResult { + this.resetResults(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + return this.processResults(); } - /** Runs the hand landmarker graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - HandLandmarkerResult { + private resetResults(): void { this.landmarks = []; this.worldLandmarks = []; this.handednesses = []; + } - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, IMAGE_STREAM, timestamp); - this.graphRunner.addProtoToStream( - FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', - NORM_RECT_STREAM, timestamp); - this.finishProcessing(); - + private processResults(): HandLandmarkerResult { return { landmarks: this.landmarks, worldLandmarks: this.worldLandmarks, diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 310575964..86c7d8457 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 4a2be5566..2ad4a821d 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -22,6 +22,7 @@ import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_cla import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -31,7 +32,8 @@ import {ImageClassifierResult} from './image_classifier_result'; const IMAGE_CLASSIFIER_GRAPH = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; -const INPUT_STREAM = 'input_image'; +const IMAGE_STREAM = 'input_image'; +const NORM_RECT_STREAM = 'norm_rect'; const CLASSIFICATIONS_STREAM = 'classifications'; export * from './image_classifier_options'; @@ -42,7 +44,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends VisionTaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); @@ -97,7 +99,9 @@ export class ImageClassifier extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -130,10 +134,15 @@ export class ImageClassifier extends VisionTaskRunner { * ImageClassifier is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classify(image: ImageSource): ImageClassifierResult { - return this.processImageData(image); + classify(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageClassifierResult { + this.classificationResult = {classifications: []}; + this.processImageData(image, imageProcessingOptions); + return this.classificationResult; } /** @@ -143,28 +152,23 @@ export class ImageClassifier extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - classifyForVideo(videoFrame: ImageSource, timestamp: number): - ImageClassifierResult { - return this.processVideoData(videoFrame, timestamp); - } - - /** Runs the image classification graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - ImageClassifierResult { - // Get classification result by running our MediaPipe graph. + classifyForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageClassifierResult { this.classificationResult = {classifications: []}; - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -175,7 +179,8 @@ export class ImageClassifier extends VisionTaskRunner { // are built-in. const classifierNode = new CalculatorGraphConfig.Node(); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); - classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); + classifierNode.addInputStream('IMAGE:' + IMAGE_STREAM); + classifierNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index de4785e6c..449cee9bb 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 4651ae4ce..64a10f5f4 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -24,6 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -31,10 +32,12 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner import {ImageEmbedderOptions} from './image_embedder_options'; import {ImageEmbedderResult} from './image_embedder_result'; + // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -const INPUT_STREAM = 'image_in'; +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; const EMBEDDINGS_STREAM = 'embeddings_out'; const TEXT_EMBEDDER_CALCULATOR = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; @@ -44,7 +47,7 @@ export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ -export class ImageEmbedder extends VisionTaskRunner { +export class ImageEmbedder extends VisionTaskRunner { private readonly options = new ImageEmbedderGraphOptions(); private embeddings: ImageEmbedderResult = {embeddings: []}; @@ -99,7 +102,9 @@ export class ImageEmbedder extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -132,10 +137,14 @@ export class ImageEmbedder extends VisionTaskRunner { * ImageEmbedder is created with running mode `image`. * * @param image The image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - embed(image: ImageSource): ImageEmbedderResult { - return this.processImageData(image); + embed(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + ImageEmbedderResult { + this.processImageData(image, imageProcessingOptions); + return this.embeddings; } /** @@ -145,11 +154,15 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The classification result of the image */ - embedForVideo(imageFrame: ImageSource, timestamp: number): - ImageEmbedderResult { - return this.processVideoData(imageFrame, timestamp); + embedForVideo( + imageFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): ImageEmbedderResult { + this.processVideoData(imageFrame, imageProcessingOptions, timestamp); + return this.embeddings; } /** @@ -165,16 +178,6 @@ export class ImageEmbedder extends VisionTaskRunner { return computeCosineSimilarity(u, v); } - /** Runs the embedding extraction and blocks on the response. */ - protected process(image: ImageSource, timestamp: number): - ImageEmbedderResult { - // Get embeddings by running our MediaPipe graph. - this.graphRunner.addGpuBufferAsImageToStream( - image, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); - return this.embeddings; - } - /** * Internal function for converting raw data into an embedding, and setting it * as our embeddings result. @@ -187,7 +190,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -195,7 +199,8 @@ export class ImageEmbedder extends VisionTaskRunner { const embedderNode = new CalculatorGraphConfig.Node(); embedderNode.setCalculator(TEXT_EMBEDDER_CALCULATOR); - embedderNode.addInputStream('IMAGE:' + INPUT_STREAM); + embedderNode.addInputStream('IMAGE:' + IMAGE_STREAM); + embedderNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); embedderNode.setOptions(calculatorOptions); diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index fc206a2d7..76fa589c8 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -23,6 +23,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index ac489ec00..3a79c1b00 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -20,6 +20,7 @@ import {Detection as DetectionProto} from '../../../../framework/formats/detecti import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -27,7 +28,8 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner import {ObjectDetectorOptions} from './object_detector_options'; import {Detection} from './object_detector_result'; -const INPUT_STREAM = 'input_frame_gpu'; +const IMAGE_STREAM = 'input_frame_gpu'; +const NORM_RECT_STREAM = 'norm_rect'; const DETECTIONS_STREAM = 'detections'; const OBJECT_DETECTOR_GRAPH = 'mediapipe.tasks.vision.ObjectDetectorGraph'; @@ -41,7 +43,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends VisionTaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); @@ -96,7 +98,9 @@ export class ObjectDetector extends VisionTaskRunner { constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(new VisionGraphRunner(wasmModule, glCanvas)); + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM); this.options.setBaseOptions(new BaseOptionsProto()); } @@ -160,10 +164,15 @@ export class ObjectDetector extends VisionTaskRunner { * ObjectDetector is created with running mode `image`. * * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detect(image: ImageSource): Detection[] { - return this.processImageData(image); + detect(image: ImageSource, imageProcessingOptions?: ImageProcessingOptions): + Detection[] { + this.detections = []; + this.processImageData(image, imageProcessingOptions); + return [...this.detections]; } /** @@ -173,20 +182,15 @@ export class ObjectDetector extends VisionTaskRunner { * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. * @return The list of detected objects */ - detectForVideo(videoFrame: ImageSource, timestamp: number): Detection[] { - return this.processVideoData(videoFrame, timestamp); - } - - /** Runs the object detector graph and blocks on the response. */ - protected override process(imageSource: ImageSource, timestamp: number): - Detection[] { - // Get detections by running our MediaPipe graph. + detectForVideo( + videoFrame: ImageSource, timestamp: number, + imageProcessingOptions?: ImageProcessingOptions): Detection[] { this.detections = []; - this.graphRunner.addGpuBufferAsImageToStream( - imageSource, INPUT_STREAM, timestamp ?? performance.now()); - this.finishProcessing(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); return [...this.detections]; } @@ -226,7 +230,8 @@ export class ObjectDetector extends VisionTaskRunner { /** Updates the MediaPipe graph configuration. */ protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); - graphConfig.addInputStream(INPUT_STREAM); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); const calculatorOptions = new CalculatorOptions(); @@ -235,7 +240,8 @@ export class ObjectDetector extends VisionTaskRunner { const detectorNode = new CalculatorGraphConfig.Node(); detectorNode.setCalculator(OBJECT_DETECTOR_GRAPH); - detectorNode.addInputStream('IMAGE:' + INPUT_STREAM); + detectorNode.addInputStream('IMAGE:' + IMAGE_STREAM); + detectorNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); detectorNode.addOutputStream('DETECTIONS:' + DETECTIONS_STREAM); detectorNode.setOptions(calculatorOptions);