Support new output format for ImageSegmenter
PiperOrigin-RevId: 524371021
This commit is contained in:
parent
f5197a3adc
commit
92f45c98d8
|
@ -59,13 +59,12 @@ export function drawCategoryMask(
|
||||||
const isFloatArray = image instanceof Float32Array;
|
const isFloatArray = image instanceof Float32Array;
|
||||||
for (let i = 0; i < image.length; i++) {
|
for (let i = 0; i < image.length; i++) {
|
||||||
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
||||||
const color = COLOR_MAP[colorIndex];
|
let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||||
|
|
||||||
// When we're given a confidence mask by accident, we just log and return.
|
|
||||||
// TODO: We should fix this.
|
|
||||||
if (!color) {
|
if (!color) {
|
||||||
|
// TODO: We should fix this.
|
||||||
console.warn('No color for ', colorIndex);
|
console.warn('No color for ', colorIndex);
|
||||||
return;
|
color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||||
}
|
}
|
||||||
|
|
||||||
rgbaArray[4 * i] = color[0];
|
rgbaArray[4 * i] = color[0];
|
||||||
|
|
|
@ -29,7 +29,10 @@ mediapipe_ts_library(
|
||||||
|
|
||||||
mediapipe_ts_declaration(
|
mediapipe_ts_declaration(
|
||||||
name = "image_segmenter_types",
|
name = "image_segmenter_types",
|
||||||
srcs = ["image_segmenter_options.d.ts"],
|
srcs = [
|
||||||
|
"image_segmenter_options.d.ts",
|
||||||
|
"image_segmenter_result.d.ts",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
|
|
@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {LabelMapItem} from '../../../../util/label_map_pb';
|
import {LabelMapItem} from '../../../../util/label_map_pb';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
import {ImageSegmenterResult} from './image_segmenter_result';
|
||||||
|
|
||||||
export * from './image_segmenter_options';
|
export * from './image_segmenter_options';
|
||||||
export {SegmentationMask, SegmentationMaskCallback};
|
export * from './image_segmenter_result';
|
||||||
|
export {SegmentationMask};
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
const IMAGE_STREAM = 'image_in';
|
const IMAGE_STREAM = 'image_in';
|
||||||
const NORM_RECT_STREAM = 'norm_rect';
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
const IMAGE_SEGMENTER_GRAPH =
|
const IMAGE_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
'mediapipe.tasks.TensorsToSegmentationCalculator';
|
'mediapipe.tasks.TensorsToSegmentationCalculator';
|
||||||
|
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||||
|
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A callback that receives the computed masks from the image segmenter. The
|
||||||
|
* returned data is only valid for the duration of the callback. If
|
||||||
|
* asynchronous processing is needed, all data needs to be copied before the
|
||||||
|
* callback returns.
|
||||||
|
*/
|
||||||
|
export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
|
||||||
|
|
||||||
/** Performs image segmentation on images. */
|
/** Performs image segmentation on images. */
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private userCallback: SegmentationMaskCallback = () => {};
|
private result: ImageSegmenterResult = {width: 0, height: 0};
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||||
|
|
||||||
|
@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
this.options.setBaseOptions(new BaseOptionsProto());
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
return this.options.getBaseOptions()!;
|
return this.options.getBaseOptions()!;
|
||||||
}
|
}
|
||||||
|
@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
this.options.clearDisplayNamesLocale();
|
this.options.clearDisplayNamesLocale();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
if ('outputCategoryMask' in options) {
|
||||||
this.segmenterOptions.setOutputType(
|
this.outputCategoryMask =
|
||||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
} else {
|
}
|
||||||
this.segmenterOptions.setOutputType(
|
|
||||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
if ('outputConfidenceMasks' in options) {
|
||||||
|
this.outputConfidenceMasks =
|
||||||
|
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
}
|
}
|
||||||
|
|
||||||
return super.applyOptions(options);
|
return super.applyOptions(options);
|
||||||
|
@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
* callback.
|
* callback.
|
||||||
*/
|
*/
|
||||||
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
|
segment(image: ImageSource, callback: ImageSegmenterCallack): void;
|
||||||
/**
|
/**
|
||||||
* Performs image segmentation on the provided single image and invokes the
|
* Performs image segmentation on the provided single image and invokes the
|
||||||
* callback with the response. The method returns synchronously once the
|
* callback with the response. The method returns synchronously once the
|
||||||
|
@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
*/
|
*/
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||||
callback: SegmentationMaskCallback): void;
|
callback: ImageSegmenterCallack): void;
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource,
|
image: ImageSource,
|
||||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||||
SegmentationMaskCallback,
|
ImageSegmenterCallack,
|
||||||
callback?: SegmentationMaskCallback): void {
|
callback?: ImageSegmenterCallack): void {
|
||||||
const imageProcessingOptions =
|
const imageProcessingOptions =
|
||||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
{};
|
{};
|
||||||
|
const userCallback =
|
||||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
callback!;
|
callback!;
|
||||||
|
|
||||||
|
this.reset();
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
this.userCallback = () => {};
|
userCallback(this.result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided video frame and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `video`.
|
||||||
|
*
|
||||||
|
* @param videoFrame A video frame to process.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource, timestamp: number,
|
||||||
|
callback: ImageSegmenterCallack): void;
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided video frame and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `video`.
|
||||||
|
*
|
||||||
|
* @param videoFrame A video frame to process.
|
||||||
|
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||||
|
* to process the input image before running inference.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||||
|
timestamp: number, callback: ImageSegmenterCallack): void;
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource,
|
||||||
|
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
||||||
|
timestampOrCallback: number|ImageSegmenterCallack,
|
||||||
|
callback?: ImageSegmenterCallack): void {
|
||||||
|
const imageProcessingOptions =
|
||||||
|
typeof timestampOrImageProcessingOptions !== 'number' ?
|
||||||
|
timestampOrImageProcessingOptions :
|
||||||
|
{};
|
||||||
|
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||||
|
timestampOrImageProcessingOptions :
|
||||||
|
timestampOrCallback as number;
|
||||||
|
const userCallback = typeof timestampOrCallback === 'function' ?
|
||||||
|
timestampOrCallback :
|
||||||
|
callback!;
|
||||||
|
|
||||||
|
this.reset();
|
||||||
|
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||||
|
userCallback(this.result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
return this.labels;
|
return this.labels;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private reset(): void {
|
||||||
* Performs image segmentation on the provided video frame and invokes the
|
this.result = {width: 0, height: 0};
|
||||||
* callback with the response. The method returns synchronously once the
|
|
||||||
* callback returns. Only use this method when the ImageSegmenter is
|
|
||||||
* created with running mode `video`.
|
|
||||||
*
|
|
||||||
* @param videoFrame A video frame to process.
|
|
||||||
* @param timestamp The timestamp of the current frame, in ms.
|
|
||||||
* @param callback The callback that is invoked with the segmented masks. The
|
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
|
||||||
* callback.
|
|
||||||
*/
|
|
||||||
segmentForVideo(
|
|
||||||
videoFrame: ImageSource, timestamp: number,
|
|
||||||
callback: SegmentationMaskCallback): void;
|
|
||||||
/**
|
|
||||||
* Performs image segmentation on the provided video frame and invokes the
|
|
||||||
* callback with the response. The method returns synchronously once the
|
|
||||||
* callback returns. Only use this method when the ImageSegmenter is
|
|
||||||
* created with running mode `video`.
|
|
||||||
*
|
|
||||||
* @param videoFrame A video frame to process.
|
|
||||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
|
||||||
* to process the input image before running inference.
|
|
||||||
* @param timestamp The timestamp of the current frame, in ms.
|
|
||||||
* @param callback The callback that is invoked with the segmented masks. The
|
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
|
||||||
* callback.
|
|
||||||
*/
|
|
||||||
segmentForVideo(
|
|
||||||
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
|
||||||
timestamp: number, callback: SegmentationMaskCallback): void;
|
|
||||||
segmentForVideo(
|
|
||||||
videoFrame: ImageSource,
|
|
||||||
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
|
||||||
timestampOrCallback: number|SegmentationMaskCallback,
|
|
||||||
callback?: SegmentationMaskCallback): void {
|
|
||||||
const imageProcessingOptions =
|
|
||||||
typeof timestampOrImageProcessingOptions !== 'number' ?
|
|
||||||
timestampOrImageProcessingOptions :
|
|
||||||
{};
|
|
||||||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
|
||||||
timestampOrImageProcessingOptions :
|
|
||||||
timestampOrCallback as number;
|
|
||||||
|
|
||||||
this.userCallback = typeof timestampOrCallback === 'function' ?
|
|
||||||
timestampOrCallback :
|
|
||||||
callback!;
|
|
||||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
|
||||||
this.userCallback = () => {};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(IMAGE_STREAM);
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
|
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
|
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
|
||||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||||
segmenterNode.addOutputStream(
|
|
||||||
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
|
|
||||||
segmenterNode.setOptions(calculatorOptions);
|
segmenterNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(segmenterNode);
|
graphConfig.addNode(segmenterNode);
|
||||||
|
|
||||||
|
if (this.outputConfidenceMasks) {
|
||||||
|
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||||
|
segmenterNode.addOutputStream(
|
||||||
|
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
if (masks.length === 0) {
|
this.result.confidenceMasks = masks.map(m => m.data);
|
||||||
this.userCallback([], 0, 0);
|
if (masks.length >= 0) {
|
||||||
} else {
|
this.result.width = masks[0].width;
|
||||||
this.userCallback(
|
this.result.height = masks[0].height;
|
||||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
GROUPED_SEGMENTATIONS_STREAM, timestamp => {
|
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.outputCategoryMask) {
|
||||||
|
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||||
|
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachImageListener(
|
||||||
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
|
this.result.categoryMask = mask.data;
|
||||||
|
this.result.width = mask.width;
|
||||||
|
this.result.height = mask.height;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
CATEGORY_MASK_STREAM, timestamp => {
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
|
|
@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
|
||||||
*/
|
*/
|
||||||
displayNamesLocale?: string|undefined;
|
displayNamesLocale?: string|undefined;
|
||||||
|
|
||||||
/**
|
/** Whether to output confidence masks. Defaults to true. */
|
||||||
* The output type of segmentation results.
|
outputConfidenceMasks?: boolean|undefined;
|
||||||
*
|
|
||||||
* The two supported modes are:
|
/** Whether to output the category masks. Defaults to false. */
|
||||||
* - Category Mask: Gives a single output mask where each pixel represents
|
outputCategoryMask?: boolean|undefined;
|
||||||
* the class which the pixel in the original image was
|
|
||||||
* predicted to belong to.
|
|
||||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
|
||||||
* each mask, the pixel represents the prediction
|
|
||||||
* confidence, usually in the [0.0, 0.1] range.
|
|
||||||
*
|
|
||||||
* Defaults to `CATEGORY_MASK`.
|
|
||||||
*/
|
|
||||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
|
||||||
}
|
}
|
||||||
|
|
37
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
vendored
Normal file
37
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
vendored
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** The output result of ImageSegmenter. */
|
||||||
|
export declare interface ImageSegmenterResult {
|
||||||
|
/**
|
||||||
|
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
||||||
|
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
|
*/
|
||||||
|
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
||||||
|
* pixel represents the class which the pixel in the original image was
|
||||||
|
* predicted to belong to.
|
||||||
|
*/
|
||||||
|
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
||||||
|
|
||||||
|
/** The width of the masks. */
|
||||||
|
width: number;
|
||||||
|
|
||||||
|
/** The height of the masks. */
|
||||||
|
height: number;
|
||||||
|
}
|
|
@ -18,7 +18,7 @@ import 'jasmine';
|
||||||
|
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
|
||||||
import {ImageSegmenter} from './image_segmenter';
|
import {ImageSegmenter} from './image_segmenter';
|
||||||
|
@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
graph: CalculatorGraphConfig|undefined;
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
fakeWasmModule: SpyWasmModule;
|
fakeWasmModule: SpyWasmModule;
|
||||||
imageVectorListener:
|
categoryMaskListener:
|
||||||
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
|
@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
this.fakeWasmModule =
|
this.fakeWasmModule =
|
||||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
this.attachListenerSpies[0] =
|
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('category_mask');
|
||||||
|
this.categoryMaskListener = listener;
|
||||||
|
});
|
||||||
|
this.attachListenerSpies[1] =
|
||||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||||
.and.callFake((stream, listener) => {
|
.and.callFake((stream, listener) => {
|
||||||
expect(stream).toEqual('segmented_masks');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.imageVectorListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
verifyGraph(imageSegmenter);
|
verifyGraph(imageSegmenter);
|
||||||
verifyListenersRegistered(imageSegmenter);
|
|
||||||
|
// Verify default options
|
||||||
|
expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
|
||||||
|
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('reloads graph when settings are changed', async () => {
|
it('reloads graph when settings are changed', async () => {
|
||||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||||
verifyListenersRegistered(imageSegmenter);
|
|
||||||
|
|
||||||
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
||||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
||||||
verifyListenersRegistered(imageSegmenter);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can use custom models', async () => {
|
it('can use custom models', async () => {
|
||||||
|
@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('merges options', async () => {
|
it('merges options', async () => {
|
||||||
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
await imageSegmenter.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||||
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
verifyGraph(
|
||||||
|
imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
|
||||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
|
||||||
defaultValue: unknown;
|
defaultValue: unknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
const testCases: TestCase[] = [
|
const testCases: TestCase[] = [{
|
||||||
{
|
|
||||||
optionName: 'displayNamesLocale',
|
optionName: 'displayNamesLocale',
|
||||||
fieldPath: ['displayNamesLocale'],
|
fieldPath: ['displayNamesLocale'],
|
||||||
userValue: 'en',
|
userValue: 'en',
|
||||||
graphValue: 'en',
|
graphValue: 'en',
|
||||||
defaultValue: 'en'
|
defaultValue: 'en'
|
||||||
},
|
}];
|
||||||
{
|
|
||||||
optionName: 'outputType',
|
|
||||||
fieldPath: ['segmenterOptions', 'outputType'],
|
|
||||||
userValue: 'CONFIDENCE_MASK',
|
|
||||||
graphValue: 2,
|
|
||||||
defaultValue: 1
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const testCase of testCases) {
|
for (const testCase of testCases) {
|
||||||
it(`can set ${testCase.optionName}`, async () => {
|
it(`can set ${testCase.optionName}`, async () => {
|
||||||
|
@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
|
||||||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('supports category masks', (done) => {
|
it('supports category mask', async () => {
|
||||||
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: false});
|
||||||
|
|
||||||
// Pass the test data to our listener
|
// Pass the test data to our listener
|
||||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
verifyListenersRegistered(imageSegmenter);
|
expect(imageSegmenter.categoryMaskListener).toBeDefined();
|
||||||
imageSegmenter.imageVectorListener!(
|
imageSegmenter.categoryMaskListener!
|
||||||
[
|
({data: mask, width: 2, height: 2},
|
||||||
{data: mask, width: 2, height: 2},
|
|
||||||
],
|
|
||||||
/* timestamp= */ 1337);
|
/* timestamp= */ 1337);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
|
||||||
|
return new Promise<void>(resolve => {
|
||||||
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(masks).toHaveSize(1);
|
expect(result.categoryMask).toEqual(mask);
|
||||||
expect(masks[0]).toEqual(mask);
|
expect(result.confidenceMasks).not.toBeDefined();
|
||||||
expect(width).toEqual(2);
|
expect(result.width).toEqual(2);
|
||||||
expect(height).toEqual(2);
|
expect(result.height).toEqual(2);
|
||||||
done();
|
resolve();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
|
||||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||||
|
|
||||||
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: false, outputConfidenceMasks: true});
|
||||||
|
|
||||||
// Pass the test data to our listener
|
// Pass the test data to our listener
|
||||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
verifyListenersRegistered(imageSegmenter);
|
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
imageSegmenter.imageVectorListener!(
|
imageSegmenter.confidenceMasksListener!(
|
||||||
[
|
[
|
||||||
{data: mask1, width: 2, height: 2},
|
{data: mask1, width: 2, height: 2},
|
||||||
{data: mask2, width: 2, height: 2},
|
{data: mask2, width: 2, height: 2},
|
||||||
|
@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(masks).toHaveSize(2);
|
expect(result.categoryMask).not.toBeDefined();
|
||||||
expect(masks[0]).toEqual(mask1);
|
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
||||||
expect(masks[1]).toEqual(mask2);
|
expect(result.width).toEqual(2);
|
||||||
expect(width).toEqual(2);
|
expect(result.height).toEqual(2);
|
||||||
expect(height).toEqual(2);
|
resolve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('supports combined category and confidence masks', async () => {
|
||||||
|
const categoryMask = new Uint8ClampedArray([1, 0]);
|
||||||
|
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
||||||
|
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
expect(imageSegmenter.categoryMaskListener).toBeDefined();
|
||||||
|
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
|
imageSegmenter.categoryMaskListener!
|
||||||
|
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||||
|
imageSegmenter.confidenceMasksListener!(
|
||||||
|
[
|
||||||
|
{data: confidenceMask1, width: 1, height: 1},
|
||||||
|
{data: confidenceMask2, width: 1, height: 1},
|
||||||
|
],
|
||||||
|
1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Promise<void>(resolve => {
|
||||||
|
// Invoke the image segmenter
|
||||||
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(result.categoryMask).toEqual(categoryMask);
|
||||||
|
expect(result.confidenceMasks).toEqual([
|
||||||
|
confidenceMask1, confidenceMask2
|
||||||
|
]);
|
||||||
|
expect(result.width).toEqual(1);
|
||||||
|
expect(result.height).toEqual(1);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user