From 18d893c6979cbaf3b90e2ec6bf4d03f93bc3197b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 4 May 2023 20:39:54 -0700 Subject: [PATCH] Add scribble support to InteractiveSegmenter Web API PiperOrigin-RevId: 529594131 --- mediapipe/tasks/web/vision/core/types.d.ts | 5 ++- .../image_segmenter/image_segmenter_test.ts | 2 +- .../interactive_segmenter.ts | 27 ++++++++++--- .../interactive_segmenter_test.ts | 40 +++++++++++++++---- .../pose_landmarker/pose_landmarker_test.ts | 2 +- 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index c985a9f36..64d67bc30 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -19,7 +19,10 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke /** A Region-Of-Interest (ROI) to represent a region within an image. */ export declare interface RegionOfInterest { /** The ROI in keypoint format. */ - keypoint: NormalizedKeypoint; + keypoint?: NormalizedKeypoint; + + /** The ROI as scribbles over the object that the user wants to segment. */ + scribble?: NormalizedKeypoint[]; } /** A connection between two landmarks. */ diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index c1ccd7997..8c8767ec7 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -263,7 +263,7 @@ describe('ImageSegmenter', () => { }); }); - it('invokes listener once masks are avaiblae', async () => { + it('invokes listener once masks are available', async () => { const categoryMask = new Uint8ClampedArray([1]); const confidenceMask = new Float32Array([0.0]); let listenerCalled = false; diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index 67d6ec3f6..60ec9e1c5 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -338,16 +338,31 @@ export class InteractiveSegmenter extends VisionTaskRunner { const renderData = new RenderDataProto(); const renderAnnotation = new RenderAnnotationProto(); - const color = new ColorProto(); color.setR(255); renderAnnotation.setColor(color); - const point = new RenderAnnotationProto.Point(); - point.setNormalized(true); - point.setX(roi.keypoint.x); - point.setY(roi.keypoint.y); - renderAnnotation.setPoint(point); + if (roi.keypoint && roi.scribble) { + throw new Error('Cannot provide both keypoint and scribble.'); + } else if (roi.keypoint) { + const point = new RenderAnnotationProto.Point(); + point.setNormalized(true); + point.setX(roi.keypoint.x); + point.setY(roi.keypoint.y); + renderAnnotation.setPoint(point); + } else if (roi.scribble) { + const scribble = new RenderAnnotationProto.Scribble(); + for (const coord of roi.scribble) { + const point = new RenderAnnotationProto.Point(); + point.setNormalized(true); + point.setX(coord.x); + point.setY(coord.y); + scribble.addPoint(point); + } + renderAnnotation.setScribble(scribble); + } else { + throw new Error('Must provide either a keypoint or a scribble.'); + } renderData.addRenderAnnotations(renderAnnotation); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index 84ecde00b..a361af5a1 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -26,10 +26,14 @@ import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter'; -const ROI: RegionOfInterest = { +const KEYPOINT: RegionOfInterest = { keypoint: {x: 0.1, y: 0.2} }; +const SCRIBBLE: RegionOfInterest = { + scribble: [{x: 0.1, y: 0.2}, {x: 0.3, y: 0.4}] +}; + class InteractiveSegmenterFake extends InteractiveSegmenter implements MediapipeTasksFake { calculatorName = @@ -134,22 +138,42 @@ describe('InteractiveSegmenter', () => { it('doesn\'t support region of interest', () => { expect(() => { interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, + {} as HTMLImageElement, KEYPOINT, {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); }).toThrowError('This task doesn\'t support region-of-interest.'); }); - it('sends region-of-interest', (done) => { + it('sends region-of-interest with keypoint', (done) => { interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { expect(interactiveSegmenter.lastRoi).toBeDefined(); expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0]) .toEqual(jasmine.objectContaining({ color: {r: 255, b: undefined, g: undefined}, + point: {x: 0.1, y: 0.2, normalized: true}, })); done(); }); - interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); + interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {}); + }); + + it('sends region-of-interest with scribble', (done) => { + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(interactiveSegmenter.lastRoi).toBeDefined(); + expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0]) + .toEqual(jasmine.objectContaining({ + color: {r: 255, b: undefined, g: undefined}, + scribble: { + pointList: [ + {x: 0.1, y: 0.2, normalized: true}, + {x: 0.3, y: 0.4, normalized: true} + ] + }, + })); + done(); + }); + + interactiveSegmenter.segment({} as HTMLImageElement, SCRIBBLE, () => {}); }); it('supports category mask', async () => { @@ -168,7 +192,7 @@ describe('InteractiveSegmenter', () => { // Invoke the image segmenter return new Promise(resolve => { - interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); expect(result.categoryMask).toBeInstanceOf(MPImage); @@ -199,7 +223,7 @@ describe('InteractiveSegmenter', () => { }); return new Promise(resolve => { // Invoke the image segmenter - interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); expect(result.categoryMask).not.toBeDefined(); @@ -239,7 +263,7 @@ describe('InteractiveSegmenter', () => { return new Promise(resolve => { // Invoke the image segmenter interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, result => { + {} as HTMLImageElement, KEYPOINT, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); expect(result.categoryMask).toBeInstanceOf(MPImage); @@ -276,7 +300,7 @@ describe('InteractiveSegmenter', () => { }); return new Promise(resolve => { - interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => { + interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => { listenerCalled = true; resolve(); }); diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts index c97b0d7b0..907cb16b3 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker_test.ts @@ -298,7 +298,7 @@ describe('PoseLandmarker', () => { }); }); - it('invokes listener once masks are avaiblae', (done) => { + it('invokes listener once masks are available', (done) => { const landmarksProto = [createLandmarks().serializeBinary()]; const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; const masks = [