Reject RegionOfInterest in not supported tasks

PiperOrigin-RevId: 501872455
This commit is contained in:
Sebastian Schmidt 2023-01-13 09:52:07 -08:00 committed by Copybara-Service
parent 69757d7924
commit f997c0ab1a
10 changed files with 70 additions and 14 deletions

View File

@ -41,14 +41,14 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
expectedImageSource?: ImageSource; expectedImageSource?: ImageSource;
expectedNormalizedRect?: NormalizedRect; expectedNormalizedRect?: NormalizedRect;
constructor() { constructor(roiAllowed = true) {
super( super(
jasmine.createSpyObj<VisionGraphRunner>([ jasmine.createSpyObj<VisionGraphRunner>([
'addProtoToStream', 'addGpuBufferAsImageToStream', 'addProtoToStream', 'addGpuBufferAsImageToStream',
'setAutoRenderToScreen', 'registerModelResourcesGraphService', 'setAutoRenderToScreen', 'registerModelResourcesGraphService',
'finishProcessing' 'finishProcessing'
]), ]),
IMAGE_STREAM, NORM_RECT_STREAM); IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed);
this.fakeGraphRunner = this.fakeGraphRunner =
this.graphRunner as unknown as jasmine.SpyObj<VisionGraphRunner>; this.graphRunner as unknown as jasmine.SpyObj<VisionGraphRunner>;
@ -71,6 +71,9 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
expect(timestamp).toBe(TIMESTAMP); expect(timestamp).toBe(TIMESTAMP);
expect(imageSource).toBe(this.expectedImageSource!); expect(imageSource).toBe(this.expectedImageSource!);
}); });
// SetOptions with a modelAssetBuffer runs synchonously
void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}});
} }
protected override refreshGraph(): void {} protected override refreshGraph(): void {}
@ -108,28 +111,26 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
} }
describe('VisionTaskRunner', () => { describe('VisionTaskRunner', () => {
let visionTaskRunner: VisionTaskRunnerFake; beforeEach(() => {
beforeEach(async () => {
addJasmineCustomFloatEqualityTester(); addJasmineCustomFloatEqualityTester();
visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}); });
it('can enable image mode', async () => { it('can enable image mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'IMAGE'}); await visionTaskRunner.setOptions({runningMode: 'IMAGE'});
expect(visionTaskRunner.baseOptions.toObject()) expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false})); .toEqual(jasmine.objectContaining({useStreamMode: false}));
}); });
it('can enable video mode', async () => { it('can enable video mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
expect(visionTaskRunner.baseOptions.toObject()) expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: true})); .toEqual(jasmine.objectContaining({useStreamMode: true}));
}); });
it('can clear running mode', async () => { it('can clear running mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
// Clear running mode // Clear running mode
@ -140,6 +141,7 @@ describe('VisionTaskRunner', () => {
}); });
it('cannot process images with video mode', async () => { it('cannot process images with video mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
expect(() => { expect(() => {
visionTaskRunner.processImageData( visionTaskRunner.processImageData(
@ -148,6 +150,7 @@ describe('VisionTaskRunner', () => {
}); });
it('cannot process video with image mode', async () => { it('cannot process video with image mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
// Use default for `useStreamMode` // Use default for `useStreamMode`
expect(() => { expect(() => {
visionTaskRunner.processVideoData( visionTaskRunner.processVideoData(
@ -163,6 +166,7 @@ describe('VisionTaskRunner', () => {
}); });
it('sends packets to graph', async () => { it('sends packets to graph', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
visionTaskRunner.expectImage(IMAGE); visionTaskRunner.expectImage(IMAGE);
@ -172,6 +176,7 @@ describe('VisionTaskRunner', () => {
}); });
it('sends packets to graph with image processing options', async () => { it('sends packets to graph with image processing options', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'}); await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
visionTaskRunner.expectImage(IMAGE); visionTaskRunner.expectImage(IMAGE);
@ -184,6 +189,7 @@ describe('VisionTaskRunner', () => {
describe('validates processing options', () => { describe('validates processing options', () => {
it('with left > right', () => { it('with left > right', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => { expect(() => {
visionTaskRunner.processImageData(IMAGE, { visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: { regionOfInterest: {
@ -197,6 +203,7 @@ describe('VisionTaskRunner', () => {
}); });
it('with top > bottom', () => { it('with top > bottom', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => { expect(() => {
visionTaskRunner.processImageData(IMAGE, { visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: { regionOfInterest: {
@ -210,6 +217,7 @@ describe('VisionTaskRunner', () => {
}); });
it('with out of range values', () => { it('with out of range values', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => { expect(() => {
visionTaskRunner.processImageData(IMAGE, { visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: { regionOfInterest: {
@ -222,7 +230,24 @@ describe('VisionTaskRunner', () => {
}).toThrowError('Expected RectF values to be in [0,1].'); }).toThrowError('Expected RectF values to be in [0,1].');
}); });
it('without region of interest support', () => {
const visionTaskRunner =
new VisionTaskRunnerFake(/* roiAllowed= */ false);
expect(() => {
visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: {
left: 0.1,
right: 0.2,
top: 0.1,
bottom: 0.2,
}
});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('with non-90 degree rotation', () => { it('with non-90 degree rotation', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => { expect(() => {
visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42}); visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42});
}).toThrowError('Expected rotation to be a multiple of 90°.'); }).toThrowError('Expected rotation to be a multiple of 90°.');

View File

@ -42,13 +42,16 @@ export abstract class VisionTaskRunner extends TaskRunner {
* @param normRectStreamName the name of the input normalized rect image * @param normRectStreamName the name of the input normalized rect image
* stream used to provide (mandatory) rotation and (optional) * stream used to provide (mandatory) rotation and (optional)
* region-of-interest. * region-of-interest.
* @param roiAllowed Whether this task supports Region-Of-Interest
* pre-processing
* *
* @hideconstructor protected * @hideconstructor protected
*/ */
constructor( constructor(
protected override readonly graphRunner: VisionGraphRunner, protected override readonly graphRunner: VisionGraphRunner,
private readonly imageStreamName: string, private readonly imageStreamName: string,
private readonly normRectStreamName: string) { private readonly normRectStreamName: string,
private readonly roiAllowed: boolean) {
super(graphRunner); super(graphRunner);
} }
@ -96,6 +99,10 @@ export abstract class VisionTaskRunner extends TaskRunner {
const normalizedRect = new NormalizedRect(); const normalizedRect = new NormalizedRect();
if (imageProcessingOptions?.regionOfInterest) { if (imageProcessingOptions?.regionOfInterest) {
if (!this.roiAllowed) {
throw new Error('This task doesn\'t support region-of-interest.');
}
const roi = imageProcessingOptions.regionOfInterest; const roi = imageProcessingOptions.regionOfInterest;
if (roi.left >= roi.right || roi.top >= roi.bottom) { if (roi.left >= roi.right || roi.top >= roi.bottom) {

View File

@ -126,7 +126,7 @@ export class GestureRecognizer extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super( super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM); NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options = new GestureRecognizerGraphOptions(); this.options = new GestureRecognizerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());

View File

@ -250,6 +250,14 @@ describe('GestureRecognizer', () => {
} }
}); });
it('doesn\'t support region of interest', () => {
expect(() => {
gestureRecognizer.recognize(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('transforms results', async () => { it('transforms results', async () => {
// Pass the test data to our listener // Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {

View File

@ -116,7 +116,7 @@ export class HandLandmarker extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super( super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM); NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options = new HandLandmarkerGraphOptions(); this.options = new HandLandmarkerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());

View File

@ -203,6 +203,14 @@ describe('HandLandmarker', () => {
} }
}); });
it('doesn\'t support region of interest', () => {
expect(() => {
handLandmarker.detect(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('transforms results', async () => { it('transforms results', async () => {
// Pass the test data to our listener // Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {

View File

@ -101,7 +101,7 @@ export class ImageClassifier extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super( super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM); NORM_RECT_STREAM, /* roiAllowed= */ true);
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());
} }

View File

@ -104,7 +104,7 @@ export class ImageEmbedder extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super( super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM); NORM_RECT_STREAM, /* roiAllowed= */ true);
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());
} }

View File

@ -100,7 +100,7 @@ export class ObjectDetector extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super( super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM); NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());
} }

View File

@ -170,6 +170,14 @@ describe('ObjectDetector', () => {
} }
}); });
it('doesn\'t support region of interest', () => {
expect(() => {
objectDetector.detect(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('transforms results', async () => { it('transforms results', async () => {
const detectionProtos: Uint8Array[] = []; const detectionProtos: Uint8Array[] = [];