diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index 5699126b9..344d4db85 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke */ export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture; -/** - * A callback that receives the computed masks from the segmentation tasks. The - * callback either receives a single element array with a category mask (as a - * `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`). - * The returned data is only valid for the duration of the callback. If - * asynchronous processing is needed, all data needs to be copied before the - * callback returns. - */ -export type SegmentationMaskCallback = - (masks: SegmentationMask[], width: number, height: number) => void; - /** * A callback that receives an `ImageData` object from a Vision task. The * lifetime of the underlying data is limited to the duration of the callback. diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD index a4a3f27c9..ead85d38a 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD @@ -30,7 +30,10 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "interactive_segmenter_types", - srcs = ["interactive_segmenter_options.d.ts"], + srcs = [ + "interactive_segmenter_options.d.ts", + "interactive_segmenter_result.d.ts", + ], deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index ddcc7e592..16841bd7f 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {Color as ColorProto} from '../../../../util/color_pb'; import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; @@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner // Placeholder for internal dependency on trusted resource url import {InteractiveSegmenterOptions} from './interactive_segmenter_options'; +import {InteractiveSegmenterResult} from './interactive_segmenter_result'; export * from './interactive_segmenter_options'; -export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest}; +export * from './interactive_segmenter_result'; +export {SegmentationMask, RegionOfInterest}; export {ImageSource}; const IMAGE_IN_STREAM = 'image_in'; const NORM_RECT_IN_STREAM = 'norm_rect_in'; const ROI_IN_STREAM = 'roi_in'; -const IMAGE_OUT_STREAM = 'image_out'; +const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; +const CATEGORY_MASK_STREAM = 'category_mask'; const IMAGEA_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; +const DEFAULT_OUTPUT_CATEGORY_MASK = false; +const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern +/** + * A callback that receives the computed masks from the interactive segmenter. + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type InteractiveSegmenterCallack = + (result: InteractiveSegmenterResult) => void; + /** * Performs interactive segmentation on images. * @@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH = * - batch is always 1 */ export class InteractiveSegmenter extends VisionTaskRunner { - private userCallback: SegmentationMaskCallback = () => {}; + private result: InteractiveSegmenterResult = {width: 0, height: 0}; + private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; + private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner { * @return A Promise that resolves when the settings have been applied. */ override setOptions(options: InteractiveSegmenterOptions): Promise { - if (options.outputType === 'CONFIDENCE_MASK') { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); - } else { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CATEGORY_MASK); + if ('outputCategoryMask' in options) { + this.outputCategoryMask = + options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK; + } + + if ('outputConfidenceMasks' in options) { + this.outputConfidenceMasks = + options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS; } return super.applyOptions(options); @@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { */ segment( image: ImageSource, roi: RegionOfInterest, - callback: SegmentationMaskCallback): void; + callback: InteractiveSegmenterCallack): void; /** * Performs interactive segmentation on the provided single image and invokes * the callback with the response. The `roi` parameter is used to represent a @@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner { segment( image: ImageSource, roi: RegionOfInterest, imageProcessingOptions: ImageProcessingOptions, - callback: SegmentationMaskCallback): void; + callback: InteractiveSegmenterCallack): void; segment( image: ImageSource, roi: RegionOfInterest, imageProcessingOptionsOrCallback: ImageProcessingOptions| - SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { + InteractiveSegmenterCallack, + callback?: InteractiveSegmenterCallack): void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - - this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + const userCallback = + typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; + this.reset(); this.processRenderData(roi, this.getSynctheticTimestamp()); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + userCallback(this.result); + } + + private reset(): void { + this.result = {width: 0, height: 0}; } /** Updates the MediaPipe graph configuration. */ @@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner { graphConfig.addInputStream(IMAGE_IN_STREAM); graphConfig.addInputStream(ROI_IN_STREAM); graphConfig.addInputStream(NORM_RECT_IN_STREAM); - graphConfig.addOutputStream(IMAGE_OUT_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner { segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM); segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM); - segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM); segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); - this.graphRunner.attachImageVectorListener( - IMAGE_OUT_STREAM, (masks, timestamp) => { - if (masks.length === 0) { - this.userCallback([], 0, 0); - } else { - this.userCallback( - masks.map(m => m.data), masks[0].width, masks[0].height); - } - this.setLatestOutputTimestamp(timestamp); - }); - this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => { - this.setLatestOutputTimestamp(timestamp); - }); + if (this.outputConfidenceMasks) { + graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); + segmenterNode.addOutputStream( + 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + + this.graphRunner.attachImageVectorListener( + CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { + this.result.confidenceMasks = masks.map(m => m.data); + if (masks.length >= 0) { + this.result.width = masks[0].width; + this.result.height = masks[0].height; + } + + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CONFIDENCE_MASKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputCategoryMask) { + graphConfig.addOutputStream(CATEGORY_MASK_STREAM); + segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + + this.graphRunner.attachImageListener( + CATEGORY_MASK_STREAM, (mask, timestamp) => { + this.result.categoryMask = mask.data; + this.result.width = mask.width; + this.result.height = mask.height; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CATEGORY_MASK_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts index beb43cd81..269403d97 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts @@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options' /** Options to configure the MediaPipe Interactive Segmenter Task */ export interface InteractiveSegmenterOptions extends TaskRunnerOptions { - /** - * The output type of segmentation results. - * - * The two supported modes are: - * - Category Mask: Gives a single output mask where each pixel represents - * the class which the pixel in the original image was - * predicted to belong to. - * - Confidence Mask: Gives a list of output masks (one for each class). For - * each mask, the pixel represents the prediction - * confidence, usually in the [0.0, 0.1] range. - * - * Defaults to `CATEGORY_MASK`. - */ - outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; + /** Whether to output confidence masks. Defaults to true. */ + outputConfidenceMasks?: boolean|undefined; + + /** Whether to output the category masks. Defaults to false. */ + outputCategoryMask?: boolean|undefined; } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts new file mode 100644 index 000000000..f7e1f3a19 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The output result of InteractiveSegmenter. */ +export declare interface InteractiveSegmenterResult { + /** + * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each + * pixel represents the prediction confidence, usually in the [0, 1] range. + */ + confidenceMasks?: Float32Array[]|WebGLTexture[]; + + /** + * A category mask as a Uint8ClampedArray or WebGLTexture where each + * pixel represents the class which the pixel in the original image was + * predicted to belong to. + */ + categoryMask?: Uint8ClampedArray|WebGLTexture; + + /** The width of the masks. */ + width: number; + + /** The height of the masks. */ + height: number; +} diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index d6e3a97a5..884be032d 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -18,7 +18,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; -import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; @@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - imageVectorListener: + categoryMaskListener: + ((images: WasmImage, timestamp: number) => void)|undefined; + confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; lastRoi?: RenderDataProto; @@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements this.fakeWasmModule = this.graphRunner.wasmModule as unknown as SpyWasmModule; - this.attachListenerSpies[0] = + this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('category_mask'); + this.categoryMaskListener = listener; + }); + this.attachListenerSpies[1] = spyOn(this.graphRunner, 'attachImageVectorListener') .and.callFake((stream, listener) => { - expect(stream).toEqual('image_out'); - this.imageVectorListener = listener; + expect(stream).toEqual('confidence_masks'); + this.confidenceMasksListener = listener; }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => { it('initializes graph', async () => { verifyGraph(interactiveSegmenter); - verifyListenersRegistered(interactiveSegmenter); + + // Verify default options + expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); }); it('reloads graph when settings are changed', async () => { - await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); - verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]); - verifyListenersRegistered(interactiveSegmenter); + await interactiveSegmenter.setOptions( + {outputConfidenceMasks: true, outputCategoryMask: false}); + expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]); - verifyListenersRegistered(interactiveSegmenter); + await interactiveSegmenter.setOptions( + {outputConfidenceMasks: false, outputCategoryMask: true}); + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); }); it('can use custom models', async () => { @@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => { ]); }); - - describe('setOptions()', () => { - const fieldPath = ['segmenterOptions', 'outputType']; - - it(`can set outputType`, async () => { - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [fieldPath, 2]); - }); - - it(`can clear outputType`, async () => { - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [fieldPath, 2]); - await interactiveSegmenter.setOptions({outputType: undefined}); - verifyGraph(interactiveSegmenter, [fieldPath, 1]); - }); - }); - it('doesn\'t support region of interest', () => { expect(() => { interactiveSegmenter.segment( @@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => { interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); }); - it('supports category masks', (done) => { + it('supports category mask', async () => { const mask = new Uint8ClampedArray([1, 2, 3, 4]); + await interactiveSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: false}); + // Pass the test data to our listener interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(interactiveSegmenter); - interactiveSegmenter.imageVectorListener!( - [ - {data: mask, width: 2, height: 2}, - ], - /* timestamp= */ 1337); + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); + interactiveSegmenter.categoryMaskListener! + ({data: mask, width: 2, height: 2}, + /* timestamp= */ 1337); }); // Invoke the image segmenter - interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, (masks, width, height) => { - expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) - .toHaveBeenCalled(); - expect(masks).toHaveSize(1); - expect(masks[0]).toEqual(mask); - expect(width).toEqual(2); - expect(height).toEqual(2); - done(); - }); + return new Promise(resolve => { + interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(result.categoryMask).toEqual(mask); + expect(result.confidenceMasks).not.toBeDefined(); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); }); it('supports confidence masks', async () => { const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + await interactiveSegmenter.setOptions( + {outputCategoryMask: false, outputConfidenceMasks: true}); // Pass the test data to our listener interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(interactiveSegmenter); - interactiveSegmenter.imageVectorListener!( + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); + interactiveSegmenter.confidenceMasksListener!( [ {data: mask1, width: 2, height: 2}, {data: mask2, width: 2, height: 2}, ], 1337); }); + return new Promise(resolve => { + // Invoke the image segmenter + interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(result.categoryMask).not.toBeDefined(); + expect(result.confidenceMasks).toEqual([mask1, mask2]); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); + }); + + it('supports combined category and confidence masks', async () => { + const categoryMask = new Uint8ClampedArray([1, 0]); + const confidenceMask1 = new Float32Array([0.0, 1.0]); + const confidenceMask2 = new Float32Array([1.0, 0.0]); + + await interactiveSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: true}); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); + interactiveSegmenter.categoryMaskListener! + ({data: categoryMask, width: 1, height: 1}, 1337); + interactiveSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask1, width: 1, height: 1}, + {data: confidenceMask2, width: 1, height: 1}, + ], + 1337); + }); return new Promise(resolve => { // Invoke the image segmenter interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, (masks, width, height) => { + {} as HTMLImageElement, ROI, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); - expect(masks).toHaveSize(2); - expect(masks[0]).toEqual(mask1); - expect(masks[1]).toEqual(mask2); - expect(width).toEqual(2); - expect(height).toEqual(2); + expect(result.categoryMask).toEqual(categoryMask); + expect(result.confidenceMasks).toEqual([ + confidenceMask1, confidenceMask2 + ]); + expect(result.width).toEqual(1); + expect(result.height).toEqual(1); resolve(); }); });