Support new output format for InteractiveSegmenter
PiperOrigin-RevId: 524940992
This commit is contained in:
parent
48cc96cf3c
commit
b147002b7e
11
mediapipe/tasks/web/vision/core/types.d.ts
vendored
11
mediapipe/tasks/web/vision/core/types.d.ts
vendored
|
@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
|
|||
*/
|
||||
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
|
||||
|
||||
/**
|
||||
* A callback that receives the computed masks from the segmentation tasks. The
|
||||
* callback either receives a single element array with a category mask (as a
|
||||
* `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`).
|
||||
* The returned data is only valid for the duration of the callback. If
|
||||
* asynchronous processing is needed, all data needs to be copied before the
|
||||
* callback returns.
|
||||
*/
|
||||
export type SegmentationMaskCallback =
|
||||
(masks: SegmentationMask[], width: number, height: number) => void;
|
||||
|
||||
/**
|
||||
* A callback that receives an `ImageData` object from a Vision task. The
|
||||
* lifetime of the underlying data is limited to the duration of the callback.
|
||||
|
|
|
@ -30,7 +30,10 @@ mediapipe_ts_library(
|
|||
|
||||
mediapipe_ts_declaration(
|
||||
name = "interactive_segmenter_types",
|
||||
srcs = ["interactive_segmenter_options.d.ts"],
|
||||
srcs = [
|
||||
"interactive_segmenter_options.d.ts",
|
||||
"interactive_segmenter_result.d.ts",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
|
|
|
@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
|||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||
import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {Color as ColorProto} from '../../../../util/color_pb';
|
||||
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
|
@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
|
||||
import {InteractiveSegmenterResult} from './interactive_segmenter_result';
|
||||
|
||||
export * from './interactive_segmenter_options';
|
||||
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
|
||||
export * from './interactive_segmenter_result';
|
||||
export {SegmentationMask, RegionOfInterest};
|
||||
export {ImageSource};
|
||||
|
||||
const IMAGE_IN_STREAM = 'image_in';
|
||||
const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||
const ROI_IN_STREAM = 'roi_in';
|
||||
const IMAGE_OUT_STREAM = 'image_out';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* A callback that receives the computed masks from the interactive segmenter.
|
||||
* The returned data is only valid for the duration of the callback. If
|
||||
* asynchronous processing is needed, all data needs to be copied before the
|
||||
* callback returns.
|
||||
*/
|
||||
export type InteractiveSegmenterCallack =
|
||||
(result: InteractiveSegmenterResult) => void;
|
||||
|
||||
/**
|
||||
* Performs interactive segmentation on images.
|
||||
*
|
||||
|
@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH =
|
|||
* - batch is always 1
|
||||
*/
|
||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private result: InteractiveSegmenterResult = {width: 0, height: 0};
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
|
@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
* @return A Promise that resolves when the settings have been applied.
|
||||
*/
|
||||
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
|
||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||
} else {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||
if ('outputCategoryMask' in options) {
|
||||
this.outputCategoryMask =
|
||||
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
}
|
||||
|
||||
if ('outputConfidenceMasks' in options) {
|
||||
this.outputConfidenceMasks =
|
||||
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
|
@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
*/
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
callback: InteractiveSegmenterCallack): void;
|
||||
/**
|
||||
* Performs interactive segmentation on the provided single image and invokes
|
||||
* the callback with the response. The `roi` parameter is used to represent a
|
||||
|
@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptions: ImageProcessingOptions,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
callback: InteractiveSegmenterCallack): void;
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||
SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
InteractiveSegmenterCallack,
|
||||
callback?: InteractiveSegmenterCallack): void {
|
||||
const imageProcessingOptions =
|
||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
const userCallback =
|
||||
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
|
||||
this.reset();
|
||||
this.processRenderData(roi, this.getSynctheticTimestamp());
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
this.userCallback = () => {};
|
||||
userCallback(this.result);
|
||||
}
|
||||
|
||||
private reset(): void {
|
||||
this.result = {width: 0, height: 0};
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
|
@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
graphConfig.addInputStream(IMAGE_IN_STREAM);
|
||||
graphConfig.addInputStream(ROI_IN_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
|
||||
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
|
||||
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
|
||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
|
||||
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
|
||||
segmenterNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(segmenterNode);
|
||||
|
||||
if (this.outputConfidenceMasks) {
|
||||
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||
segmenterNode.addOutputStream(
|
||||
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
IMAGE_OUT_STREAM, (masks, timestamp) => {
|
||||
if (masks.length === 0) {
|
||||
this.userCallback([], 0, 0);
|
||||
} else {
|
||||
this.userCallback(
|
||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||
this.result.confidenceMasks = masks.map(m => m.data);
|
||||
if (masks.length >= 0) {
|
||||
this.result.width = masks[0].width;
|
||||
this.result.height = masks[0].height;
|
||||
}
|
||||
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
if (this.outputCategoryMask) {
|
||||
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||
|
||||
this.graphRunner.attachImageListener(
|
||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||
this.result.categoryMask = mask.data;
|
||||
this.result.width = mask.width;
|
||||
this.result.height = mask.height;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
CATEGORY_MASK_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'
|
|||
|
||||
/** Options to configure the MediaPipe Interactive Segmenter Task */
|
||||
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
|
||||
/**
|
||||
* The output type of segmentation results.
|
||||
*
|
||||
* The two supported modes are:
|
||||
* - Category Mask: Gives a single output mask where each pixel represents
|
||||
* the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||
* each mask, the pixel represents the prediction
|
||||
* confidence, usually in the [0.0, 0.1] range.
|
||||
*
|
||||
* Defaults to `CATEGORY_MASK`.
|
||||
*/
|
||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||
/** Whether to output confidence masks. Defaults to true. */
|
||||
outputConfidenceMasks?: boolean|undefined;
|
||||
|
||||
/** Whether to output the category masks. Defaults to false. */
|
||||
outputCategoryMask?: boolean|undefined;
|
||||
}
|
||||
|
|
37
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts
vendored
Normal file
37
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/** The output result of InteractiveSegmenter. */
|
||||
export declare interface InteractiveSegmenterResult {
|
||||
/**
|
||||
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
||||
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||
*/
|
||||
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
||||
|
||||
/**
|
||||
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
||||
* pixel represents the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
*/
|
||||
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
||||
|
||||
/** The width of the masks. */
|
||||
width: number;
|
||||
|
||||
/** The height of the masks. */
|
||||
height: number;
|
||||
}
|
|
@ -18,7 +18,7 @@ import 'jasmine';
|
|||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||
|
||||
|
@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
imageVectorListener:
|
||||
categoryMaskListener:
|
||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
|
@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('category_mask');
|
||||
this.categoryMaskListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[1] =
|
||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('image_out');
|
||||
this.imageVectorListener = listener;
|
||||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
|
@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => {
|
|||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(interactiveSegmenter);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
|
||||
// Verify default options
|
||||
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputConfidenceMasks: true, outputCategoryMask: false});
|
||||
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputConfidenceMasks: false, outputCategoryMask: true});
|
||||
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
|
@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
|
||||
describe('setOptions()', () => {
|
||||
const fieldPath = ['segmenterOptions', 'outputType'];
|
||||
|
||||
it(`can set outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
});
|
||||
|
||||
it(`can clear outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
await interactiveSegmenter.setOptions({outputType: undefined});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
|
||||
});
|
||||
});
|
||||
|
||||
it('doesn\'t support region of interest', () => {
|
||||
expect(() => {
|
||||
interactiveSegmenter.segment(
|
||||
|
@ -153,29 +147,31 @@ describe('InteractiveSegmenter', () => {
|
|||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
|
||||
});
|
||||
|
||||
it('supports category masks', (done) => {
|
||||
it('supports category mask', async () => {
|
||||
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputCategoryMask: true, outputConfidenceMasks: false});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask, width: 2, height: 2},
|
||||
],
|
||||
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||
interactiveSegmenter.categoryMaskListener!
|
||||
({data: mask, width: 2, height: 2},
|
||||
/* timestamp= */ 1337);
|
||||
});
|
||||
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
return new Promise<void>(resolve => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(1);
|
||||
expect(masks[0]).toEqual(mask);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
done();
|
||||
expect(result.categoryMask).toEqual(mask);
|
||||
expect(result.confidenceMasks).not.toBeDefined();
|
||||
expect(result.width).toEqual(2);
|
||||
expect(result.height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -183,30 +179,67 @@ describe('InteractiveSegmenter', () => {
|
|||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputCategoryMask: false, outputConfidenceMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
interactiveSegmenter.confidenceMasksListener!(
|
||||
[
|
||||
{data: mask1, width: 2, height: 2},
|
||||
{data: mask2, width: 2, height: 2},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(result.categoryMask).not.toBeDefined();
|
||||
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
||||
expect(result.width).toEqual(2);
|
||||
expect(result.height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('supports combined category and confidence masks', async () => {
|
||||
const categoryMask = new Uint8ClampedArray([1, 0]);
|
||||
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
||||
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
interactiveSegmenter.categoryMaskListener!
|
||||
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||
interactiveSegmenter.confidenceMasksListener!(
|
||||
[
|
||||
{data: confidenceMask1, width: 1, height: 1},
|
||||
{data: confidenceMask2, width: 1, height: 1},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
{} as HTMLImageElement, ROI, result => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(2);
|
||||
expect(masks[0]).toEqual(mask1);
|
||||
expect(masks[1]).toEqual(mask2);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
expect(result.categoryMask).toEqual(categoryMask);
|
||||
expect(result.confidenceMasks).toEqual([
|
||||
confidenceMask1, confidenceMask2
|
||||
]);
|
||||
expect(result.width).toEqual(1);
|
||||
expect(result.height).toEqual(1);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user