Update InteractiveSegmenter to return MPImage

PiperOrigin-RevId: 528010944
This commit is contained in:
Sebastian Schmidt 2023-04-28 16:11:34 -07:00 committed by Copybara-Service
parent bbbc0f98c5
commit dcef6df1cb
6 changed files with 40 additions and 81 deletions

View File

@ -35,7 +35,6 @@ const COLOR_MAP: Array<[number, number, number, number]> = [
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
]; ];
/** Helper function to draw a confidence mask */ /** Helper function to draw a confidence mask */
export function drawConfidenceMask( export function drawConfidenceMask(
ctx: CanvasRenderingContext2D, image: Float32Array, width: number, ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
@ -50,33 +49,6 @@ export function drawConfidenceMask(
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0); ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
} }
/**
* Helper function to draw a category mask. For GPU, we only have F32Arrays
* for now.
*/
export function drawCategoryMask(
ctx: CanvasRenderingContext2D, image: Uint8ClampedArray|Float32Array,
width: number, height: number): void {
const rgbaArray = new Uint8ClampedArray(width * height * 4);
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
if (!color) {
// TODO: We should fix this.
console.warn('No color for ', colorIndex);
color = COLOR_MAP[colorIndex % COLOR_MAP.length];
}
rgbaArray[4 * i] = color[0];
rgbaArray[4 * i + 1] = color[1];
rgbaArray[4 * i + 2] = color[2];
rgbaArray[4 * i + 3] = color[3];
}
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
}
/** The color converter we use in our demos. */ /** The color converter we use in our demos. */
export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = { export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = {
floatToRGBAConverter: v => [128, 0, 0, v * 255], floatToRGBAConverter: v => [128, 0, 0, v * 255],

View File

@ -16,16 +16,6 @@
import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint';
/**
* The segmentation tasks return the segmentation either as a WebGLTexture (when
* the output is on GPU) or as a typed JavaScript arrays for CPU-based
* category or confidence masks. `Uint8ClampedArray`s are used to represent
* CPU-based category masks and `Float32Array`s are used for CPU-based
* confidence masks.
*/
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
/** A Region-Of-Interest (ROI) to represent a region within an image. */ /** A Region-Of-Interest (ROI) to represent a region within an image. */
export declare interface RegionOfInterest { export declare interface RegionOfInterest {
/** The ROI in keypoint format. */ /** The ROI in keypoint format. */

View File

@ -37,6 +37,7 @@ mediapipe_ts_declaration(
deps = [ deps = [
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_options",
], ],
) )
@ -53,6 +54,7 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/util:render_data_jspb_proto", "//mediapipe/util:render_data_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
], ],

View File

@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {RegionOfInterest} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {Color as ColorProto} from '../../../../util/color_pb'; import {Color as ColorProto} from '../../../../util/color_pb';
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
@ -33,7 +33,7 @@ import {InteractiveSegmenterResult} from './interactive_segmenter_result';
export * from './interactive_segmenter_options'; export * from './interactive_segmenter_options';
export * from './interactive_segmenter_result'; export * from './interactive_segmenter_result';
export {SegmentationMask, RegionOfInterest}; export {RegionOfInterest};
export {ImageSource}; export {ImageSource};
const IMAGE_IN_STREAM = 'image_in'; const IMAGE_IN_STREAM = 'image_in';
@ -83,7 +83,7 @@ export type InteractiveSegmenterCallback =
* - batch is always 1 * - batch is always 1
*/ */
export class InteractiveSegmenter extends VisionTaskRunner { export class InteractiveSegmenter extends VisionTaskRunner {
private result: InteractiveSegmenterResult = {width: 0, height: 0}; private result: InteractiveSegmenterResult = {};
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
@ -253,7 +253,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
} }
private reset(): void { private reset(): void {
this.result = {width: 0, height: 0}; this.result = {};
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -283,12 +283,8 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map(m => m.data); this.result.confidenceMasks =
if (masks.length >= 0) { masks.map(wasmImage => this.convertToMPImage(wasmImage));
this.result.width = masks[0].width;
this.result.height = masks[0].height;
}
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
@ -303,9 +299,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = mask.data; this.result.categoryMask = this.convertToMPImage(mask);
this.result.width = mask.width;
this.result.height = mask.height;
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(

View File

@ -14,24 +14,21 @@
* limitations under the License. * limitations under the License.
*/ */
import {MPImage} from '../../../../tasks/web/vision/core/image';
/** The output result of InteractiveSegmenter. */ /** The output result of InteractiveSegmenter. */
export declare interface InteractiveSegmenterResult { export declare interface InteractiveSegmenterResult {
/** /**
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
* pixel represents the prediction confidence, usually in the [0, 1] range. * `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range.
*/ */
confidenceMasks?: Float32Array[]|WebGLTexture[]; confidenceMasks?: MPImage[];
/** /**
* A category mask as a Uint8ClampedArray or WebGLTexture where each * A category mask represented as a `Uint8ClampedArray` or
* pixel represents the class which the pixel in the original image was * `WebGLTexture`-backed `MPImage` where each pixel represents the class which
* predicted to belong to. * the pixel in the original image was predicted to belong to.
*/ */
categoryMask?: Uint8ClampedArray|WebGLTexture; categoryMask?: MPImage;
/** The width of the masks. */
width: number;
/** The height of the masks. */
height: number;
} }

View File

@ -19,6 +19,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
@ -170,10 +171,10 @@ describe('InteractiveSegmenter', () => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).toEqual(mask); expect(result.categoryMask).toBeInstanceOf(MPImage);
expect(result.categoryMask!.width).toEqual(2);
expect(result.categoryMask!.height).toEqual(2);
expect(result.confidenceMasks).not.toBeDefined(); expect(result.confidenceMasks).not.toBeDefined();
expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve(); resolve();
}); });
}); });
@ -202,18 +203,21 @@ describe('InteractiveSegmenter', () => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined(); expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks).toEqual([mask1, mask2]);
expect(result.width).toEqual(2); expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.height).toEqual(2); expect(result.confidenceMasks![0].width).toEqual(2);
expect(result.confidenceMasks![0].height).toEqual(2);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
resolve(); resolve();
}); });
}); });
}); });
it('supports combined category and confidence masks', async () => { it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1, 0]); const categoryMask = new Uint8ClampedArray([1]);
const confidenceMask1 = new Float32Array([0.0, 1.0]); const confidenceMask1 = new Float32Array([0.0]);
const confidenceMask2 = new Float32Array([1.0, 0.0]); const confidenceMask2 = new Float32Array([1.0]);
await interactiveSegmenter.setOptions( await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true}); {outputCategoryMask: true, outputConfidenceMasks: true});
@ -238,12 +242,12 @@ describe('InteractiveSegmenter', () => {
{} as HTMLImageElement, ROI, result => { {} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).toEqual(categoryMask); expect(result.categoryMask).toBeInstanceOf(MPImage);
expect(result.confidenceMasks).toEqual([ expect(result.categoryMask!.width).toEqual(1);
confidenceMask1, confidenceMask2 expect(result.categoryMask!.height).toEqual(1);
]);
expect(result.width).toEqual(1); expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.height).toEqual(1); expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
resolve(); resolve();
}); });
}); });