Add scribble support to InteractiveSegmenter Web API

PiperOrigin-RevId: 529594131
This commit is contained in:
Sebastian Schmidt 2023-05-04 20:39:54 -07:00 committed by Copybara-Service
parent 61cfe2ca9b
commit 18d893c697
5 changed files with 59 additions and 17 deletions

View File

@ -19,7 +19,10 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
/** A Region-Of-Interest (ROI) to represent a region within an image. */
export declare interface RegionOfInterest {
/** The ROI in keypoint format. */
keypoint: NormalizedKeypoint;
keypoint?: NormalizedKeypoint;
/** The ROI as scribbles over the object that the user wants to segment. */
scribble?: NormalizedKeypoint[];
}
/** A connection between two landmarks. */

View File

@ -263,7 +263,7 @@ describe('ImageSegmenter', () => {
});
});
it('invokes listener once masks are avaiblae', async () => {
it('invokes listener once masks are available', async () => {
const categoryMask = new Uint8ClampedArray([1]);
const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false;

View File

@ -338,16 +338,31 @@ export class InteractiveSegmenter extends VisionTaskRunner {
const renderData = new RenderDataProto();
const renderAnnotation = new RenderAnnotationProto();
const color = new ColorProto();
color.setR(255);
renderAnnotation.setColor(color);
if (roi.keypoint && roi.scribble) {
throw new Error('Cannot provide both keypoint and scribble.');
} else if (roi.keypoint) {
const point = new RenderAnnotationProto.Point();
point.setNormalized(true);
point.setX(roi.keypoint.x);
point.setY(roi.keypoint.y);
renderAnnotation.setPoint(point);
} else if (roi.scribble) {
const scribble = new RenderAnnotationProto.Scribble();
for (const coord of roi.scribble) {
const point = new RenderAnnotationProto.Point();
point.setNormalized(true);
point.setX(coord.x);
point.setY(coord.y);
scribble.addPoint(point);
}
renderAnnotation.setScribble(scribble);
} else {
throw new Error('Must provide either a keypoint or a scribble.');
}
renderData.addRenderAnnotations(renderAnnotation);

View File

@ -26,10 +26,14 @@ import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter';
const ROI: RegionOfInterest = {
const KEYPOINT: RegionOfInterest = {
keypoint: {x: 0.1, y: 0.2}
};
const SCRIBBLE: RegionOfInterest = {
scribble: [{x: 0.1, y: 0.2}, {x: 0.3, y: 0.4}]
};
class InteractiveSegmenterFake extends InteractiveSegmenter implements
MediapipeTasksFake {
calculatorName =
@ -134,22 +138,42 @@ describe('InteractiveSegmenter', () => {
it('doesn\'t support region of interest', () => {
expect(() => {
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI,
{} as HTMLImageElement, KEYPOINT,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('sends region-of-interest', (done) => {
it('sends region-of-interest with keypoint', (done) => {
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.lastRoi).toBeDefined();
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
.toEqual(jasmine.objectContaining({
color: {r: 255, b: undefined, g: undefined},
point: {x: 0.1, y: 0.2, normalized: true},
}));
done();
});
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {});
});
it('sends region-of-interest with scribble', (done) => {
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.lastRoi).toBeDefined();
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
.toEqual(jasmine.objectContaining({
color: {r: 255, b: undefined, g: undefined},
scribble: {
pointList: [
{x: 0.1, y: 0.2, normalized: true},
{x: 0.3, y: 0.4, normalized: true}
]
},
}));
done();
});
interactiveSegmenter.segment({} as HTMLImageElement, SCRIBBLE, () => {});
});
it('supports category mask', async () => {
@ -168,7 +192,7 @@ describe('InteractiveSegmenter', () => {
// Invoke the image segmenter
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage);
@ -199,7 +223,7 @@ describe('InteractiveSegmenter', () => {
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined();
@ -239,7 +263,7 @@ describe('InteractiveSegmenter', () => {
return new Promise<void>(resolve => {
// Invoke the image segmenter
interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, result => {
{} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage);
@ -276,7 +300,7 @@ describe('InteractiveSegmenter', () => {
});
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
listenerCalled = true;
resolve();
});

View File

@ -298,7 +298,7 @@ describe('PoseLandmarker', () => {
});
});
it('invokes listener once masks are avaiblae', (done) => {
it('invokes listener once masks are available', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const masks = [