Reject RegionOfInterest in not supported tasks

PiperOrigin-RevId: 501872455
This commit is contained in:
Sebastian Schmidt 2023-01-13 09:52:07 -08:00 committed by Copybara-Service
parent 69757d7924
commit f997c0ab1a
10 changed files with 70 additions and 14 deletions

View File

@ -41,14 +41,14 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
expectedImageSource?: ImageSource;
expectedNormalizedRect?: NormalizedRect;
constructor() {
constructor(roiAllowed = true) {
super(
jasmine.createSpyObj<VisionGraphRunner>([
'addProtoToStream', 'addGpuBufferAsImageToStream',
'setAutoRenderToScreen', 'registerModelResourcesGraphService',
'finishProcessing'
]),
IMAGE_STREAM, NORM_RECT_STREAM);
IMAGE_STREAM, NORM_RECT_STREAM, roiAllowed);
this.fakeGraphRunner =
this.graphRunner as unknown as jasmine.SpyObj<VisionGraphRunner>;
@ -71,6 +71,9 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
expect(timestamp).toBe(TIMESTAMP);
expect(imageSource).toBe(this.expectedImageSource!);
});
// SetOptions with a modelAssetBuffer runs synchonously
void this.setOptions({baseOptions: {modelAssetBuffer: new Uint8Array([])}});
}
protected override refreshGraph(): void {}
@ -108,28 +111,26 @@ class VisionTaskRunnerFake extends VisionTaskRunner {
}
describe('VisionTaskRunner', () => {
let visionTaskRunner: VisionTaskRunnerFake;
beforeEach(async () => {
beforeEach(() => {
addJasmineCustomFloatEqualityTester();
visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('can enable image mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'IMAGE'});
expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: false}));
});
it('can enable video mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
expect(visionTaskRunner.baseOptions.toObject())
.toEqual(jasmine.objectContaining({useStreamMode: true}));
});
it('can clear running mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
// Clear running mode
@ -140,6 +141,7 @@ describe('VisionTaskRunner', () => {
});
it('cannot process images with video mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
expect(() => {
visionTaskRunner.processImageData(
@ -148,6 +150,7 @@ describe('VisionTaskRunner', () => {
});
it('cannot process video with image mode', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
// Use default for `useStreamMode`
expect(() => {
visionTaskRunner.processVideoData(
@ -163,6 +166,7 @@ describe('VisionTaskRunner', () => {
});
it('sends packets to graph', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
visionTaskRunner.expectImage(IMAGE);
@ -172,6 +176,7 @@ describe('VisionTaskRunner', () => {
});
it('sends packets to graph with image processing options', async () => {
const visionTaskRunner = new VisionTaskRunnerFake();
await visionTaskRunner.setOptions({runningMode: 'VIDEO'});
visionTaskRunner.expectImage(IMAGE);
@ -184,6 +189,7 @@ describe('VisionTaskRunner', () => {
describe('validates processing options', () => {
it('with left > right', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => {
visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: {
@ -197,6 +203,7 @@ describe('VisionTaskRunner', () => {
});
it('with top > bottom', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => {
visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: {
@ -210,6 +217,7 @@ describe('VisionTaskRunner', () => {
});
it('with out of range values', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => {
visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: {
@ -222,7 +230,24 @@ describe('VisionTaskRunner', () => {
}).toThrowError('Expected RectF values to be in [0,1].');
});
it('without region of interest support', () => {
const visionTaskRunner =
new VisionTaskRunnerFake(/* roiAllowed= */ false);
expect(() => {
visionTaskRunner.processImageData(IMAGE, {
regionOfInterest: {
left: 0.1,
right: 0.2,
top: 0.1,
bottom: 0.2,
}
});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('with non-90 degree rotation', () => {
const visionTaskRunner = new VisionTaskRunnerFake();
expect(() => {
visionTaskRunner.processImageData(IMAGE, {rotationDegrees: 42});
}).toThrowError('Expected rotation to be a multiple of 90°.');

View File

@ -42,13 +42,16 @@ export abstract class VisionTaskRunner extends TaskRunner {
* @param normRectStreamName the name of the input normalized rect image
* stream used to provide (mandatory) rotation and (optional)
* region-of-interest.
* @param roiAllowed Whether this task supports Region-Of-Interest
* pre-processing
*
* @hideconstructor protected
*/
constructor(
protected override readonly graphRunner: VisionGraphRunner,
private readonly imageStreamName: string,
private readonly normRectStreamName: string) {
private readonly normRectStreamName: string,
private readonly roiAllowed: boolean) {
super(graphRunner);
}
@ -96,6 +99,10 @@ export abstract class VisionTaskRunner extends TaskRunner {
const normalizedRect = new NormalizedRect();
if (imageProcessingOptions?.regionOfInterest) {
if (!this.roiAllowed) {
throw new Error('This task doesn\'t support region-of-interest.');
}
const roi = imageProcessingOptions.regionOfInterest;
if (roi.left >= roi.right || roi.top >= roi.bottom) {

View File

@ -126,7 +126,7 @@ export class GestureRecognizer extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM);
NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options = new GestureRecognizerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());

View File

@ -250,6 +250,14 @@ describe('GestureRecognizer', () => {
}
});
it('doesn\'t support region of interest', () => {
expect(() => {
gestureRecognizer.recognize(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('transforms results', async () => {
// Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {

View File

@ -116,7 +116,7 @@ export class HandLandmarker extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM);
NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options = new HandLandmarkerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());

View File

@ -203,6 +203,14 @@ describe('HandLandmarker', () => {
}
});
it('doesn\'t support region of interest', () => {
expect(() => {
handLandmarker.detect(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('transforms results', async () => {
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {

View File

@ -101,7 +101,7 @@ export class ImageClassifier extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM);
NORM_RECT_STREAM, /* roiAllowed= */ true);
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -104,7 +104,7 @@ export class ImageEmbedder extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM);
NORM_RECT_STREAM, /* roiAllowed= */ true);
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -100,7 +100,7 @@ export class ObjectDetector extends VisionTaskRunner {
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM);
NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -170,6 +170,14 @@ describe('ObjectDetector', () => {
}
});
it('doesn\'t support region of interest', () => {
expect(() => {
objectDetector.detect(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('transforms results', async () => {
const detectionProtos: Uint8Array[] = [];