Add scribble support to InteractiveSegmenter Web API

PiperOrigin-RevId: 529594131
This commit is contained in:
Sebastian Schmidt 2023-05-04 20:39:54 -07:00 committed by Copybara-Service
parent 61cfe2ca9b
commit 18d893c697
5 changed files with 59 additions and 17 deletions

View File

@ -19,7 +19,10 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
/** A Region-Of-Interest (ROI) to represent a region within an image. */ /** A Region-Of-Interest (ROI) to represent a region within an image. */
export declare interface RegionOfInterest { export declare interface RegionOfInterest {
/** The ROI in keypoint format. */ /** The ROI in keypoint format. */
keypoint: NormalizedKeypoint; keypoint?: NormalizedKeypoint;
/** The ROI as scribbles over the object that the user wants to segment. */
scribble?: NormalizedKeypoint[];
} }
/** A connection between two landmarks. */ /** A connection between two landmarks. */

View File

@ -263,7 +263,7 @@ describe('ImageSegmenter', () => {
}); });
}); });
it('invokes listener once masks are avaiblae', async () => { it('invokes listener once masks are available', async () => {
const categoryMask = new Uint8ClampedArray([1]); const categoryMask = new Uint8ClampedArray([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false; let listenerCalled = false;

View File

@ -338,16 +338,31 @@ export class InteractiveSegmenter extends VisionTaskRunner {
const renderData = new RenderDataProto(); const renderData = new RenderDataProto();
const renderAnnotation = new RenderAnnotationProto(); const renderAnnotation = new RenderAnnotationProto();
const color = new ColorProto(); const color = new ColorProto();
color.setR(255); color.setR(255);
renderAnnotation.setColor(color); renderAnnotation.setColor(color);
const point = new RenderAnnotationProto.Point(); if (roi.keypoint && roi.scribble) {
point.setNormalized(true); throw new Error('Cannot provide both keypoint and scribble.');
point.setX(roi.keypoint.x); } else if (roi.keypoint) {
point.setY(roi.keypoint.y); const point = new RenderAnnotationProto.Point();
renderAnnotation.setPoint(point); point.setNormalized(true);
point.setX(roi.keypoint.x);
point.setY(roi.keypoint.y);
renderAnnotation.setPoint(point);
} else if (roi.scribble) {
const scribble = new RenderAnnotationProto.Scribble();
for (const coord of roi.scribble) {
const point = new RenderAnnotationProto.Point();
point.setNormalized(true);
point.setX(coord.x);
point.setY(coord.y);
scribble.addPoint(point);
}
renderAnnotation.setScribble(scribble);
} else {
throw new Error('Must provide either a keypoint or a scribble.');
}
renderData.addRenderAnnotations(renderAnnotation); renderData.addRenderAnnotations(renderAnnotation);

View File

@ -26,10 +26,14 @@ import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter'; import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter';
const ROI: RegionOfInterest = { const KEYPOINT: RegionOfInterest = {
keypoint: {x: 0.1, y: 0.2} keypoint: {x: 0.1, y: 0.2}
}; };
const SCRIBBLE: RegionOfInterest = {
scribble: [{x: 0.1, y: 0.2}, {x: 0.3, y: 0.4}]
};
class InteractiveSegmenterFake extends InteractiveSegmenter implements class InteractiveSegmenterFake extends InteractiveSegmenter implements
MediapipeTasksFake { MediapipeTasksFake {
calculatorName = calculatorName =
@ -134,22 +138,42 @@ describe('InteractiveSegmenter', () => {
it('doesn\'t support region of interest', () => { it('doesn\'t support region of interest', () => {
expect(() => { expect(() => {
interactiveSegmenter.segment( interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, {} as HTMLImageElement, KEYPOINT,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
}).toThrowError('This task doesn\'t support region-of-interest.'); }).toThrowError('This task doesn\'t support region-of-interest.');
}); });
it('sends region-of-interest', (done) => { it('sends region-of-interest with keypoint', (done) => {
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.lastRoi).toBeDefined(); expect(interactiveSegmenter.lastRoi).toBeDefined();
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0]) expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
.toEqual(jasmine.objectContaining({ .toEqual(jasmine.objectContaining({
color: {r: 255, b: undefined, g: undefined}, color: {r: 255, b: undefined, g: undefined},
point: {x: 0.1, y: 0.2, normalized: true},
})); }));
done(); done();
}); });
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {});
});
it('sends region-of-interest with scribble', (done) => {
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.lastRoi).toBeDefined();
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
.toEqual(jasmine.objectContaining({
color: {r: 255, b: undefined, g: undefined},
scribble: {
pointList: [
{x: 0.1, y: 0.2, normalized: true},
{x: 0.3, y: 0.4, normalized: true}
]
},
}));
done();
});
interactiveSegmenter.segment({} as HTMLImageElement, SCRIBBLE, () => {});
}); });
it('supports category mask', async () => { it('supports category mask', async () => {
@ -168,7 +192,7 @@ describe('InteractiveSegmenter', () => {
// Invoke the image segmenter // Invoke the image segmenter
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage); expect(result.categoryMask).toBeInstanceOf(MPImage);
@ -199,7 +223,7 @@ describe('InteractiveSegmenter', () => {
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
// Invoke the image segmenter // Invoke the image segmenter
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined(); expect(result.categoryMask).not.toBeDefined();
@ -239,7 +263,7 @@ describe('InteractiveSegmenter', () => {
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
// Invoke the image segmenter // Invoke the image segmenter
interactiveSegmenter.segment( interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, result => { {} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage); expect(result.categoryMask).toBeInstanceOf(MPImage);
@ -276,7 +300,7 @@ describe('InteractiveSegmenter', () => {
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => { interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
listenerCalled = true; listenerCalled = true;
resolve(); resolve();
}); });

View File

@ -298,7 +298,7 @@ describe('PoseLandmarker', () => {
}); });
}); });
it('invokes listener once masks are avaiblae', (done) => { it('invokes listener once masks are available', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()]; const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const masks = [ const masks = [