From 4d38557f116853ce8e90457d61c56b795a6ba86b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 26 Jan 2023 12:30:05 -0800 Subject: [PATCH] Add MediaPipe Image Segmenter task for Web PiperOrigin-RevId: 504912518 --- mediapipe/tasks/web/vision/BUILD | 1 + mediapipe/tasks/web/vision/README.md | 17 + .../tasks/web/vision/image_segmenter/BUILD | 58 ++++ .../vision/image_segmenter/image_segmenter.ts | 300 ++++++++++++++++++ .../image_segmenter_options.d.ts | 41 +++ .../image_segmenter/image_segmenter_test.ts | 215 +++++++++++++ mediapipe/tasks/web/vision/index.ts | 3 + mediapipe/tasks/web/vision/types.ts | 1 + 8 files changed, 636 insertions(+) create mode 100644 mediapipe/tasks/web/vision/image_segmenter/BUILD create mode 100644 mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts create mode 100644 mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts create mode 100644 mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 8ba9c85b3..a229cbd2a 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -23,6 +23,7 @@ VISION_LIBS = [ "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/image_segmenter", "//mediapipe/tasks/web/vision/object_detector", ] diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index 51f43821c..9e86eafd3 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -39,6 +39,23 @@ const classifications = imageClassifier.classify(image); For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation. +## Image Segmentation + +The MediaPipe Image Segmenter lets you segment an image into categories. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const imageSegmenter = await ImageSegmenter.createFromModelPath(vision, + "model.tflite" +); +const image = document.getElementById("image") as HTMLImageElement; +imageSegmenter.segment(image, (masks, width, height) => { + ... +}); +``` + ## Gesture Recognition The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD new file mode 100644 index 000000000..d15fe63f1 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -0,0 +1,58 @@ +# This contains the MediaPipe Image Segmenter Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "image_segmenter", + srcs = ["image_segmenter.ts"], + deps = [ + ":image_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "image_segmenter_types", + srcs = ["image_segmenter_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "image_segmenter_test_lib", + testonly = True, + srcs = [ + "image_segmenter_test.ts", + ], + deps = [ + ":image_segmenter", + ":image_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + ], +) + +jasmine_node_test( + name = "image_segmenter_test", + tags = ["nomsan"], + deps = [":image_segmenter_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts new file mode 100644 index 000000000..4f81977eb --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -0,0 +1,300 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; +import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {ImageSegmenterOptions} from './image_segmenter_options'; + +export * from './image_segmenter_options'; +export {ImageSource}; // Used in the public API + +/** + * The ImageSegmenter returns the segmentation result as a Uint8Array (when + * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for + * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved + * for future usage. + */ +export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture; + +/** + * A callback that receives the computed masks from the image segmenter. The + * callback either receives a single element array with a category mask (as a + * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`). + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type SegmentationMaskCallback = + (masks: SegmentationMask[], width: number, height: number) => void; + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; +const IMAGEA_SEGMENTER_GRAPH = + 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Performs image segmentation on images. */ +export class ImageSegmenter extends VisionTaskRunner { + private userCallback: SegmentationMaskCallback = () => {}; + private readonly options: ImageSegmenterGraphOptionsProto; + private readonly segmenterOptions: SegmenterOptionsProto; + + /** + * Initializes the Wasm runtime and creates a new image segmenter from the + * provided options. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param imageSegmenterOptions The options for the Image Segmenter. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + imageSegmenterOptions: ImageSegmenterOptions): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + imageSegmenterOptions); + } + + /** + * Initializes the Wasm runtime and creates a new image segmenter based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new image segmenter based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + ImageSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ false); + this.options = new ImageSegmenterGraphOptionsProto(); + this.segmenterOptions = new SegmenterOptionsProto(); + this.options.setSegmenterOptions(this.segmenterOptions); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the image segmenter. + * + * Calling `setOptions()` with a subset of options only affects those + * options. You can reset an option back to its default value by + * explicitly setting it to `undefined`. + * + * @param options The options for the image segmenter. + */ + override setOptions(options: ImageSegmenterOptions): Promise { + // Note that we have to support both JSPB and ProtobufJS, hence we + // have to expliclity clear the values instead of setting them to + // `undefined`. + if (options.displayNamesLocale !== undefined) { + this.options.setDisplayNamesLocale(options.displayNamesLocale); + } else if ('displayNamesLocale' in options) { // Check for undefined + this.options.clearDisplayNamesLocale(); + } + + if (options.outputType === 'CONFIDENCE_MASK') { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); + } else { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CATEGORY_MASK); + } + + return super.applyOptions(options); + } + + /** + * Performs image segmentation on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `image`. + * + * @param image An image to process. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment(image: ImageSource, callback: SegmentationMaskCallback): void; + /** + * Performs image segmentation on the provided single image and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `image`. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, imageProcessingOptions: ImageProcessingOptions, + callback: SegmentationMaskCallback): void; + segment( + image: ImageSource, + imageProcessingOptionsOrCallback: ImageProcessingOptions| + SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof imageProcessingOptionsOrCallback !== 'function' ? + imageProcessingOptionsOrCallback : + {}; + + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + imageProcessingOptionsOrCallback : + callback!; + this.processImageData(image, imageProcessingOptions); + this.userCallback = () => {}; + } + + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, timestamp: number, + callback: SegmentationMaskCallback): void; + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: SegmentationMaskCallback): void; + segmentForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + + this.userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + this.userCallback = () => {}; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageSegmenterGraphOptionsProto.ext, this.options); + + const segmenterNode = new CalculatorGraphConfig.Node(); + segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH); + segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); + segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + segmenterNode.addOutputStream( + 'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM); + segmenterNode.setOptions(calculatorOptions); + + graphConfig.addNode(segmenterNode); + + this.graphRunner.attachImageVectorListener( + GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { + if (masks.length === 0) { + this.userCallback([], 0, 0); + } else { + this.userCallback( + masks.map(m => m.data), masks[0].width, masks[0].height); + } + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts new file mode 100644 index 000000000..c17e7e421 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Image Segmenter Task */ +export interface ImageSegmenterOptions extends VisionTaskOptions { + /** + * The locale to use for display names specified through the TFLite Model + * Metadata, if any. Defaults to English. + */ + displayNamesLocale?: string|undefined; + + /** + * The output type of segmentation results. + * + * The two supported modes are: + * - Category Mask: Gives a single output mask where each pixel represents + * the class which the pixel in the original image was + * predicted to belong to. + * - Confidence Mask: Gives a list of output masks (one for each class). For + * each mask, the pixel represents the prediction + * confidence, usually in the [0.0, 0.1] range. + * + * Defaults to `CATEGORY_MASK`. + */ + outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; +} diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts new file mode 100644 index 000000000..aa81be025 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -0,0 +1,215 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; + +import {ImageSegmenter} from './image_segmenter'; +import {ImageSegmenterOptions} from './image_segmenter_options'; + +class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + imageVectorListener: + ((images: WasmImage[], timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachImageVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('segmented_masks'); + this.imageVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('ImageSegmenter', () => { + let imageSegmenter: ImageSegmenterFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + imageSegmenter = new ImageSegmenterFake(); + await imageSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(imageSegmenter); + verifyListenersRegistered(imageSegmenter); + }); + + it('reloads graph when settings are changed', async () => { + await imageSegmenter.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); + verifyListenersRegistered(imageSegmenter); + + await imageSegmenter.setOptions({displayNamesLocale: 'de'}); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); + verifyListenersRegistered(imageSegmenter); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await imageSegmenter.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + imageSegmenter, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('merges options', async () => { + await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + await imageSegmenter.setOptions({displayNamesLocale: 'en'}); + verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); + }); + + describe('setOptions()', () => { + interface TestCase { + optionName: keyof ImageSegmenterOptions; + fieldPath: string[]; + userValue: unknown; + graphValue: unknown; + defaultValue: unknown; + } + + const testCases: TestCase[] = [ + { + optionName: 'displayNamesLocale', + fieldPath: ['displayNamesLocale'], + userValue: 'en', + graphValue: 'en', + defaultValue: 'en' + }, + { + optionName: 'outputType', + fieldPath: ['segmenterOptions', 'outputType'], + userValue: 'CONFIDENCE_MASK', + graphValue: 2, + defaultValue: 1 + }, + ]; + + for (const testCase of testCases) { + it(`can set ${testCase.optionName}`, async () => { + await imageSegmenter.setOptions( + {[testCase.optionName]: testCase.userValue}); + verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]); + }); + + it(`can clear ${testCase.optionName}`, async () => { + await imageSegmenter.setOptions( + {[testCase.optionName]: testCase.userValue}); + verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]); + await imageSegmenter.setOptions({[testCase.optionName]: undefined}); + verifyGraph( + imageSegmenter, [testCase.fieldPath, testCase.defaultValue]); + }); + } + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + imageSegmenter.segment( + {} as HTMLImageElement, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('supports category masks', (done) => { + const mask = new Uint8Array([1, 2, 3, 4]); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageSegmenter); + imageSegmenter.imageVectorListener!( + [ + {data: mask, width: 2, height: 2}, + ], + /* timestamp= */ 1337); + }); + + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(masks).toHaveSize(1); + expect(masks[0]).toEqual(mask); + expect(width).toEqual(2); + expect(height).toEqual(2); + done(); + }); + }); + + it('supports confidence masks', async () => { + const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); + const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); + + await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(imageSegmenter); + imageSegmenter.imageVectorListener!( + [ + {data: mask1, width: 2, height: 2}, + {data: mask2, width: 2, height: 2}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(masks).toHaveSize(2); + expect(masks[0]).toEqual(mask1); + expect(masks[1]).toEqual(mask2); + expect(width).toEqual(2); + expect(height).toEqual(2); + resolve(); + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 49f23c243..5a87c7a82 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -19,6 +19,7 @@ import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vis import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; +import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; // Declare the variables locally so that Rollup in OSS includes them explicitly @@ -28,6 +29,7 @@ const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; const ImageEmbedder = ImageEmbedderImpl; +const ImageSegmenter = ImageSegementerImpl; const ObjectDetector = ObjectDetectorImpl; export { @@ -36,5 +38,6 @@ export { HandLandmarker, ImageClassifier, ImageEmbedder, + ImageSegmenter, ObjectDetector }; diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index dd1f58294..b9d951f60 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -19,4 +19,5 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/image_classifier/image_classifier'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/image_segmenter/image_segmenter'; export * from '../../../tasks/web/vision/object_detector/object_detector';