From dcef6df1cbf5be95fc7f74b714d328c3b73aa7a9 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 28 Apr 2023 16:11:34 -0700 Subject: [PATCH] Update InteractiveSegmenter to return MPImage PiperOrigin-RevId: 528010944 --- .../tasks/web/vision/core/render_utils.ts | 32 ++--------------- mediapipe/tasks/web/vision/core/types.d.ts | 10 ------ .../web/vision/interactive_segmenter/BUILD | 2 ++ .../interactive_segmenter.ts | 20 ++++------- .../interactive_segmenter_result.d.ts | 23 ++++++------- .../interactive_segmenter_test.ts | 34 +++++++++++-------- 6 files changed, 40 insertions(+), 81 deletions(-) diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts index 892cd8645..066494f57 100644 --- a/mediapipe/tasks/web/vision/core/render_utils.ts +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -35,11 +35,10 @@ const COLOR_MAP: Array<[number, number, number, number]> = [ [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? ]; - /** Helper function to draw a confidence mask */ export function drawConfidenceMask( - ctx: CanvasRenderingContext2D, image: Float32Array, width: number, - height: number): void { + ctx: CanvasRenderingContext2D, image: Float32Array, width: number, + height: number): void { const uint8ClampedArray = new Uint8ClampedArray(width * height * 4); for (let i = 0; i < image.length; i++) { uint8ClampedArray[4 * i] = 128; @@ -50,33 +49,6 @@ export function drawConfidenceMask( ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0); } -/** - * Helper function to draw a category mask. For GPU, we only have F32Arrays - * for now. - */ -export function drawCategoryMask( - ctx: CanvasRenderingContext2D, image: Uint8ClampedArray|Float32Array, - width: number, height: number): void { - const rgbaArray = new Uint8ClampedArray(width * height * 4); - const isFloatArray = image instanceof Float32Array; - for (let i = 0; i < image.length; i++) { - const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; - let color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - - if (!color) { - // TODO: We should fix this. - console.warn('No color for ', colorIndex); - color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - } - - rgbaArray[4 * i] = color[0]; - rgbaArray[4 * i + 1] = color[1]; - rgbaArray[4 * i + 2] = color[2]; - rgbaArray[4 * i + 3] = color[3]; - } - ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0); -} - /** The color converter we use in our demos. */ export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = { floatToRGBAConverter: v => [128, 0, 0, v * 255], diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index 1cc2e36fd..c985a9f36 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -16,16 +16,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; -/** - * The segmentation tasks return the segmentation either as a WebGLTexture (when - * the output is on GPU) or as a typed JavaScript arrays for CPU-based - * category or confidence masks. `Uint8ClampedArray`s are used to represent - * CPU-based category masks and `Float32Array`s are used for CPU-based - * confidence masks. - */ -export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture; - - /** A Region-Of-Interest (ROI) to represent a region within an image. */ export declare interface RegionOfInterest { /** The ROI in keypoint format. */ diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD index ead85d38a..c3be79ebf 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD @@ -37,6 +37,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) @@ -53,6 +54,7 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/tasks/web/vision/core:image", "//mediapipe/util:render_data_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", ], diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index 70d2c1f4e..f127792ce 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types'; +import {RegionOfInterest} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {Color as ColorProto} from '../../../../util/color_pb'; import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; @@ -33,7 +33,7 @@ import {InteractiveSegmenterResult} from './interactive_segmenter_result'; export * from './interactive_segmenter_options'; export * from './interactive_segmenter_result'; -export {SegmentationMask, RegionOfInterest}; +export {RegionOfInterest}; export {ImageSource}; const IMAGE_IN_STREAM = 'image_in'; @@ -83,7 +83,7 @@ export type InteractiveSegmenterCallback = * - batch is always 1 */ export class InteractiveSegmenter extends VisionTaskRunner { - private result: InteractiveSegmenterResult = {width: 0, height: 0}; + private result: InteractiveSegmenterResult = {}; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; @@ -253,7 +253,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { } private reset(): void { - this.result = {width: 0, height: 0}; + this.result = {}; } /** Updates the MediaPipe graph configuration. */ @@ -283,12 +283,8 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { - this.result.confidenceMasks = masks.map(m => m.data); - if (masks.length >= 0) { - this.result.width = masks[0].width; - this.result.height = masks[0].height; - } - + this.result.confidenceMasks = + masks.map(wasmImage => this.convertToMPImage(wasmImage)); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( @@ -303,9 +299,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { - this.result.categoryMask = mask.data; - this.result.width = mask.width; - this.result.height = mask.height; + this.result.categoryMask = this.convertToMPImage(mask); this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts index f1f134a77..bc2962936 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts @@ -14,24 +14,21 @@ * limitations under the License. */ +import {MPImage} from '../../../../tasks/web/vision/core/image'; + /** The output result of InteractiveSegmenter. */ export declare interface InteractiveSegmenterResult { /** - * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each - * pixel represents the prediction confidence, usually in the [0, 1] range. + * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed + * `MPImage`s where, for each mask, each pixel represents the prediction + * confidence, usually in the [0, 1] range. */ - confidenceMasks?: Float32Array[]|WebGLTexture[]; + confidenceMasks?: MPImage[]; /** - * A category mask as a Uint8ClampedArray or WebGLTexture where each - * pixel represents the class which the pixel in the original image was - * predicted to belong to. + * A category mask represented as a `Uint8ClampedArray` or + * `WebGLTexture`-backed `MPImage` where each pixel represents the class which + * the pixel in the original image was predicted to belong to. */ - categoryMask?: Uint8ClampedArray|WebGLTexture; - - /** The width of the masks. */ - width: number; - - /** The height of the masks. */ - height: number; + categoryMask?: MPImage; } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index 0a9477605..cbaf76630 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -19,6 +19,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; +import {MPImage} from '../../../../tasks/web/vision/core/image'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; @@ -170,10 +171,10 @@ describe('InteractiveSegmenter', () => { interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); - expect(result.categoryMask).toEqual(mask); + expect(result.categoryMask).toBeInstanceOf(MPImage); + expect(result.categoryMask!.width).toEqual(2); + expect(result.categoryMask!.height).toEqual(2); expect(result.confidenceMasks).not.toBeDefined(); - expect(result.width).toEqual(2); - expect(result.height).toEqual(2); resolve(); }); }); @@ -202,18 +203,21 @@ describe('InteractiveSegmenter', () => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); expect(result.categoryMask).not.toBeDefined(); - expect(result.confidenceMasks).toEqual([mask1, mask2]); - expect(result.width).toEqual(2); - expect(result.height).toEqual(2); + + expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); + expect(result.confidenceMasks![0].width).toEqual(2); + expect(result.confidenceMasks![0].height).toEqual(2); + + expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage); resolve(); }); }); }); it('supports combined category and confidence masks', async () => { - const categoryMask = new Uint8ClampedArray([1, 0]); - const confidenceMask1 = new Float32Array([0.0, 1.0]); - const confidenceMask2 = new Float32Array([1.0, 0.0]); + const categoryMask = new Uint8ClampedArray([1]); + const confidenceMask1 = new Float32Array([0.0]); + const confidenceMask2 = new Float32Array([1.0]); await interactiveSegmenter.setOptions( {outputCategoryMask: true, outputConfidenceMasks: true}); @@ -238,12 +242,12 @@ describe('InteractiveSegmenter', () => { {} as HTMLImageElement, ROI, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); - expect(result.categoryMask).toEqual(categoryMask); - expect(result.confidenceMasks).toEqual([ - confidenceMask1, confidenceMask2 - ]); - expect(result.width).toEqual(1); - expect(result.height).toEqual(1); + expect(result.categoryMask).toBeInstanceOf(MPImage); + expect(result.categoryMask!.width).toEqual(1); + expect(result.categoryMask!.height).toEqual(1); + + expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); + expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage); resolve(); }); });