From f997c0ab1a8bc69d0ef8760061a515313144af8c Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 13 Jan 2023 09:52:07 -0800 Subject: [PATCH] Reject RegionOfInterest in not supported tasks PiperOrigin-RevId: 501872455 --- .../vision/core/vision_task_runner.test.ts | 41 +++++++++++++++---- .../web/vision/core/vision_task_runner.ts | 9 +++- .../gesture_recognizer/gesture_recognizer.ts | 2 +- .../gesture_recognizer_test.ts | 8 ++++ .../vision/hand_landmarker/hand_landmarker.ts | 2 +- .../hand_landmarker/hand_landmarker_test.ts | 8 ++++ .../image_classifier/image_classifier.ts | 2 +- .../vision/image_embedder/image_embedder.ts | 2 +- .../vision/object_detector/object_detector.ts | 2 +- .../object_detector/object_detector_test.ts | 8 ++++ 10 files changed, 70 insertions(+), 14 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 4567134b8..4eb51afdb 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -41,14 +41,14 @@ class VisionTaskRunnerFake extends VisionTaskRunner { expectedImageSource?: ImageSource; expectedNormalizedRect?: NormalizedRect; - constructor() { + constructor(roiAllowed = true) { super( jasmine.createSpyObj([ 'addProtoToStream', 'addGpuBufferAsImageToStream', 'setAutoRenderToScreen', 'registerModelResourcesGraphService', 'finishProcessing' ]), - IMAGE_STREAM, NORM_RECT_STREAM); + IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed); this.fakeGraphRunner = this.graphRunner as unknown as jasmine.SpyObj; @@ -71,6 +71,9 @@ class VisionTaskRunnerFake extends VisionTaskRunner { expect(timestamp).toBe(TIMESTAMP); expect(imageSource).toBe(this.expectedImageSource!); }); + + // SetOptions with a modelAssetBuffer runs synchonously + void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}}); } protected override refreshGraph(): void {} @@ -108,28 +111,26 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - let visionTaskRunner: VisionTaskRunnerFake; - - beforeEach(async () => { + beforeEach(() => { addJasmineCustomFloatEqualityTester(); - visionTaskRunner = new VisionTaskRunnerFake(); - await visionTaskRunner.setOptions( - {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(visionTaskRunner.baseOptions.toObject()) .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); // Clear running mode @@ -140,6 +141,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process images with video mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); expect(() => { visionTaskRunner.processImageData( @@ -148,6 +150,7 @@ describe('VisionTaskRunner', () => { }); it('cannot process video with image mode', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); // Use default for `useStreamMode` expect(() => { visionTaskRunner.processVideoData( @@ -163,6 +166,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); @@ -172,6 +176,7 @@ describe('VisionTaskRunner', () => { }); it('sends packets to graph with image processing options', async () => { + const visionTaskRunner = new VisionTaskRunnerFake(); await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); visionTaskRunner.expectImage(IMAGE); @@ -184,6 +189,7 @@ describe('VisionTaskRunner', () => { describe('validates processing options', () => { it('with left > right', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -197,6 +203,7 @@ describe('VisionTaskRunner', () => { }); it('with top > bottom', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -210,6 +217,7 @@ describe('VisionTaskRunner', () => { }); it('with out of range values', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, { regionOfInterest: { @@ -222,7 +230,24 @@ describe('VisionTaskRunner', () => { }).toThrowError('Expected RectF values to be in [0,1].'); }); + + it('without region of interest support', () => { + const visionTaskRunner = + new VisionTaskRunnerFake(/* roiAllowed= */ false); + expect(() => { + visionTaskRunner.processImageData(IMAGE, { + regionOfInterest: { + left: 0.1, + right: 0.2, + top: 0.1, + bottom: 0.2, + } + }); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('with non-90 degree rotation', () => { + const visionTaskRunner = new VisionTaskRunnerFake(); expect(() => { visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); }).toThrowError('Expected rotation to be a multiple of 90°.'); diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 71cac920c..b3e8ed4db 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -42,13 +42,16 @@ export abstract class VisionTaskRunner extends TaskRunner { * @param normRectStreamName the name of the input normalized rect image * stream used to provide (mandatory) rotation and (optional) * region-of-interest. + * @param roiAllowed Whether this task supports Region-Of-Interest + * pre-processing * * @hideconstructor protected */ constructor( protected override readonly graphRunner: VisionGraphRunner, private readonly imageStreamName: string, - private readonly normRectStreamName: string) { + private readonly normRectStreamName: string, + private readonly roiAllowed: boolean) { super(graphRunner); } @@ -96,6 +99,10 @@ export abstract class VisionTaskRunner extends TaskRunner { const normalizedRect = new NormalizedRect(); if (imageProcessingOptions?.regionOfInterest) { + if (!this.roiAllowed) { + throw new Error('This task doesn\'t support region-of-interest.'); + } + const roi = imageProcessingOptions.regionOfInterest; if (roi.left >= roi.right || roi.top >= roi.bottom) { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 1b7201b9a..beea263ce 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -126,7 +126,7 @@ export class GestureRecognizer extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new GestureRecognizerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index dfc252eb6..b2a2c0d72 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -250,6 +250,14 @@ describe('GestureRecognizer', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + gestureRecognizer.recognize( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { // Pass the test data to our listener gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index b51fb6a52..cd0459372 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -116,7 +116,7 @@ export class HandLandmarker extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options = new HandLandmarkerGraphOptions(); this.options.setBaseOptions(new BaseOptionsProto()); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 0abd1df27..5fd493424 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -203,6 +203,14 @@ describe('HandLandmarker', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + handLandmarker.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index cb2849cd8..071513b19 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -101,7 +101,7 @@ export class ImageClassifier extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ true); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 788646e6d..fdeb92f3f 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -104,7 +104,7 @@ export class ImageEmbedder extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ true); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 5741a3a0c..5b581432d 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -100,7 +100,7 @@ export class ObjectDetector extends VisionTaskRunner { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { super( new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, - NORM_RECT_STREAM); + NORM_RECT_STREAM, /* roiAllowed= */ false); this.options.setBaseOptions(new BaseOptionsProto()); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index ceb96acb1..9dd64c0b6 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -170,6 +170,14 @@ describe('ObjectDetector', () => { } }); + it('doesn\'t support region of interest', () => { + expect(() => { + objectDetector.detect( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + it('transforms results', async () => { const detectionProtos: Uint8Array[] = [];