PiperOrigin-RevId: 529890599
This commit is contained in:
Sebastian Schmidt 2023-05-05 21:49:52 -07:00 committed by Copybara-Service
parent e707c84a3d
commit 6aad5742c3
17 changed files with 117 additions and 68 deletions

View File

@ -87,6 +87,7 @@ mediapipe_ts_library(
deps = [
":image",
":image_processing_options",
":mask",
":vision_task_options",
"//mediapipe/framework/formats:rect_jspb_proto",
"//mediapipe/tasks/web/core",

View File

@ -16,8 +16,6 @@
* limitations under the License.
*/
import {MPImageChannelConverter} from '../../../../tasks/web/vision/core/image';
// Pre-baked color table for a maximum of 12 classes.
const CM_ALPHA = 128;
const COLOR_MAP: Array<[number, number, number, number]> = [
@ -35,8 +33,37 @@ const COLOR_MAP: Array<[number, number, number, number]> = [
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
];
/** The color converter we use in our demos. */
export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = {
floatToRGBAConverter: v => [128, 0, 0, v * 255],
uint8ToRGBAConverter: v => COLOR_MAP[v % COLOR_MAP.length],
};
/** Helper function to draw a confidence mask */
export function drawConfidenceMask(
ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
height: number): void {
const uint8Array = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < image.length; i++) {
uint8Array[4 * i] = 128;
uint8Array[4 * i + 1] = 0;
uint8Array[4 * i + 2] = 0;
uint8Array[4 * i + 3] = image[i] * 255;
}
ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
}
/**
* Helper function to draw a category mask. For GPU, we only have F32Arrays
* for now.
*/
export function drawCategoryMask(
ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array,
width: number, height: number): void {
const rgbaArray = new Uint8ClampedArray(width * height * 4);
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex % COLOR_MAP.length];
rgbaArray[4 * i] = color[0];
rgbaArray[4 * i + 1] = color[1];
rgbaArray[4 * i + 2] = color[2];
rgbaArray[4 * i + 3] = color[3];
}
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
}

View File

@ -20,6 +20,7 @@ import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {isWebKit} from '../../../../web/graph_runner/platform_utils';
@ -226,7 +227,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
/**
* Converts a WasmImage to an MPImage.
*
* Converts the underlying Uint8ClampedArray-backed images to ImageData
* Converts the underlying Uint8Array-backed images to ImageData
* (adding an alpha channel if necessary), passes through WebGLTextures and
* throws for Float32Array-backed images.
*/
@ -235,11 +236,9 @@ export abstract class VisionTaskRunner extends TaskRunner {
const {data, width, height} = wasmImage;
const pixels = width * height;
let container: ImageData|WebGLTexture|Uint8ClampedArray;
if (data instanceof Uint8ClampedArray) {
if (data.length === pixels) {
container = data; // Mask
} else if (data.length === pixels * 3) {
let container: ImageData|WebGLTexture;
if (data instanceof Uint8Array) {
if (data.length === pixels * 3) {
// TODO: Convert in C++
const rgba = new Uint8ClampedArray(pixels * 4);
for (let i = 0; i < pixels; ++i) {
@ -250,18 +249,16 @@ export abstract class VisionTaskRunner extends TaskRunner {
}
container = new ImageData(rgba, width, height);
} else if (data.length === pixels * 4) {
container = new ImageData(data, width, height);
container = new ImageData(
new Uint8ClampedArray(data.buffer, data.byteOffset, data.length),
width, height);
} else {
throw new Error(`Unsupported channel count: ${data.length/pixels}`);
}
} else if (data instanceof Float32Array) {
if (data.length === pixels) {
container = data; // Mask
} else {
throw new Error(`Unsupported channel count: ${data.length/pixels}`);
}
} else { // WebGLTexture
} else if (data instanceof WebGLTexture) {
container = data;
} else {
throw new Error(`Unsupported format: ${data.constructor.name}`);
}
const image = new MPImage(
@ -271,6 +268,30 @@ export abstract class VisionTaskRunner extends TaskRunner {
return shouldCopyData ? image.clone() : image;
}
/** Converts a WasmImage to an MPMask. */
protected convertToMPMask(wasmImage: WasmImage, shouldCopyData: boolean):
MPMask {
const {data, width, height} = wasmImage;
const pixels = width * height;
let container: WebGLTexture|Uint8Array|Float32Array;
if (data instanceof Uint8Array || data instanceof Float32Array) {
if (data.length === pixels) {
container = data;
} else {
throw new Error(`Unsupported channel count: ${data.length / pixels}`);
}
} else {
container = data;
}
const mask = new MPMask(
[container],
/* ownsWebGLTexture= */ false, this.graphRunner.wasmModule.canvas!,
this.shaderContext, width, height);
return shouldCopyData ? mask.clone() : mask;
}
/** Closes and cleans up the resources held by this task. */
override close(): void {
this.shaderContext.close();

View File

@ -109,7 +109,7 @@ describe('FaceStylizer', () => {
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer);
faceStylizer.imageListener!
({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1},
({data: new Uint8Array([1, 1, 1, 1]), width: 1, height: 1},
/* timestamp= */ 1337);
});
@ -134,7 +134,7 @@ describe('FaceStylizer', () => {
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer);
faceStylizer.imageListener!
({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1},
({data: new Uint8Array([1, 1, 1, 1]), width: 1, height: 1},
/* timestamp= */ 1337);
});

View File

@ -35,7 +35,7 @@ mediapipe_ts_declaration(
deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
@ -52,7 +52,7 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
],
)

View File

@ -412,7 +412,7 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map(
wasmImage => this.convertToMPImage(
wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
@ -431,7 +431,7 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage(
this.result.categoryMask = this.convertToMPMask(
mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
/** The output result of ImageSegmenter. */
export declare interface ImageSegmenterResult {
@ -23,12 +23,12 @@ export declare interface ImageSegmenterResult {
* `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range.
*/
confidenceMasks?: MPImage[];
confidenceMasks?: MPMask[];
/**
* A category mask represented as a `Uint8ClampedArray` or
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which
* the pixel in the original image was predicted to belong to.
*/
categoryMask?: MPImage;
categoryMask?: MPMask;
}

View File

@ -19,8 +19,8 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {ImageSegmenter} from './image_segmenter';
import {ImageSegmenterOptions} from './image_segmenter_options';
@ -165,7 +165,7 @@ describe('ImageSegmenter', () => {
});
it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
const mask = new Uint8Array([1, 2, 3, 4]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
@ -183,7 +183,7 @@ describe('ImageSegmenter', () => {
return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage);
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks).not.toBeDefined();
expect(result.categoryMask!.width).toEqual(2);
expect(result.categoryMask!.height).toEqual(2);
@ -216,18 +216,18 @@ describe('ImageSegmenter', () => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0].width).toEqual(2);
expect(result.confidenceMasks![0].height).toEqual(2);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1]);
const categoryMask = new Uint8Array([1]);
const confidenceMask1 = new Float32Array([0.0]);
const confidenceMask2 = new Float32Array([1.0]);
@ -252,19 +252,19 @@ describe('ImageSegmenter', () => {
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage);
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.categoryMask!.width).toEqual(1);
expect(result.categoryMask!.height).toEqual(1);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve();
});
});
});
it('invokes listener once masks are available', async () => {
const categoryMask = new Uint8ClampedArray([1]);
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false;
@ -306,7 +306,7 @@ describe('ImageSegmenter', () => {
});
const result = imageSegmenter.segment({} as HTMLImageElement);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
result.confidenceMasks![0].close();
});
});

View File

@ -37,7 +37,7 @@ mediapipe_ts_declaration(
deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
@ -54,7 +54,7 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/util:render_data_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
],

View File

@ -328,7 +328,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map(
wasmImage => this.convertToMPImage(
wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
@ -347,7 +347,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage(
this.result.categoryMask = this.convertToMPMask(
mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
/** The output result of InteractiveSegmenter. */
export declare interface InteractiveSegmenterResult {
@ -23,12 +23,12 @@ export declare interface InteractiveSegmenterResult {
* `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range.
*/
confidenceMasks?: MPImage[];
confidenceMasks?: MPMask[];
/**
* A category mask represented as a `Uint8ClampedArray` or
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which
* the pixel in the original image was predicted to belong to.
*/
categoryMask?: MPImage;
categoryMask?: MPMask;
}

View File

@ -19,7 +19,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
@ -177,7 +177,7 @@ describe('InteractiveSegmenter', () => {
});
it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
const mask = new Uint8Array([1, 2, 3, 4]);
await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
@ -195,7 +195,7 @@ describe('InteractiveSegmenter', () => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage);
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.categoryMask!.width).toEqual(2);
expect(result.categoryMask!.height).toEqual(2);
expect(result.confidenceMasks).not.toBeDefined();
@ -228,18 +228,18 @@ describe('InteractiveSegmenter', () => {
.toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0].width).toEqual(2);
expect(result.confidenceMasks![0].height).toEqual(2);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1]);
const categoryMask = new Uint8Array([1]);
const confidenceMask1 = new Float32Array([0.0]);
const confidenceMask2 = new Float32Array([1.0]);
@ -266,19 +266,19 @@ describe('InteractiveSegmenter', () => {
{} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage);
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.categoryMask!.width).toEqual(1);
expect(result.categoryMask!.height).toEqual(1);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve();
});
});
});
it('invokes listener once masks are avaiblae', async () => {
const categoryMask = new Uint8ClampedArray([1]);
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false;
@ -321,7 +321,7 @@ describe('InteractiveSegmenter', () => {
const result =
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
result.confidenceMasks![0].close();
});
});

View File

@ -45,7 +45,7 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
@ -63,7 +63,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
],
)

View File

@ -504,7 +504,7 @@ export class PoseLandmarker extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener(
SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
this.result.segmentationMasks = masks.map(
wasmImage => this.convertToMPImage(
wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();

View File

@ -16,7 +16,7 @@
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
export {Category, Landmark, NormalizedLandmark};
@ -35,5 +35,5 @@ export declare interface PoseLandmarkerResult {
auxilaryLandmarks: NormalizedLandmark[][];
/** Segmentation mask for the detected pose. */
segmentationMasks?: MPImage[];
segmentationMasks?: MPMask[];
}

View File

@ -18,7 +18,7 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {PoseLandmarker} from './pose_landmarker';
@ -225,7 +225,7 @@ describe('PoseLandmarker', () => {
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.segmentationMasks![0]).toBeInstanceOf(MPImage);
expect(result.segmentationMasks![0]).toBeInstanceOf(MPMask);
done();
});
});

View File

@ -10,7 +10,7 @@ type LibConstructor = new (...args: any[]) => GraphRunner;
/** An image returned from a MediaPipe graph. */
export interface WasmImage {
data: Uint8ClampedArray|Float32Array|WebGLTexture;
data: Uint8Array|Float32Array|WebGLTexture;
width: number;
height: number;
}