diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts index 879e23010..903d789f5 100644 --- a/mediapipe/tasks/web/vision/core/render_utils.ts +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -59,13 +59,12 @@ export function drawCategoryMask( const isFloatArray = image instanceof Float32Array; for (let i = 0; i < image.length; i++) { const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; - const color = COLOR_MAP[colorIndex]; + let color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - // When we're given a confidence mask by accident, we just log and return. - // TODO: We should fix this. if (!color) { + // TODO: We should fix this. console.warn('No color for ', colorIndex); - return; + color = COLOR_MAP[colorIndex % COLOR_MAP.length]; } rgbaArray[4 * i] = color[0]; diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD index a4b9008dd..3db15641f 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -29,7 +29,10 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "image_segmenter_types", - srcs = ["image_segmenter_options.d.ts"], + srcs = [ + "image_segmenter_options.d.ts", + "image_segmenter_result.d.ts", + ], deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 3690fd855..740047762 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {LabelMapItem} from '../../../../util/label_map_pb'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageSegmenterOptions} from './image_segmenter_options'; +import {ImageSegmenterResult} from './image_segmenter_result'; export * from './image_segmenter_options'; -export {SegmentationMask, SegmentationMaskCallback}; +export * from './image_segmenter_result'; +export {SegmentationMask}; export {ImageSource}; // Used in the public API const IMAGE_STREAM = 'image_in'; const NORM_RECT_STREAM = 'norm_rect'; -const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; +const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; +const CATEGORY_MASK_STREAM = 'category_mask'; const IMAGE_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = 'mediapipe.tasks.TensorsToSegmentationCalculator'; +const DEFAULT_OUTPUT_CATEGORY_MASK = false; +const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern +/** + * A callback that receives the computed masks from the image segmenter. The + * returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void; + /** Performs image segmentation on images. */ export class ImageSegmenter extends VisionTaskRunner { - private userCallback: SegmentationMaskCallback = () => {}; + private result: ImageSegmenterResult = {width: 0, height: 0}; private labels: string[] = []; + private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; + private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner { this.options.setBaseOptions(new BaseOptionsProto()); } - protected override get baseOptions(): BaseOptionsProto { return this.options.getBaseOptions()!; } @@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner { this.options.clearDisplayNamesLocale(); } - if (options.outputType === 'CONFIDENCE_MASK') { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); - } else { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CATEGORY_MASK); + if ('outputCategoryMask' in options) { + this.outputCategoryMask = + options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK; + } + + if ('outputConfidenceMasks' in options) { + this.outputConfidenceMasks = + options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS; } return super.applyOptions(options); @@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner { * lifetime of the returned data is only guaranteed for the duration of the * callback. */ - segment(image: ImageSource, callback: SegmentationMaskCallback): void; + segment(image: ImageSource, callback: ImageSegmenterCallack): void; /** * Performs image segmentation on the provided single image and invokes the * callback with the response. The method returns synchronously once the @@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner { */ segment( image: ImageSource, imageProcessingOptions: ImageProcessingOptions, - callback: SegmentationMaskCallback): void; + callback: ImageSegmenterCallack): void; segment( image: ImageSource, imageProcessingOptionsOrCallback: ImageProcessingOptions| - SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { + ImageSegmenterCallack, + callback?: ImageSegmenterCallack): void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - - this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + const userCallback = + typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; + + this.reset(); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + userCallback(this.result); + } + + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, timestamp: number, + callback: ImageSegmenterCallack): void; + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: ImageSegmenterCallack): void; + segmentForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|ImageSegmenterCallack, + callback?: ImageSegmenterCallack): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + const userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + + this.reset(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + userCallback(this.result); } /** @@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner { return this.labels; } - /** - * Performs image segmentation on the provided video frame and invokes the - * callback with the response. The method returns synchronously once the - * callback returns. Only use this method when the ImageSegmenter is - * created with running mode `video`. - * - * @param videoFrame A video frame to process. - * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the segmented masks. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. - */ - segmentForVideo( - videoFrame: ImageSource, timestamp: number, - callback: SegmentationMaskCallback): void; - /** - * Performs image segmentation on the provided video frame and invokes the - * callback with the response. The method returns synchronously once the - * callback returns. Only use this method when the ImageSegmenter is - * created with running mode `video`. - * - * @param videoFrame A video frame to process. - * @param imageProcessingOptions the `ImageProcessingOptions` specifying how - * to process the input image before running inference. - * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the segmented masks. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. - */ - segmentForVideo( - videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, - timestamp: number, callback: SegmentationMaskCallback): void; - segmentForVideo( - videoFrame: ImageSource, - timestampOrImageProcessingOptions: number|ImageProcessingOptions, - timestampOrCallback: number|SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { - const imageProcessingOptions = - typeof timestampOrImageProcessingOptions !== 'number' ? - timestampOrImageProcessingOptions : - {}; - const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? - timestampOrImageProcessingOptions : - timestampOrCallback as number; - - this.userCallback = typeof timestampOrCallback === 'function' ? - timestampOrCallback : - callback!; - this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - this.userCallback = () => {}; + private reset(): void { + this.result = {width: 0, height: 0}; } /** Updates the MediaPipe graph configuration. */ @@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); - graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner { segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH); segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); - segmenterNode.addOutputStream( - 'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM); segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); - this.graphRunner.attachImageVectorListener( - GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { - if (masks.length === 0) { - this.userCallback([], 0, 0); - } else { - this.userCallback( - masks.map(m => m.data), masks[0].width, masks[0].height); - } - this.setLatestOutputTimestamp(timestamp); - }); - this.graphRunner.attachEmptyPacketListener( - GROUPED_SEGMENTATIONS_STREAM, timestamp => { - this.setLatestOutputTimestamp(timestamp); - }); + if (this.outputConfidenceMasks) { + graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); + segmenterNode.addOutputStream( + 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + + this.graphRunner.attachImageVectorListener( + CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { + this.result.confidenceMasks = masks.map(m => m.data); + if (masks.length >= 0) { + this.result.width = masks[0].width; + this.result.height = masks[0].height; + } + + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CONFIDENCE_MASKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputCategoryMask) { + graphConfig.addOutputStream(CATEGORY_MASK_STREAM); + segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + + this.graphRunner.attachImageListener( + CATEGORY_MASK_STREAM, (mask, timestamp) => { + this.result.categoryMask = mask.data; + this.result.width = mask.width; + this.result.height = mask.height; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CATEGORY_MASK_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts index c17e7e421..f80a792a5 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts @@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions { */ displayNamesLocale?: string|undefined; - /** - * The output type of segmentation results. - * - * The two supported modes are: - * - Category Mask: Gives a single output mask where each pixel represents - * the class which the pixel in the original image was - * predicted to belong to. - * - Confidence Mask: Gives a list of output masks (one for each class). For - * each mask, the pixel represents the prediction - * confidence, usually in the [0.0, 0.1] range. - * - * Defaults to `CATEGORY_MASK`. - */ - outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; + /** Whether to output confidence masks. Defaults to true. */ + outputConfidenceMasks?: boolean|undefined; + + /** Whether to output the category masks. Defaults to false. */ + outputCategoryMask?: boolean|undefined; } diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts new file mode 100644 index 000000000..be082d516 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The output result of ImageSegmenter. */ +export declare interface ImageSegmenterResult { + /** + * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each + * pixel represents the prediction confidence, usually in the [0, 1] range. + */ + confidenceMasks?: Float32Array[]|WebGLTexture[]; + + /** + * A category mask as a Uint8ClampedArray or WebGLTexture where each + * pixel represents the class which the pixel in the original image was + * predicted to belong to. + */ + categoryMask?: Uint8ClampedArray|WebGLTexture; + + /** The width of the masks. */ + width: number; + + /** The height of the masks. */ + height: number; +} diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index 4cf27b9a5..6b5c90080 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -18,7 +18,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; -import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {ImageSegmenter} from './image_segmenter'; @@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - imageVectorListener: + categoryMaskListener: + ((images: WasmImage, timestamp: number) => void)|undefined; + confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; constructor() { @@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { this.fakeWasmModule = this.graphRunner.wasmModule as unknown as SpyWasmModule; - this.attachListenerSpies[0] = + this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('category_mask'); + this.categoryMaskListener = listener; + }); + this.attachListenerSpies[1] = spyOn(this.graphRunner, 'attachImageVectorListener') .and.callFake((stream, listener) => { - expect(stream).toEqual('segmented_masks'); - this.imageVectorListener = listener; + expect(stream).toEqual('confidence_masks'); + this.confidenceMasksListener = listener; }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -63,17 +70,18 @@ describe('ImageSegmenter', () => { it('initializes graph', async () => { verifyGraph(imageSegmenter); - verifyListenersRegistered(imageSegmenter); + + // Verify default options + expect(imageSegmenter.categoryMaskListener).not.toBeDefined(); + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); }); it('reloads graph when settings are changed', async () => { await imageSegmenter.setOptions({displayNamesLocale: 'en'}); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); - verifyListenersRegistered(imageSegmenter); await imageSegmenter.setOptions({displayNamesLocale: 'de'}); verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); - verifyListenersRegistered(imageSegmenter); }); it('can use custom models', async () => { @@ -100,9 +108,11 @@ describe('ImageSegmenter', () => { }); it('merges options', async () => { - await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + await imageSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); await imageSegmenter.setOptions({displayNamesLocale: 'en'}); - verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyGraph( + imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); }); @@ -115,22 +125,13 @@ describe('ImageSegmenter', () => { defaultValue: unknown; } - const testCases: TestCase[] = [ - { - optionName: 'displayNamesLocale', - fieldPath: ['displayNamesLocale'], - userValue: 'en', - graphValue: 'en', - defaultValue: 'en' - }, - { - optionName: 'outputType', - fieldPath: ['segmenterOptions', 'outputType'], - userValue: 'CONFIDENCE_MASK', - graphValue: 2, - defaultValue: 1 - }, - ]; + const testCases: TestCase[] = [{ + optionName: 'displayNamesLocale', + fieldPath: ['displayNamesLocale'], + userValue: 'en', + graphValue: 'en', + defaultValue: 'en' + }]; for (const testCase of testCases) { it(`can set ${testCase.optionName}`, async () => { @@ -158,27 +159,31 @@ describe('ImageSegmenter', () => { }).toThrowError('This task doesn\'t support region-of-interest.'); }); - it('supports category masks', (done) => { + it('supports category mask', async () => { const mask = new Uint8ClampedArray([1, 2, 3, 4]); + await imageSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: false}); + // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(imageSegmenter); - imageSegmenter.imageVectorListener!( - [ - {data: mask, width: 2, height: 2}, - ], - /* timestamp= */ 1337); + expect(imageSegmenter.categoryMaskListener).toBeDefined(); + imageSegmenter.categoryMaskListener! + ({data: mask, width: 2, height: 2}, + /* timestamp= */ 1337); }); // Invoke the image segmenter - imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { - expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(masks).toHaveSize(1); - expect(masks[0]).toEqual(mask); - expect(width).toEqual(2); - expect(height).toEqual(2); - done(); + + return new Promise(resolve => { + imageSegmenter.segment({} as HTMLImageElement, result => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.categoryMask).toEqual(mask); + expect(result.confidenceMasks).not.toBeDefined(); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); }); }); @@ -186,12 +191,13 @@ describe('ImageSegmenter', () => { const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); - await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + await imageSegmenter.setOptions( + {outputCategoryMask: false, outputConfidenceMasks: true}); // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(imageSegmenter); - imageSegmenter.imageVectorListener!( + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); + imageSegmenter.confidenceMasksListener!( [ {data: mask1, width: 2, height: 2}, {data: mask2, width: 2, height: 2}, @@ -201,13 +207,49 @@ describe('ImageSegmenter', () => { return new Promise(resolve => { // Invoke the image segmenter - imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + imageSegmenter.segment({} as HTMLImageElement, result => { expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(masks).toHaveSize(2); - expect(masks[0]).toEqual(mask1); - expect(masks[1]).toEqual(mask2); - expect(width).toEqual(2); - expect(height).toEqual(2); + expect(result.categoryMask).not.toBeDefined(); + expect(result.confidenceMasks).toEqual([mask1, mask2]); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); + }); + + it('supports combined category and confidence masks', async () => { + const categoryMask = new Uint8ClampedArray([1, 0]); + const confidenceMask1 = new Float32Array([0.0, 1.0]); + const confidenceMask2 = new Float32Array([1.0, 0.0]); + + await imageSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: true}); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(imageSegmenter.categoryMaskListener).toBeDefined(); + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); + imageSegmenter.categoryMaskListener! + ({data: categoryMask, width: 1, height: 1}, 1337); + imageSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask1, width: 1, height: 1}, + {data: confidenceMask2, width: 1, height: 1}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, result => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.categoryMask).toEqual(categoryMask); + expect(result.confidenceMasks).toEqual([ + confidenceMask1, confidenceMask2 + ]); + expect(result.width).toEqual(1); + expect(result.height).toEqual(1); resolve(); }); });