Update ImageSegmenter to return MPImage
PiperOrigin-RevId: 527990991
This commit is contained in:
parent
a544098100
commit
2c1d9c6582
|
@ -91,6 +91,7 @@ mediapipe_ts_library(
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "render_utils",
|
name = "render_utils",
|
||||||
srcs = ["render_utils.ts"],
|
srcs = ["render_utils.ts"],
|
||||||
|
deps = [":image"],
|
||||||
)
|
)
|
||||||
|
|
||||||
jasmine_node_test(
|
jasmine_node_test(
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {MPImageChannelConverter} from '../../../../tasks/web/vision/core/image';
|
||||||
|
|
||||||
// Pre-baked color table for a maximum of 12 classes.
|
// Pre-baked color table for a maximum of 12 classes.
|
||||||
const CM_ALPHA = 128;
|
const CM_ALPHA = 128;
|
||||||
const COLOR_MAP = [
|
const COLOR_MAP: Array<[number, number, number, number]> = [
|
||||||
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
|
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
|
||||||
[255, 0, 0, CM_ALPHA], // class 1 is red
|
[255, 0, 0, CM_ALPHA], // class 1 is red
|
||||||
[0, 255, 0, CM_ALPHA], // class 2 is light green
|
[0, 255, 0, CM_ALPHA], // class 2 is light green
|
||||||
|
@ -74,3 +76,9 @@ export function drawCategoryMask(
|
||||||
}
|
}
|
||||||
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
|
ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** The color converter we use in our demos. */
|
||||||
|
export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = {
|
||||||
|
floatToRGBAConverter: v => [128, 0, 0, v * 255],
|
||||||
|
uint8ToRGBAConverter: v => COLOR_MAP[v % COLOR_MAP.length],
|
||||||
|
};
|
||||||
|
|
|
@ -231,39 +231,41 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
||||||
*/
|
*/
|
||||||
protected convertToMPImage(wasmImage: WasmImage): MPImage {
|
protected convertToMPImage(wasmImage: WasmImage): MPImage {
|
||||||
const {data, width, height} = wasmImage;
|
const {data, width, height} = wasmImage;
|
||||||
|
const pixels = width * height;
|
||||||
|
|
||||||
|
let container: ImageData|WebGLTexture|Uint8ClampedArray;
|
||||||
if (data instanceof Uint8ClampedArray) {
|
if (data instanceof Uint8ClampedArray) {
|
||||||
let rgba: Uint8ClampedArray;
|
if (data.length === pixels) {
|
||||||
if (data.length === width * height * 4) {
|
container = data; // Mask
|
||||||
rgba = data;
|
} else if (data.length === pixels * 3) {
|
||||||
} else if (data.length === width * height * 3) {
|
|
||||||
// TODO: Convert in C++
|
// TODO: Convert in C++
|
||||||
rgba = new Uint8ClampedArray(width * height * 4);
|
const rgba = new Uint8ClampedArray(pixels * 4);
|
||||||
for (let i = 0; i < width * height; ++i) {
|
for (let i = 0; i < pixels; ++i) {
|
||||||
rgba[4 * i] = data[3 * i];
|
rgba[4 * i] = data[3 * i];
|
||||||
rgba[4 * i + 1] = data[3 * i + 1];
|
rgba[4 * i + 1] = data[3 * i + 1];
|
||||||
rgba[4 * i + 2] = data[3 * i + 2];
|
rgba[4 * i + 2] = data[3 * i + 2];
|
||||||
rgba[4 * i + 3] = 255;
|
rgba[4 * i + 3] = 255;
|
||||||
}
|
}
|
||||||
|
container = new ImageData(rgba, width, height);
|
||||||
|
} else if (data.length ===pixels * 4) {
|
||||||
|
container = new ImageData(data, width, height);
|
||||||
} else {
|
} else {
|
||||||
throw new Error(
|
throw new Error(`Unsupported channel count: ${data.length/pixels}`);
|
||||||
`Unsupported channel count: ${data.length / width / height}`);
|
}
|
||||||
|
} else if (data instanceof Float32Array) {
|
||||||
|
if (data.length === pixels) {
|
||||||
|
container = data; // Mask
|
||||||
|
} else {
|
||||||
|
throw new Error(`Unsupported channel count: ${data.length/pixels}`);
|
||||||
|
}
|
||||||
|
} else { // WebGLTexture
|
||||||
|
container = data;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new MPImage(
|
return new MPImage(
|
||||||
[new ImageData(rgba, width, height)],
|
[container], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
|
||||||
/* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
|
|
||||||
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
|
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
|
||||||
height);
|
height);
|
||||||
} else if (data instanceof WebGLTexture) {
|
|
||||||
return new MPImage(
|
|
||||||
[data], /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false,
|
|
||||||
this.graphRunner.wasmModule.canvas!, this.shaderContext, width,
|
|
||||||
height);
|
|
||||||
} else {
|
|
||||||
throw new Error(
|
|
||||||
`Cannot convert type ${data.constructor.name} to MPImage.`);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Closes and cleans up the resources held by this task. */
|
/** Closes and cleans up the resources held by this task. */
|
||||||
|
|
|
@ -20,7 +20,6 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/web/vision/core:types",
|
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||||
"//mediapipe/util:label_map_jspb_proto",
|
"//mediapipe/util:label_map_jspb_proto",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
|
@ -36,6 +35,7 @@ mediapipe_ts_declaration(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/vision/core:image",
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -52,6 +52,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/framework:calculator_jspb_proto",
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
"//mediapipe/tasks/web/vision/core:image",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,6 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {LabelMapItem} from '../../../../util/label_map_pb';
|
import {LabelMapItem} from '../../../../util/label_map_pb';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
@ -33,7 +32,6 @@ import {ImageSegmenterResult} from './image_segmenter_result';
|
||||||
|
|
||||||
export * from './image_segmenter_options';
|
export * from './image_segmenter_options';
|
||||||
export * from './image_segmenter_result';
|
export * from './image_segmenter_result';
|
||||||
export {SegmentationMask};
|
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
const IMAGE_STREAM = 'image_in';
|
const IMAGE_STREAM = 'image_in';
|
||||||
|
@ -60,7 +58,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
||||||
|
|
||||||
/** Performs image segmentation on images. */
|
/** Performs image segmentation on images. */
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private result: ImageSegmenterResult = {width: 0, height: 0};
|
private result: ImageSegmenterResult = {};
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
|
@ -313,7 +311,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.result = {width: 0, height: 0};
|
this.result = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
@ -341,12 +339,8 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
this.result.confidenceMasks = masks.map(m => m.data);
|
this.result.confidenceMasks =
|
||||||
if (masks.length >= 0) {
|
masks.map(wasmImage => this.convertToMPImage(wasmImage));
|
||||||
this.result.width = masks[0].width;
|
|
||||||
this.result.height = masks[0].height;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
@ -361,9 +355,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
this.result.categoryMask = mask.data;
|
this.result.categoryMask = this.convertToMPImage(mask);
|
||||||
this.result.width = mask.width;
|
|
||||||
this.result.height = mask.height;
|
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
|
|
@ -14,24 +14,21 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import {MPImage} from '../../../../tasks/web/vision/core/image';
|
||||||
|
|
||||||
/** The output result of ImageSegmenter. */
|
/** The output result of ImageSegmenter. */
|
||||||
export declare interface ImageSegmenterResult {
|
export declare interface ImageSegmenterResult {
|
||||||
/**
|
/**
|
||||||
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
|
||||||
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
* `MPImage`s where, for each mask, each pixel represents the prediction
|
||||||
|
* confidence, usually in the [0, 1] range.
|
||||||
*/
|
*/
|
||||||
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
confidenceMasks?: MPImage[];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
* A category mask represented as a `Uint8ClampedArray` or
|
||||||
* pixel represents the class which the pixel in the original image was
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which
|
||||||
* predicted to belong to.
|
* the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
categoryMask?: MPImage;
|
||||||
|
|
||||||
/** The width of the masks. */
|
|
||||||
width: number;
|
|
||||||
|
|
||||||
/** The height of the masks. */
|
|
||||||
height: number;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import 'jasmine';
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
import {MPImage} from '../../../../tasks/web/vision/core/image';
|
||||||
|
|
||||||
import {ImageSegmenter} from './image_segmenter';
|
import {ImageSegmenter} from './image_segmenter';
|
||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
@ -182,10 +183,10 @@ describe('ImageSegmenter', () => {
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(result.categoryMask).toEqual(mask);
|
expect(result.categoryMask).toBeInstanceOf(MPImage);
|
||||||
expect(result.confidenceMasks).not.toBeDefined();
|
expect(result.confidenceMasks).not.toBeDefined();
|
||||||
expect(result.width).toEqual(2);
|
expect(result.categoryMask!.width).toEqual(2);
|
||||||
expect(result.height).toEqual(2);
|
expect(result.categoryMask!.height).toEqual(2);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -214,18 +215,21 @@ describe('ImageSegmenter', () => {
|
||||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(result.categoryMask).not.toBeDefined();
|
expect(result.categoryMask).not.toBeDefined();
|
||||||
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
|
||||||
expect(result.width).toEqual(2);
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
|
||||||
expect(result.height).toEqual(2);
|
expect(result.confidenceMasks![0].width).toEqual(2);
|
||||||
|
expect(result.confidenceMasks![0].height).toEqual(2);
|
||||||
|
|
||||||
|
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('supports combined category and confidence masks', async () => {
|
it('supports combined category and confidence masks', async () => {
|
||||||
const categoryMask = new Uint8ClampedArray([1, 0]);
|
const categoryMask = new Uint8ClampedArray([1]);
|
||||||
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
const confidenceMask1 = new Float32Array([0.0]);
|
||||||
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
const confidenceMask2 = new Float32Array([1.0]);
|
||||||
|
|
||||||
await imageSegmenter.setOptions(
|
await imageSegmenter.setOptions(
|
||||||
{outputCategoryMask: true, outputConfidenceMasks: true});
|
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||||
|
@ -248,12 +252,12 @@ describe('ImageSegmenter', () => {
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(result.categoryMask).toEqual(categoryMask);
|
expect(result.categoryMask).toBeInstanceOf(MPImage);
|
||||||
expect(result.confidenceMasks).toEqual([
|
expect(result.categoryMask!.width).toEqual(1);
|
||||||
confidenceMask1, confidenceMask2
|
expect(result.categoryMask!.height).toEqual(1);
|
||||||
]);
|
|
||||||
expect(result.width).toEqual(1);
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
|
||||||
expect(result.height).toEqual(1);
|
expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user