PiperOrigin-RevId: 529890599
This commit is contained in:
Sebastian Schmidt 2023-05-05 21:49:52 -07:00 committed by Copybara-Service
parent e707c84a3d
commit 6aad5742c3
17 changed files with 117 additions and 68 deletions

View File

@ -87,6 +87,7 @@ mediapipe_ts_library(
deps = [ deps = [
":image", ":image",
":image_processing_options", ":image_processing_options",
":mask",
":vision_task_options", ":vision_task_options",
"//mediapipe/framework/formats:rect_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",

View File

@ -16,8 +16,6 @@
* limitations under the License. * limitations under the License.
*/ */
import {MPImageChannelConverter} from '../../../../tasks/web/vision/core/image';
// Pre-baked color table for a maximum of 12 classes. // Pre-baked color table for a maximum of 12 classes.
const CM_ALPHA = 128; const CM_ALPHA = 128;
const COLOR_MAP: Array<[number, number, number, number]> = [ const COLOR_MAP: Array<[number, number, number, number]> = [
@ -35,8 +33,37 @@ const COLOR_MAP: Array<[number, number, number, number]> = [
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
]; ];
/** The color converter we use in our demos. */
export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = { /** Helper function to draw a confidence mask */
floatToRGBAConverter: v => [128, 0, 0, v * 255], export function drawConfidenceMask(
uint8ToRGBAConverter: v => COLOR_MAP[v % COLOR_MAP.length], ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
}; height: number): void {
const uint8Array = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < image.length; i++) {
uint8Array[4 * i] = 128;
uint8Array[4 * i + 1] = 0;
uint8Array[4 * i + 2] = 0;
uint8Array[4 * i + 3] = image[i] * 255;
}
ctx.putImageData(new ImageData(uint8Array, width, height), 0, 0);
}
/**
* Helper function to draw a category mask. For GPU, we only have F32Arrays
* for now.
*/
export function drawCategoryMask(
ctx: CanvasRenderingContext2D, image: Uint8Array|Float32Array,
width: number, height: number): void {
const rgbaArray = new Uint8ClampedArray(width * height * 4);
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex % COLOR_MAP.length];
rgbaArray[4 * i] = color[0];
rgbaArray[4 * i + 1] = color[1];
rgbaArray[4 * i + 2] = color[2];
rgbaArray[4 * i + 3] = color[3];
}
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
}

View File

@ -20,6 +20,7 @@ import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {MPImage} from '../../../../tasks/web/vision/core/image'; import {MPImage} from '../../../../tasks/web/vision/core/image';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner'; import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {isWebKit} from '../../../../web/graph_runner/platform_utils'; import {isWebKit} from '../../../../web/graph_runner/platform_utils';
@ -226,7 +227,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
/** /**
* Converts a WasmImage to an MPImage. * Converts a WasmImage to an MPImage.
* *
* Converts the underlying Uint8ClampedArray-backed images to ImageData * Converts the underlying Uint8Array-backed images to ImageData
* (adding an alpha channel if necessary), passes through WebGLTextures and * (adding an alpha channel if necessary), passes through WebGLTextures and
* throws for Float32Array-backed images. * throws for Float32Array-backed images.
*/ */
@ -235,11 +236,9 @@ export abstract class VisionTaskRunner extends TaskRunner {
const {data, width, height} = wasmImage; const {data, width, height} = wasmImage;
const pixels = width * height; const pixels = width * height;
let container: ImageData|WebGLTexture|Uint8ClampedArray; let container: ImageData|WebGLTexture;
if (data instanceof Uint8ClampedArray) { if (data instanceof Uint8Array) {
if (data.length === pixels) { if (data.length === pixels * 3) {
container = data; // Mask
} else if (data.length === pixels * 3) {
// TODO: Convert in C++ // TODO: Convert in C++
const rgba = new Uint8ClampedArray(pixels * 4); const rgba = new Uint8ClampedArray(pixels * 4);
for (let i = 0; i < pixels; ++i) { for (let i = 0; i < pixels; ++i) {
@ -249,19 +248,17 @@ export abstract class VisionTaskRunner extends TaskRunner {
rgba[4 * i + 3] = 255; rgba[4 * i + 3] = 255;
} }
container = new ImageData(rgba, width, height); container = new ImageData(rgba, width, height);
} else if (data.length ===pixels * 4) { } else if (data.length === pixels * 4) {
container = new ImageData(data, width, height); container = new ImageData(
new Uint8ClampedArray(data.buffer, data.byteOffset, data.length),
width, height);
} else { } else {
throw new Error(`Unsupported channel count: ${data.length/pixels}`); throw new Error(`Unsupported channel count: ${data.length/pixels}`);
} }
} else if (data instanceof Float32Array) { } else if (data instanceof WebGLTexture) {
if (data.length === pixels) {
container = data; // Mask
} else {
throw new Error(`Unsupported channel count: ${data.length/pixels}`);
}
} else { // WebGLTexture
container = data; container = data;
} else {
throw new Error(`Unsupported format: ${data.constructor.name}`);
} }
const image = new MPImage( const image = new MPImage(
@ -271,6 +268,30 @@ export abstract class VisionTaskRunner extends TaskRunner {
return shouldCopyData ? image.clone() : image; return shouldCopyData ? image.clone() : image;
} }
/** Converts a WasmImage to an MPMask. */
protected convertToMPMask(wasmImage: WasmImage, shouldCopyData: boolean):
MPMask {
const {data, width, height} = wasmImage;
const pixels = width * height;
let container: WebGLTexture|Uint8Array|Float32Array;
if (data instanceof Uint8Array || data instanceof Float32Array) {
if (data.length === pixels) {
container = data;
} else {
throw new Error(`Unsupported channel count: ${data.length / pixels}`);
}
} else {
container = data;
}
const mask = new MPMask(
[container],
/* ownsWebGLTexture= */ false, this.graphRunner.wasmModule.canvas!,
this.shaderContext, width, height);
return shouldCopyData ? mask.clone() : mask;
}
/** Closes and cleans up the resources held by this task. */ /** Closes and cleans up the resources held by this task. */
override close(): void { override close(): void {
this.shaderContext.close(); this.shaderContext.close();

View File

@ -109,7 +109,7 @@ describe('FaceStylizer', () => {
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer); verifyListenersRegistered(faceStylizer);
faceStylizer.imageListener! faceStylizer.imageListener!
({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1}, ({data: new Uint8Array([1, 1, 1, 1]), width: 1, height: 1},
/* timestamp= */ 1337); /* timestamp= */ 1337);
}); });
@ -134,7 +134,7 @@ describe('FaceStylizer', () => {
faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(faceStylizer); verifyListenersRegistered(faceStylizer);
faceStylizer.imageListener! faceStylizer.imageListener!
({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1}, ({data: new Uint8Array([1, 1, 1, 1]), width: 1, height: 1},
/* timestamp= */ 1337); /* timestamp= */ 1337);
}); });

View File

@ -35,7 +35,7 @@ mediapipe_ts_declaration(
deps = [ deps = [
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_options",
], ],
) )
@ -52,7 +52,7 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
], ],
) )

View File

@ -412,7 +412,7 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map( this.result.confidenceMasks = masks.map(
wasmImage => this.convertToMPImage( wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback)); wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
@ -431,7 +431,7 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage( this.result.categoryMask = this.convertToMPMask(
mask, /* shouldCopyData= */ !this.userCallback); mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {MPImage} from '../../../../tasks/web/vision/core/image'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
/** The output result of ImageSegmenter. */ /** The output result of ImageSegmenter. */
export declare interface ImageSegmenterResult { export declare interface ImageSegmenterResult {
@ -23,12 +23,12 @@ export declare interface ImageSegmenterResult {
* `MPImage`s where, for each mask, each pixel represents the prediction * `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range. * confidence, usually in the [0, 1] range.
*/ */
confidenceMasks?: MPImage[]; confidenceMasks?: MPMask[];
/** /**
* A category mask represented as a `Uint8ClampedArray` or * A category mask represented as a `Uint8ClampedArray` or
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which * `WebGLTexture`-backed `MPImage` where each pixel represents the class which
* the pixel in the original image was predicted to belong to. * the pixel in the original image was predicted to belong to.
*/ */
categoryMask?: MPImage; categoryMask?: MPMask;
} }

View File

@ -19,8 +19,8 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {ImageSegmenter} from './image_segmenter'; import {ImageSegmenter} from './image_segmenter';
import {ImageSegmenterOptions} from './image_segmenter_options'; import {ImageSegmenterOptions} from './image_segmenter_options';
@ -165,7 +165,7 @@ describe('ImageSegmenter', () => {
}); });
it('supports category mask', async () => { it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]); const mask = new Uint8Array([1, 2, 3, 4]);
await imageSegmenter.setOptions( await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false}); {outputCategoryMask: true, outputConfidenceMasks: false});
@ -183,7 +183,7 @@ describe('ImageSegmenter', () => {
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, result => { imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage); expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks).not.toBeDefined(); expect(result.confidenceMasks).not.toBeDefined();
expect(result.categoryMask!.width).toEqual(2); expect(result.categoryMask!.width).toEqual(2);
expect(result.categoryMask!.height).toEqual(2); expect(result.categoryMask!.height).toEqual(2);
@ -216,18 +216,18 @@ describe('ImageSegmenter', () => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined(); expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0].width).toEqual(2); expect(result.confidenceMasks![0].width).toEqual(2);
expect(result.confidenceMasks![0].height).toEqual(2); expect(result.confidenceMasks![0].height).toEqual(2);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve(); resolve();
}); });
}); });
}); });
it('supports combined category and confidence masks', async () => { it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask1 = new Float32Array([0.0]); const confidenceMask1 = new Float32Array([0.0]);
const confidenceMask2 = new Float32Array([1.0]); const confidenceMask2 = new Float32Array([1.0]);
@ -252,19 +252,19 @@ describe('ImageSegmenter', () => {
// Invoke the image segmenter // Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, result => { imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage); expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.categoryMask!.width).toEqual(1); expect(result.categoryMask!.width).toEqual(1);
expect(result.categoryMask!.height).toEqual(1); expect(result.categoryMask!.height).toEqual(1);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve(); resolve();
}); });
}); });
}); });
it('invokes listener once masks are available', async () => { it('invokes listener once masks are available', async () => {
const categoryMask = new Uint8ClampedArray([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false; let listenerCalled = false;
@ -306,7 +306,7 @@ describe('ImageSegmenter', () => {
}); });
const result = imageSegmenter.segment({} as HTMLImageElement); const result = imageSegmenter.segment({} as HTMLImageElement);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
result.confidenceMasks![0].close(); result.confidenceMasks![0].close();
}); });
}); });

View File

@ -37,7 +37,7 @@ mediapipe_ts_declaration(
deps = [ deps = [
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_options",
], ],
) )
@ -54,7 +54,7 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/util:render_data_jspb_proto", "//mediapipe/util:render_data_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
], ],

View File

@ -328,7 +328,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map( this.result.confidenceMasks = masks.map(
wasmImage => this.convertToMPImage( wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback)); wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();
@ -347,7 +347,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage( this.result.categoryMask = this.convertToMPMask(
mask, /* shouldCopyData= */ !this.userCallback); mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {MPImage} from '../../../../tasks/web/vision/core/image'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
/** The output result of InteractiveSegmenter. */ /** The output result of InteractiveSegmenter. */
export declare interface InteractiveSegmenterResult { export declare interface InteractiveSegmenterResult {
@ -23,12 +23,12 @@ export declare interface InteractiveSegmenterResult {
* `MPImage`s where, for each mask, each pixel represents the prediction * `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range. * confidence, usually in the [0, 1] range.
*/ */
confidenceMasks?: MPImage[]; confidenceMasks?: MPMask[];
/** /**
* A category mask represented as a `Uint8ClampedArray` or * A category mask represented as a `Uint8ClampedArray` or
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which * `WebGLTexture`-backed `MPImage` where each pixel represents the class which
* the pixel in the original image was predicted to belong to. * the pixel in the original image was predicted to belong to.
*/ */
categoryMask?: MPImage; categoryMask?: MPMask;
} }

View File

@ -19,7 +19,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
@ -177,7 +177,7 @@ describe('InteractiveSegmenter', () => {
}); });
it('supports category mask', async () => { it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]); const mask = new Uint8Array([1, 2, 3, 4]);
await interactiveSegmenter.setOptions( await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false}); {outputCategoryMask: true, outputConfidenceMasks: false});
@ -195,7 +195,7 @@ describe('InteractiveSegmenter', () => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => { interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage); expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.categoryMask!.width).toEqual(2); expect(result.categoryMask!.width).toEqual(2);
expect(result.categoryMask!.height).toEqual(2); expect(result.categoryMask!.height).toEqual(2);
expect(result.confidenceMasks).not.toBeDefined(); expect(result.confidenceMasks).not.toBeDefined();
@ -228,18 +228,18 @@ describe('InteractiveSegmenter', () => {
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined(); expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0].width).toEqual(2); expect(result.confidenceMasks![0].width).toEqual(2);
expect(result.confidenceMasks![0].height).toEqual(2); expect(result.confidenceMasks![0].height).toEqual(2);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve(); resolve();
}); });
}); });
}); });
it('supports combined category and confidence masks', async () => { it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask1 = new Float32Array([0.0]); const confidenceMask1 = new Float32Array([0.0]);
const confidenceMask2 = new Float32Array([1.0]); const confidenceMask2 = new Float32Array([1.0]);
@ -266,19 +266,19 @@ describe('InteractiveSegmenter', () => {
{} as HTMLImageElement, KEYPOINT, result => { {} as HTMLImageElement, KEYPOINT, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(result.categoryMask).toBeInstanceOf(MPImage); expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.categoryMask!.width).toEqual(1); expect(result.categoryMask!.width).toEqual(1);
expect(result.categoryMask!.height).toEqual(1); expect(result.categoryMask!.height).toEqual(1);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![1]).toBeInstanceOf(MPMask);
resolve(); resolve();
}); });
}); });
}); });
it('invokes listener once masks are avaiblae', async () => { it('invokes listener once masks are avaiblae', async () => {
const categoryMask = new Uint8ClampedArray([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false; let listenerCalled = false;
@ -321,7 +321,7 @@ describe('InteractiveSegmenter', () => {
const result = const result =
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT); interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
result.confidenceMasks![0].close(); result.confidenceMasks![0].close();
}); });
}); });

View File

@ -45,7 +45,7 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_options",
], ],
) )
@ -63,7 +63,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/processors:landmark_result", "//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
], ],
) )

View File

@ -504,7 +504,7 @@ export class PoseLandmarker extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
SEGMENTATION_MASK_STREAM, (masks, timestamp) => { SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
this.result.segmentationMasks = masks.map( this.result.segmentationMasks = masks.map(
wasmImage => this.convertToMPImage( wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback)); wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback(); this.maybeInvokeCallback();

View File

@ -16,7 +16,7 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {MPImage} from '../../../../tasks/web/vision/core/image'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
export {Category, Landmark, NormalizedLandmark}; export {Category, Landmark, NormalizedLandmark};
@ -35,5 +35,5 @@ export declare interface PoseLandmarkerResult {
auxilaryLandmarks: NormalizedLandmark[][]; auxilaryLandmarks: NormalizedLandmark[][];
/** Segmentation mask for the detected pose. */ /** Segmentation mask for the detected pose. */
segmentationMasks?: MPImage[]; segmentationMasks?: MPMask[];
} }

View File

@ -18,7 +18,7 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib'; import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {PoseLandmarker} from './pose_landmarker'; import {PoseLandmarker} from './pose_landmarker';
@ -225,7 +225,7 @@ describe('PoseLandmarker', () => {
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.segmentationMasks![0]).toBeInstanceOf(MPImage); expect(result.segmentationMasks![0]).toBeInstanceOf(MPMask);
done(); done();
}); });
}); });

View File

@ -10,7 +10,7 @@ type LibConstructor = new (...args: any[]) => GraphRunner;
/** An image returned from a MediaPipe graph. */ /** An image returned from a MediaPipe graph. */
export interface WasmImage { export interface WasmImage {
data: Uint8ClampedArray|Float32Array|WebGLTexture; data: Uint8Array|Float32Array|WebGLTexture;
width: number; width: number;
height: number; height: number;
} }