From ec3cd45d615c161fdb8fc234679fc41efa2913e6 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 14 Mar 2023 15:45:07 -0700 Subject: [PATCH] Add InteractiveSegmenter Web API PiperOrigin-RevId: 516654090 --- mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/README.md | 18 ++ mediapipe/tasks/web/vision/index.ts | 3 + .../web/vision/interactive_segmenter/BUILD | 62 ++++ .../interactive_segmenter.ts | 306 ++++++++++++++++++ .../interactive_segmenter_options.d.ts | 36 +++ .../interactive_segmenter_test.ts | 214 ++++++++++++ mediapipe/tasks/web/vision/types.ts | 1 + 8 files changed, 641 insertions(+) create mode 100644 mediapipe/tasks/web/vision/interactive_segmenter/BUILD create mode 100644 mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts create mode 100644 mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts create mode 100644 mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index a229cbd2a..37709c055 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -24,6 +24,7 @@ VISION_LIBS = [ "//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/image_embedder", "//mediapipe/tasks/web/vision/image_segmenter", + "//mediapipe/tasks/web/vision/interactive_segmenter", "//mediapipe/tasks/web/vision/object_detector", ] diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index c1f15ec26..2ca4ff64e 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -75,6 +75,24 @@ imageSegmenter.segment(image, (masks, width, height) => { }); ``` +## Interactive Segmentation + +The MediaPipe Interactive Segmenter lets you select a region of interest to +segment an image by. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const interactiveSegmenter = await InteractiveSegmenter.createFromModelPath( + vision, "model.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +interactiveSegmenter.segment(image, { keypoint: { x: 0.1, y: 0.2 } }, + (masks, width, height) => { ... } +); +``` + ## Object Detection The MediaPipe Object Detector task lets you detect the presence and location of diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 5a87c7a82..fdbb1a65a 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -20,6 +20,7 @@ import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/ha import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; +import {InteractiveSegmenter as InteractiveSegmenterImpl} from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; // Declare the variables locally so that Rollup in OSS includes them explicitly @@ -30,6 +31,7 @@ const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; const ImageEmbedder = ImageEmbedderImpl; const ImageSegmenter = ImageSegementerImpl; +const InteractiveSegmenter = InteractiveSegmenterImpl; const ObjectDetector = ObjectDetectorImpl; export { @@ -39,5 +41,6 @@ export { ImageClassifier, ImageEmbedder, ImageSegmenter, + InteractiveSegmenter, ObjectDetector }; diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD new file mode 100644 index 000000000..a4a3f27c9 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD @@ -0,0 +1,62 @@ +# This contains the MediaPipe Interactive Segmenter Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "interactive_segmenter", + srcs = ["interactive_segmenter.ts"], + deps = [ + ":interactive_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:keypoint", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:types", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/util:color_jspb_proto", + "//mediapipe/util:render_data_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "interactive_segmenter_types", + srcs = ["interactive_segmenter_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "interactive_segmenter_test_lib", + testonly = True, + srcs = [ + "interactive_segmenter_test.ts", + ], + deps = [ + ":interactive_segmenter", + ":interactive_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/util:render_data_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + ], +) + +jasmine_node_test( + name = "interactive_segmenter_test", + tags = ["nomsan"], + deps = [":interactive_segmenter_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts new file mode 100644 index 000000000..1499a4c0c --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -0,0 +1,306 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; +import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {Color as ColorProto} from '../../../../util/color_pb'; +import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {InteractiveSegmenterOptions} from './interactive_segmenter_options'; + +export * from './interactive_segmenter_options'; +export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest}; +export {ImageSource}; + +const IMAGE_IN_STREAM = 'image_in'; +const NORM_RECT_IN_STREAM = 'norm_rect_in'; +const ROI_IN_STREAM = 'roi_in'; +const IMAGE_OUT_STREAM = 'image_out'; +const IMAGEA_SEGMENTER_GRAPH = + 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** + * Performs interactive segmentation on images. + * + * Users can represent user interaction through `RegionOfInterest`, which gives + * a hint to InteractiveSegmenter to perform segmentation focusing on the given + * region of interest. + * + * The API expects a TFLite model with mandatory TFLite Model Metadata. + * + * Input tensor: + * (kTfLiteUInt8/kTfLiteFloat32) + * - image input of size `[batch x height x width x channels]`. + * - batch inference is not supported (`batch` is required to be 1). + * - RGB inputs is supported (`channels` is required to be 3). + * - if type is kTfLiteFloat32, NormalizationOptions are required to be + * attached to the metadata for input normalization. + * Output tensors: + * (kTfLiteUInt8/kTfLiteFloat32) + * - list of segmented masks. + * - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. + * - if `output_type` is CONFIDENCE_MASK, float32 Image list of size + * `channels`. + * - batch is always 1 + */ +export class InteractiveSegmenter extends VisionTaskRunner { + private userCallback: SegmentationMaskCallback = () => {}; + private readonly options: ImageSegmenterGraphOptionsProto; + private readonly segmenterOptions: SegmenterOptionsProto; + + /** + * Initializes the Wasm runtime and creates a new interactive segmenter from + * the provided options. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param interactiveSegmenterOptions The options for the Interactive + * Segmenter. Note that either a path to the model asset or a model buffer + * needs to be provided (via `baseOptions`). + * @return A new `InteractiveSegmenter`. + */ + static createFromOptions( + wasmFileset: WasmFileset, + interactiveSegmenterOptions: InteractiveSegmenterOptions): + Promise { + return VisionTaskRunner.createInstance( + InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, + interactiveSegmenterOptions); + } + + /** + * Initializes the Wasm runtime and creates a new interactive segmenter based + * on the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + * @return A new `InteractiveSegmenter`. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new interactive segmenter based + * on the path to the model asset. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + * @return A new `InteractiveSegmenter`. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_IN_STREAM, + NORM_RECT_IN_STREAM, /* roiAllowed= */ false); + this.options = new ImageSegmenterGraphOptionsProto(); + this.segmenterOptions = new SegmenterOptionsProto(); + this.options.setSegmenterOptions(this.segmenterOptions); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the interactive segmenter. + * + * Calling `setOptions()` with a subset of options only affects those + * options. You can reset an option back to its default value by + * explicitly setting it to `undefined`. + * + * @param options The options for the interactive segmenter. + * @return A Promise that resolves when the settings have been applied. + */ + override setOptions(options: InteractiveSegmenterOptions): Promise { + if (options.outputType === 'CONFIDENCE_MASK') { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); + } else { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CATEGORY_MASK); + } + + return super.applyOptions(options); + } + + /** + * Performs interactive segmentation on the provided single image and invokes + * the callback with the response. The `roi` parameter is used to represent a + * user's region of interest for segmentation. + * + * If the output_type is `CATEGORY_MASK`, the callback is invoked with vector + * of images that represent per-category segmented image mask. If the + * output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of + * images that contains only one confidence image mask. The method returns + * synchronously once the callback returns. + * + * @param image An image to process. + * @param roi The region of interest for segmentation. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, roi: RegionOfInterest, + callback: SegmentationMaskCallback): void; + /** + * Performs interactive segmentation on the provided single image and invokes + * the callback with the response. The `roi` parameter is used to represent a + * user's region of interest for segmentation. + * + * The 'image_processing_options' parameter can be used to specify the + * rotation to apply to the image before performing segmentation, by setting + * its 'rotationDegrees' field. Note that specifying a region-of-interest + * using the 'regionOfInterest' field is NOT supported and will result in an + * error. + * + * If the output_type is `CATEGORY_MASK`, the callback is invoked with vector + * of images that represent per-category segmented image mask. If the + * output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of + * images that contains only one confidence image mask. The method returns + * synchronously once the callback returns. + * + * @param image An image to process. + * @param roi The region of interest for segmentation. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, roi: RegionOfInterest, + imageProcessingOptions: ImageProcessingOptions, + callback: SegmentationMaskCallback): void; + segment( + image: ImageSource, roi: RegionOfInterest, + imageProcessingOptionsOrCallback: ImageProcessingOptions| + SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof imageProcessingOptionsOrCallback !== 'function' ? + imageProcessingOptionsOrCallback : + {}; + + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + imageProcessingOptionsOrCallback : + callback!; + + this.processRenderData(roi, this.getSynctheticTimestamp()); + this.processImageData(image, imageProcessingOptions); + this.userCallback = () => {}; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_IN_STREAM); + graphConfig.addInputStream(ROI_IN_STREAM); + graphConfig.addInputStream(NORM_RECT_IN_STREAM); + graphConfig.addOutputStream(IMAGE_OUT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageSegmenterGraphOptionsProto.ext, this.options); + + const segmenterNode = new CalculatorGraphConfig.Node(); + segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH); + segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM); + segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM); + segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM); + segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM); + segmenterNode.setOptions(calculatorOptions); + + graphConfig.addNode(segmenterNode); + + this.graphRunner.attachImageVectorListener( + IMAGE_OUT_STREAM, (masks, timestamp) => { + if (masks.length === 0) { + this.userCallback([], 0, 0); + } else { + this.userCallback( + masks.map(m => m.data), masks[0].width, masks[0].height); + } + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } + + /** + * Converts the user-facing RegionOfInterest message to the RenderData proto + * and sends it to the graph + */ + private processRenderData(roi: RegionOfInterest, timestamp: number): void { + const renderData = new RenderDataProto(); + + const renderAnnotation = new RenderAnnotationProto(); + + const color = new ColorProto(); + color.setR(255); + renderAnnotation.setColor(color); + + const point = new RenderAnnotationProto.Point(); + point.setNormalized(true); + point.setX(roi.keypoint.x); + point.setY(roi.keypoint.y); + renderAnnotation.setPoint(point); + + renderData.addRenderAnnotations(renderAnnotation); + + this.graphRunner.addProtoToStream( + renderData.serializeBinary(), 'mediapipe.RenderData', ROI_IN_STREAM, + timestamp); + } +} + + diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts new file mode 100644 index 000000000..beb43cd81 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Interactive Segmenter Task */ +export interface InteractiveSegmenterOptions extends TaskRunnerOptions { + /** + * The output type of segmentation results. + * + * The two supported modes are: + * - Category Mask: Gives a single output mask where each pixel represents + * the class which the pixel in the original image was + * predicted to belong to. + * - Confidence Mask: Gives a list of output masks (one for each class). For + * each mask, the pixel represents the prediction + * confidence, usually in the [0.0, 0.1] range. + * + * Defaults to `CATEGORY_MASK`. + */ + outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; +} diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts new file mode 100644 index 000000000..4be9f7d37 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -0,0 +1,214 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; +import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; + +import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter'; + + +const ROI: RegionOfInterest = { + keypoint: {x: 0.1, y: 0.2} +}; + +class InteractiveSegmenterFake extends InteractiveSegmenter implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + imageVectorListener: + ((images: WasmImage[], timestamp: number) => void)|undefined; + lastRoi?: RenderDataProto; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachImageVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('image_out'); + this.imageVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + + spyOn(this.graphRunner, 'addProtoToStream') + .and.callFake((data, protoName, stream) => { + if (stream === 'roi_in') { + expect(protoName).toEqual('mediapipe.RenderData'); + this.lastRoi = RenderDataProto.deserializeBinary(data); + } + }); + } +} + +describe('InteractiveSegmenter', () => { + let interactiveSegmenter: InteractiveSegmenterFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + interactiveSegmenter = new InteractiveSegmenterFake(); + await interactiveSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(interactiveSegmenter); + verifyListenersRegistered(interactiveSegmenter); + }); + + it('reloads graph when settings are changed', async () => { + await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyListenersRegistered(interactiveSegmenter); + + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]); + verifyListenersRegistered(interactiveSegmenter); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await interactiveSegmenter.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + interactiveSegmenter, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + + describe('setOptions()', () => { + const fieldPath = ['segmenterOptions', 'outputType']; + + it(`can set outputType`, async () => { + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + verifyGraph(interactiveSegmenter, [fieldPath, 2]); + }); + + it(`can clear outputType`, async () => { + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + verifyGraph(interactiveSegmenter, [fieldPath, 2]); + await interactiveSegmenter.setOptions({outputType: undefined}); + verifyGraph(interactiveSegmenter, [fieldPath, 1]); + }); + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + interactiveSegmenter.segment( + {} as HTMLImageElement, ROI, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('sends region-of-interest', (done) => { + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(interactiveSegmenter.lastRoi).toBeDefined(); + expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0]) + .toEqual(jasmine.objectContaining({ + color: {r: 255, b: undefined, g: undefined}, + })); + done(); + }); + + interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); + }); + + it('supports category masks', (done) => { + const mask = new Uint8Array([1, 2, 3, 4]); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(interactiveSegmenter); + interactiveSegmenter.imageVectorListener!( + [ + {data: mask, width: 2, height: 2}, + ], + /* timestamp= */ 1337); + }); + + // Invoke the image segmenter + interactiveSegmenter.segment( + {} as HTMLImageElement, ROI, (masks, width, height) => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(masks).toHaveSize(1); + expect(masks[0]).toEqual(mask); + expect(width).toEqual(2); + expect(height).toEqual(2); + done(); + }); + }); + + it('supports confidence masks', async () => { + const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); + const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); + + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(interactiveSegmenter); + interactiveSegmenter.imageVectorListener!( + [ + {data: mask1, width: 2, height: 2}, + {data: mask2, width: 2, height: 2}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + interactiveSegmenter.segment( + {} as HTMLImageElement, ROI, (masks, width, height) => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(masks).toHaveSize(2); + expect(masks[0]).toEqual(mask1); + expect(masks[1]).toEqual(mask2); + expect(width).toEqual(2); + expect(height).toEqual(2); + resolve(); + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index b9d951f60..fa6939460 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -20,4 +20,5 @@ export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/image_classifier/image_classifier'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/image_segmenter/image_segmenter'; +export * from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter'; export * from '../../../tasks/web/vision/object_detector/object_detector';