Support new output format for ImageSegmenter

PiperOrigin-RevId: 524371021
This commit is contained in:
Sebastian Schmidt 2023-04-14 13:23:03 -07:00 committed by Copybara-Service
parent f5197a3adc
commit 92f45c98d8
6 changed files with 268 additions and 153 deletions

View File

@ -59,13 +59,12 @@ export function drawCategoryMask(
const isFloatArray = image instanceof Float32Array; const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) { for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex]; let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
// When we're given a confidence mask by accident, we just log and return.
// TODO: We should fix this.
if (!color) { if (!color) {
// TODO: We should fix this.
console.warn('No color for ', colorIndex); console.warn('No color for ', colorIndex);
return; color = COLOR_MAP[colorIndex % COLOR_MAP.length];
} }
rgbaArray[4 * i] = color[0]; rgbaArray[4 * i] = color[0];

View File

@ -29,7 +29,10 @@ mediapipe_ts_library(
mediapipe_ts_declaration( mediapipe_ts_declaration(
name = "image_segmenter_types", name = "image_segmenter_types",
srcs = ["image_segmenter_options.d.ts"], srcs = [
"image_segmenter_options.d.ts",
"image_segmenter_result.d.ts",
],
deps = [ deps = [
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",

View File

@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {LabelMapItem} from '../../../../util/label_map_pb'; import {LabelMapItem} from '../../../../util/label_map_pb';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ImageSegmenterOptions} from './image_segmenter_options'; import {ImageSegmenterOptions} from './image_segmenter_options';
import {ImageSegmenterResult} from './image_segmenter_result';
export * from './image_segmenter_options'; export * from './image_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback}; export * from './image_segmenter_result';
export {SegmentationMask};
export {ImageSource}; // Used in the public API export {ImageSource}; // Used in the public API
const IMAGE_STREAM = 'image_in'; const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect'; const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const IMAGE_SEGMENTER_GRAPH = const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
'mediapipe.tasks.TensorsToSegmentationCalculator'; 'mediapipe.tasks.TensorsToSegmentationCalculator';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/**
* A callback that receives the computed masks from the image segmenter. The
* returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
/** Performs image segmentation on images. */ /** Performs image segmentation on images. */
export class ImageSegmenter extends VisionTaskRunner { export class ImageSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {}; private result: ImageSegmenterResult = {width: 0, height: 0};
private labels: string[] = []; private labels: string[] = [];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto;
@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override get baseOptions(): BaseOptionsProto { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!; return this.options.getBaseOptions()!;
} }
@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
this.options.clearDisplayNamesLocale(); this.options.clearDisplayNamesLocale();
} }
if (options.outputType === 'CONFIDENCE_MASK') { if ('outputCategoryMask' in options) {
this.segmenterOptions.setOutputType( this.outputCategoryMask =
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
} else { }
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK); if ('outputConfidenceMasks' in options) {
this.outputConfidenceMasks =
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
} }
return super.applyOptions(options); return super.applyOptions(options);
@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
* lifetime of the returned data is only guaranteed for the duration of the * lifetime of the returned data is only guaranteed for the duration of the
* callback. * callback.
*/ */
segment(image: ImageSource, callback: SegmentationMaskCallback): void; segment(image: ImageSource, callback: ImageSegmenterCallack): void;
/** /**
* Performs image segmentation on the provided single image and invokes the * Performs image segmentation on the provided single image and invokes the
* callback with the response. The method returns synchronously once the * callback with the response. The method returns synchronously once the
@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
*/ */
segment( segment(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions, image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void; callback: ImageSegmenterCallack): void;
segment( segment(
image: ImageSource, image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback, ImageSegmenterCallack,
callback?: SegmentationMaskCallback): void { callback?: ImageSegmenterCallack): void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
const userCallback =
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback!;
this.reset();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {}; userCallback(this.result);
}
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: ImageSegmenterCallack): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: ImageSegmenterCallack): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|ImageSegmenterCallack,
callback?: ImageSegmenterCallack): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
const userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.reset();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
userCallback(this.result);
} }
/** /**
@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
return this.labels; return this.labels;
} }
/** private reset(): void {
* Performs image segmentation on the provided video frame and invokes the this.result = {width: 0, height: 0};
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: SegmentationMaskCallback): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: SegmentationMaskCallback): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {};
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM);
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH); segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
segmenterNode.addOutputStream(
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
segmenterNode.setOptions(calculatorOptions); segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode); graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener( if (this.outputConfidenceMasks) {
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
if (masks.length === 0) { segmenterNode.addOutputStream(
this.userCallback([], 0, 0); 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
} else {
this.userCallback( this.graphRunner.attachImageVectorListener(
masks.map(m => m.data), masks[0].width, masks[0].height); CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
} this.result.confidenceMasks = masks.map(m => m.data);
this.setLatestOutputTimestamp(timestamp); if (masks.length >= 0) {
}); this.result.width = masks[0].width;
this.graphRunner.attachEmptyPacketListener( this.result.height = masks[0].height;
GROUPED_SEGMENTATIONS_STREAM, timestamp => { }
this.setLatestOutputTimestamp(timestamp);
}); this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
if (this.outputCategoryMask) {
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = mask.data;
this.result.width = mask.width;
this.result.height = mask.height;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
*/ */
displayNamesLocale?: string|undefined; displayNamesLocale?: string|undefined;
/** /** Whether to output confidence masks. Defaults to true. */
* The output type of segmentation results. outputConfidenceMasks?: boolean|undefined;
*
* The two supported modes are: /** Whether to output the category masks. Defaults to false. */
* - Category Mask: Gives a single output mask where each pixel represents outputCategoryMask?: boolean|undefined;
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
} }

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** The output result of ImageSegmenter. */
export declare interface ImageSegmenterResult {
/**
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
* pixel represents the prediction confidence, usually in the [0, 1] range.
*/
confidenceMasks?: Float32Array[]|WebGLTexture[];
/**
* A category mask as a Uint8ClampedArray or WebGLTexture where each
* pixel represents the class which the pixel in the original image was
* predicted to belong to.
*/
categoryMask?: Uint8ClampedArray|WebGLTexture;
/** The width of the masks. */
width: number;
/** The height of the masks. */
height: number;
}

View File

@ -18,7 +18,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {ImageSegmenter} from './image_segmenter'; import {ImageSegmenter} from './image_segmenter';
@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
graph: CalculatorGraphConfig|undefined; graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule; fakeWasmModule: SpyWasmModule;
imageVectorListener: categoryMaskListener:
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
constructor() { constructor() {
@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
this.fakeWasmModule = this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule; this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] = this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('category_mask');
this.categoryMaskListener = listener;
});
this.attachListenerSpies[1] =
spyOn(this.graphRunner, 'attachImageVectorListener') spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => { .and.callFake((stream, listener) => {
expect(stream).toEqual('segmented_masks'); expect(stream).toEqual('confidence_masks');
this.imageVectorListener = listener; this.confidenceMasksListener = listener;
}); });
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
it('initializes graph', async () => { it('initializes graph', async () => {
verifyGraph(imageSegmenter); verifyGraph(imageSegmenter);
verifyListenersRegistered(imageSegmenter);
// Verify default options
expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
}); });
it('reloads graph when settings are changed', async () => { it('reloads graph when settings are changed', async () => {
await imageSegmenter.setOptions({displayNamesLocale: 'en'}); await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
verifyListenersRegistered(imageSegmenter);
await imageSegmenter.setOptions({displayNamesLocale: 'de'}); await imageSegmenter.setOptions({displayNamesLocale: 'de'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
verifyListenersRegistered(imageSegmenter);
}); });
it('can use custom models', async () => { it('can use custom models', async () => {
@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
}); });
it('merges options', async () => { it('merges options', async () => {
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); await imageSegmenter.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
await imageSegmenter.setOptions({displayNamesLocale: 'en'}); await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); verifyGraph(
imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
}); });
@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
defaultValue: unknown; defaultValue: unknown;
} }
const testCases: TestCase[] = [ const testCases: TestCase[] = [{
{ optionName: 'displayNamesLocale',
optionName: 'displayNamesLocale', fieldPath: ['displayNamesLocale'],
fieldPath: ['displayNamesLocale'], userValue: 'en',
userValue: 'en', graphValue: 'en',
graphValue: 'en', defaultValue: 'en'
defaultValue: 'en' }];
},
{
optionName: 'outputType',
fieldPath: ['segmenterOptions', 'outputType'],
userValue: 'CONFIDENCE_MASK',
graphValue: 2,
defaultValue: 1
},
];
for (const testCase of testCases) { for (const testCase of testCases) {
it(`can set ${testCase.optionName}`, async () => { it(`can set ${testCase.optionName}`, async () => {
@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
}).toThrowError('This task doesn\'t support region-of-interest.'); }).toThrowError('This task doesn\'t support region-of-interest.');
}); });
it('supports category masks', (done) => { it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]); const mask = new Uint8ClampedArray([1, 2, 3, 4]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
// Pass the test data to our listener // Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter); expect(imageSegmenter.categoryMaskListener).toBeDefined();
imageSegmenter.imageVectorListener!( imageSegmenter.categoryMaskListener!
[ ({data: mask, width: 2, height: 2},
{data: mask, width: 2, height: 2}, /* timestamp= */ 1337);
],
/* timestamp= */ 1337);
}); });
// Invoke the image segmenter // Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); return new Promise<void>(resolve => {
expect(masks).toHaveSize(1); imageSegmenter.segment({} as HTMLImageElement, result => {
expect(masks[0]).toEqual(mask); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(width).toEqual(2); expect(result.categoryMask).toEqual(mask);
expect(height).toEqual(2); expect(result.confidenceMasks).not.toBeDefined();
done(); expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
}); });
}); });
@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); await imageSegmenter.setOptions(
{outputCategoryMask: false, outputConfidenceMasks: true});
// Pass the test data to our listener // Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter); expect(imageSegmenter.confidenceMasksListener).toBeDefined();
imageSegmenter.imageVectorListener!( imageSegmenter.confidenceMasksListener!(
[ [
{data: mask1, width: 2, height: 2}, {data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2}, {data: mask2, width: 2, height: 2},
@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
// Invoke the image segmenter // Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(masks).toHaveSize(2); expect(result.categoryMask).not.toBeDefined();
expect(masks[0]).toEqual(mask1); expect(result.confidenceMasks).toEqual([mask1, mask2]);
expect(masks[1]).toEqual(mask2); expect(result.width).toEqual(2);
expect(width).toEqual(2); expect(result.height).toEqual(2);
expect(height).toEqual(2); resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1, 0]);
const confidenceMask1 = new Float32Array([0.0, 1.0]);
const confidenceMask2 = new Float32Array([1.0, 0.0]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true});
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(imageSegmenter.categoryMaskListener).toBeDefined();
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
imageSegmenter.categoryMaskListener!
({data: categoryMask, width: 1, height: 1}, 1337);
imageSegmenter.confidenceMasksListener!(
[
{data: confidenceMask1, width: 1, height: 1},
{data: confidenceMask2, width: 1, height: 1},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toEqual(categoryMask);
expect(result.confidenceMasks).toEqual([
confidenceMask1, confidenceMask2
]);
expect(result.width).toEqual(1);
expect(result.height).toEqual(1);
resolve(); resolve();
}); });
}); });