Support new output format for ImageSegmenter
PiperOrigin-RevId: 524371021
This commit is contained in:
		
							parent
							
								
									f5197a3adc
								
							
						
					
					
						commit
						92f45c98d8
					
				| 
						 | 
					@ -59,13 +59,12 @@ export function drawCategoryMask(
 | 
				
			||||||
  const isFloatArray = image instanceof Float32Array;
 | 
					  const isFloatArray = image instanceof Float32Array;
 | 
				
			||||||
  for (let i = 0; i < image.length; i++) {
 | 
					  for (let i = 0; i < image.length; i++) {
 | 
				
			||||||
    const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
 | 
					    const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
 | 
				
			||||||
    const color = COLOR_MAP[colorIndex];
 | 
					    let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // When we're given a confidence mask by accident, we just log and return.
 | 
					 | 
				
			||||||
    // TODO: We should fix this.
 | 
					 | 
				
			||||||
    if (!color) {
 | 
					    if (!color) {
 | 
				
			||||||
 | 
					      // TODO: We should fix this.
 | 
				
			||||||
      console.warn('No color for ', colorIndex);
 | 
					      console.warn('No color for ', colorIndex);
 | 
				
			||||||
      return;
 | 
					      color = COLOR_MAP[colorIndex % COLOR_MAP.length];
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    rgbaArray[4 * i] = color[0];
 | 
					    rgbaArray[4 * i] = color[0];
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,7 +29,10 @@ mediapipe_ts_library(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mediapipe_ts_declaration(
 | 
					mediapipe_ts_declaration(
 | 
				
			||||||
    name = "image_segmenter_types",
 | 
					    name = "image_segmenter_types",
 | 
				
			||||||
    srcs = ["image_segmenter_options.d.ts"],
 | 
					    srcs = [
 | 
				
			||||||
 | 
					        "image_segmenter_options.d.ts",
 | 
				
			||||||
 | 
					        "image_segmenter_result.d.ts",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
        "//mediapipe/tasks/web/core:classifier_options",
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
 | 
				
			||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
					import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
				
			||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
					import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
				
			||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
					import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
				
			||||||
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
 | 
					import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
 | 
				
			||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
					import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
				
			||||||
import {LabelMapItem} from '../../../../util/label_map_pb';
 | 
					import {LabelMapItem} from '../../../../util/label_map_pb';
 | 
				
			||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
 | 
					import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
 | 
				
			||||||
// Placeholder for internal dependency on trusted resource url
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
 | 
					import {ImageSegmenterOptions} from './image_segmenter_options';
 | 
				
			||||||
 | 
					import {ImageSegmenterResult} from './image_segmenter_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export * from './image_segmenter_options';
 | 
					export * from './image_segmenter_options';
 | 
				
			||||||
export {SegmentationMask, SegmentationMaskCallback};
 | 
					export * from './image_segmenter_result';
 | 
				
			||||||
 | 
					export {SegmentationMask};
 | 
				
			||||||
export {ImageSource};  // Used in the public API
 | 
					export {ImageSource};  // Used in the public API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const IMAGE_STREAM = 'image_in';
 | 
					const IMAGE_STREAM = 'image_in';
 | 
				
			||||||
const NORM_RECT_STREAM = 'norm_rect';
 | 
					const NORM_RECT_STREAM = 'norm_rect';
 | 
				
			||||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
 | 
					const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
 | 
				
			||||||
 | 
					const CATEGORY_MASK_STREAM = 'category_mask';
 | 
				
			||||||
const IMAGE_SEGMENTER_GRAPH =
 | 
					const IMAGE_SEGMENTER_GRAPH =
 | 
				
			||||||
    'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
 | 
					    'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
 | 
				
			||||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
					const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
 | 
				
			||||||
    'mediapipe.tasks.TensorsToSegmentationCalculator';
 | 
					    'mediapipe.tasks.TensorsToSegmentationCalculator';
 | 
				
			||||||
 | 
					const DEFAULT_OUTPUT_CATEGORY_MASK = false;
 | 
				
			||||||
 | 
					const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// The OSS JS API does not support the builder pattern.
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
// tslint:disable:jspb-use-builder-pattern
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * A callback that receives the computed masks from the image segmenter. The
 | 
				
			||||||
 | 
					 * returned data is only valid for the duration of the callback. If
 | 
				
			||||||
 | 
					 * asynchronous processing is needed, all data needs to be copied before the
 | 
				
			||||||
 | 
					 * callback returns.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** Performs image segmentation on images. */
 | 
					/** Performs image segmentation on images. */
 | 
				
			||||||
export class ImageSegmenter extends VisionTaskRunner {
 | 
					export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
  private userCallback: SegmentationMaskCallback = () => {};
 | 
					  private result: ImageSegmenterResult = {width: 0, height: 0};
 | 
				
			||||||
  private labels: string[] = [];
 | 
					  private labels: string[] = [];
 | 
				
			||||||
 | 
					  private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
 | 
				
			||||||
 | 
					  private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
 | 
				
			||||||
  private readonly options: ImageSegmenterGraphOptionsProto;
 | 
					  private readonly options: ImageSegmenterGraphOptionsProto;
 | 
				
			||||||
  private readonly segmenterOptions: SegmenterOptionsProto;
 | 
					  private readonly segmenterOptions: SegmenterOptionsProto;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    this.options.setBaseOptions(new BaseOptionsProto());
 | 
					    this.options.setBaseOptions(new BaseOptionsProto());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
  protected override get baseOptions(): BaseOptionsProto {
 | 
					  protected override get baseOptions(): BaseOptionsProto {
 | 
				
			||||||
    return this.options.getBaseOptions()!;
 | 
					    return this.options.getBaseOptions()!;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
      this.options.clearDisplayNamesLocale();
 | 
					      this.options.clearDisplayNamesLocale();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (options.outputType === 'CONFIDENCE_MASK') {
 | 
					    if ('outputCategoryMask' in options) {
 | 
				
			||||||
      this.segmenterOptions.setOutputType(
 | 
					      this.outputCategoryMask =
 | 
				
			||||||
          SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
 | 
					          options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
 | 
				
			||||||
    } else {
 | 
					    }
 | 
				
			||||||
      this.segmenterOptions.setOutputType(
 | 
					
 | 
				
			||||||
          SegmenterOptionsProto.OutputType.CATEGORY_MASK);
 | 
					    if ('outputConfidenceMasks' in options) {
 | 
				
			||||||
 | 
					      this.outputConfidenceMasks =
 | 
				
			||||||
 | 
					          options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return super.applyOptions(options);
 | 
					    return super.applyOptions(options);
 | 
				
			||||||
| 
						 | 
					@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
					   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
				
			||||||
   *    callback.
 | 
					   *    callback.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  segment(image: ImageSource, callback: SegmentationMaskCallback): void;
 | 
					  segment(image: ImageSource, callback: ImageSegmenterCallack): void;
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Performs image segmentation on the provided single image and invokes the
 | 
					   * Performs image segmentation on the provided single image and invokes the
 | 
				
			||||||
   * callback with the response. The method returns synchronously once the
 | 
					   * callback with the response. The method returns synchronously once the
 | 
				
			||||||
| 
						 | 
					@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  segment(
 | 
					  segment(
 | 
				
			||||||
      image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
 | 
					      image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
 | 
				
			||||||
      callback: SegmentationMaskCallback): void;
 | 
					      callback: ImageSegmenterCallack): void;
 | 
				
			||||||
  segment(
 | 
					  segment(
 | 
				
			||||||
      image: ImageSource,
 | 
					      image: ImageSource,
 | 
				
			||||||
      imageProcessingOptionsOrCallback: ImageProcessingOptions|
 | 
					      imageProcessingOptionsOrCallback: ImageProcessingOptions|
 | 
				
			||||||
      SegmentationMaskCallback,
 | 
					      ImageSegmenterCallack,
 | 
				
			||||||
      callback?: SegmentationMaskCallback): void {
 | 
					      callback?: ImageSegmenterCallack): void {
 | 
				
			||||||
    const imageProcessingOptions =
 | 
					    const imageProcessingOptions =
 | 
				
			||||||
        typeof imageProcessingOptionsOrCallback !== 'function' ?
 | 
					        typeof imageProcessingOptionsOrCallback !== 'function' ?
 | 
				
			||||||
        imageProcessingOptionsOrCallback :
 | 
					        imageProcessingOptionsOrCallback :
 | 
				
			||||||
        {};
 | 
					        {};
 | 
				
			||||||
 | 
					    const userCallback =
 | 
				
			||||||
    this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
 | 
					        typeof imageProcessingOptionsOrCallback === 'function' ?
 | 
				
			||||||
        imageProcessingOptionsOrCallback :
 | 
					        imageProcessingOptionsOrCallback :
 | 
				
			||||||
        callback!;
 | 
					        callback!;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.reset();
 | 
				
			||||||
    this.processImageData(image, imageProcessingOptions);
 | 
					    this.processImageData(image, imageProcessingOptions);
 | 
				
			||||||
    this.userCallback = () => {};
 | 
					    userCallback(this.result);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs image segmentation on the provided video frame and invokes the
 | 
				
			||||||
 | 
					   * callback with the response. The method returns synchronously once the
 | 
				
			||||||
 | 
					   * callback returns. Only use this method when the ImageSegmenter is
 | 
				
			||||||
 | 
					   * created with running mode `video`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param videoFrame A video frame to process.
 | 
				
			||||||
 | 
					   * @param timestamp The timestamp of the current frame, in ms.
 | 
				
			||||||
 | 
					   * @param callback The callback that is invoked with the segmented masks. The
 | 
				
			||||||
 | 
					   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
				
			||||||
 | 
					   *    callback.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  segmentForVideo(
 | 
				
			||||||
 | 
					      videoFrame: ImageSource, timestamp: number,
 | 
				
			||||||
 | 
					      callback: ImageSegmenterCallack): void;
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Performs image segmentation on the provided video frame and invokes the
 | 
				
			||||||
 | 
					   * callback with the response. The method returns synchronously once the
 | 
				
			||||||
 | 
					   * callback returns. Only use this method when the ImageSegmenter is
 | 
				
			||||||
 | 
					   * created with running mode `video`.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param videoFrame A video frame to process.
 | 
				
			||||||
 | 
					   * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
 | 
				
			||||||
 | 
					   *    to process the input image before running inference.
 | 
				
			||||||
 | 
					   * @param timestamp The timestamp of the current frame, in ms.
 | 
				
			||||||
 | 
					   * @param callback The callback that is invoked with the segmented masks. The
 | 
				
			||||||
 | 
					   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
				
			||||||
 | 
					   *    callback.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  segmentForVideo(
 | 
				
			||||||
 | 
					      videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
 | 
				
			||||||
 | 
					      timestamp: number, callback: ImageSegmenterCallack): void;
 | 
				
			||||||
 | 
					  segmentForVideo(
 | 
				
			||||||
 | 
					      videoFrame: ImageSource,
 | 
				
			||||||
 | 
					      timestampOrImageProcessingOptions: number|ImageProcessingOptions,
 | 
				
			||||||
 | 
					      timestampOrCallback: number|ImageSegmenterCallack,
 | 
				
			||||||
 | 
					      callback?: ImageSegmenterCallack): void {
 | 
				
			||||||
 | 
					    const imageProcessingOptions =
 | 
				
			||||||
 | 
					        typeof timestampOrImageProcessingOptions !== 'number' ?
 | 
				
			||||||
 | 
					        timestampOrImageProcessingOptions :
 | 
				
			||||||
 | 
					        {};
 | 
				
			||||||
 | 
					    const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
 | 
				
			||||||
 | 
					        timestampOrImageProcessingOptions :
 | 
				
			||||||
 | 
					        timestampOrCallback as number;
 | 
				
			||||||
 | 
					    const userCallback = typeof timestampOrCallback === 'function' ?
 | 
				
			||||||
 | 
					        timestampOrCallback :
 | 
				
			||||||
 | 
					        callback!;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    this.reset();
 | 
				
			||||||
 | 
					    this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
 | 
				
			||||||
 | 
					    userCallback(this.result);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
| 
						 | 
					@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    return this.labels;
 | 
					    return this.labels;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  private reset(): void {
 | 
				
			||||||
   * Performs image segmentation on the provided video frame and invokes the
 | 
					    this.result = {width: 0, height: 0};
 | 
				
			||||||
   * callback with the response. The method returns synchronously once the
 | 
					 | 
				
			||||||
   * callback returns. Only use this method when the ImageSegmenter is
 | 
					 | 
				
			||||||
   * created with running mode `video`.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param videoFrame A video frame to process.
 | 
					 | 
				
			||||||
   * @param timestamp The timestamp of the current frame, in ms.
 | 
					 | 
				
			||||||
   * @param callback The callback that is invoked with the segmented masks. The
 | 
					 | 
				
			||||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
					 | 
				
			||||||
   *    callback.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  segmentForVideo(
 | 
					 | 
				
			||||||
      videoFrame: ImageSource, timestamp: number,
 | 
					 | 
				
			||||||
      callback: SegmentationMaskCallback): void;
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Performs image segmentation on the provided video frame and invokes the
 | 
					 | 
				
			||||||
   * callback with the response. The method returns synchronously once the
 | 
					 | 
				
			||||||
   * callback returns. Only use this method when the ImageSegmenter is
 | 
					 | 
				
			||||||
   * created with running mode `video`.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param videoFrame A video frame to process.
 | 
					 | 
				
			||||||
   * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
 | 
					 | 
				
			||||||
   *    to process the input image before running inference.
 | 
					 | 
				
			||||||
   * @param timestamp The timestamp of the current frame, in ms.
 | 
					 | 
				
			||||||
   * @param callback The callback that is invoked with the segmented masks. The
 | 
					 | 
				
			||||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
					 | 
				
			||||||
   *    callback.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  segmentForVideo(
 | 
					 | 
				
			||||||
      videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
 | 
					 | 
				
			||||||
      timestamp: number, callback: SegmentationMaskCallback): void;
 | 
					 | 
				
			||||||
  segmentForVideo(
 | 
					 | 
				
			||||||
      videoFrame: ImageSource,
 | 
					 | 
				
			||||||
      timestampOrImageProcessingOptions: number|ImageProcessingOptions,
 | 
					 | 
				
			||||||
      timestampOrCallback: number|SegmentationMaskCallback,
 | 
					 | 
				
			||||||
      callback?: SegmentationMaskCallback): void {
 | 
					 | 
				
			||||||
    const imageProcessingOptions =
 | 
					 | 
				
			||||||
        typeof timestampOrImageProcessingOptions !== 'number' ?
 | 
					 | 
				
			||||||
        timestampOrImageProcessingOptions :
 | 
					 | 
				
			||||||
        {};
 | 
					 | 
				
			||||||
    const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
 | 
					 | 
				
			||||||
        timestampOrImageProcessingOptions :
 | 
					 | 
				
			||||||
        timestampOrCallback as number;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    this.userCallback = typeof timestampOrCallback === 'function' ?
 | 
					 | 
				
			||||||
        timestampOrCallback :
 | 
					 | 
				
			||||||
        callback!;
 | 
					 | 
				
			||||||
    this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
 | 
					 | 
				
			||||||
    this.userCallback = () => {};
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Updates the MediaPipe graph configuration. */
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
| 
						 | 
					@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    const graphConfig = new CalculatorGraphConfig();
 | 
					    const graphConfig = new CalculatorGraphConfig();
 | 
				
			||||||
    graphConfig.addInputStream(IMAGE_STREAM);
 | 
					    graphConfig.addInputStream(IMAGE_STREAM);
 | 
				
			||||||
    graphConfig.addInputStream(NORM_RECT_STREAM);
 | 
					    graphConfig.addInputStream(NORM_RECT_STREAM);
 | 
				
			||||||
    graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const calculatorOptions = new CalculatorOptions();
 | 
					    const calculatorOptions = new CalculatorOptions();
 | 
				
			||||||
    calculatorOptions.setExtension(
 | 
					    calculatorOptions.setExtension(
 | 
				
			||||||
| 
						 | 
					@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
 | 
					    segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
 | 
				
			||||||
    segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
 | 
					    segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
 | 
				
			||||||
    segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
 | 
					    segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
 | 
				
			||||||
    segmenterNode.addOutputStream(
 | 
					 | 
				
			||||||
        'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
 | 
					 | 
				
			||||||
    segmenterNode.setOptions(calculatorOptions);
 | 
					    segmenterNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    graphConfig.addNode(segmenterNode);
 | 
					    graphConfig.addNode(segmenterNode);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (this.outputConfidenceMasks) {
 | 
				
			||||||
 | 
					      graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
 | 
					      segmenterNode.addOutputStream(
 | 
				
			||||||
 | 
					          'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageVectorListener(
 | 
					      this.graphRunner.attachImageVectorListener(
 | 
				
			||||||
        GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
 | 
					          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
				
			||||||
          if (masks.length === 0) {
 | 
					            this.result.confidenceMasks = masks.map(m => m.data);
 | 
				
			||||||
            this.userCallback([], 0, 0);
 | 
					            if (masks.length >= 0) {
 | 
				
			||||||
          } else {
 | 
					              this.result.width = masks[0].width;
 | 
				
			||||||
            this.userCallback(
 | 
					              this.result.height = masks[0].height;
 | 
				
			||||||
                masks.map(m => m.data), masks[0].width, masks[0].height);
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            this.setLatestOutputTimestamp(timestamp);
 | 
					            this.setLatestOutputTimestamp(timestamp);
 | 
				
			||||||
          });
 | 
					          });
 | 
				
			||||||
      this.graphRunner.attachEmptyPacketListener(
 | 
					      this.graphRunner.attachEmptyPacketListener(
 | 
				
			||||||
        GROUPED_SEGMENTATIONS_STREAM, timestamp => {
 | 
					          CONFIDENCE_MASKS_STREAM, timestamp => {
 | 
				
			||||||
            this.setLatestOutputTimestamp(timestamp);
 | 
					            this.setLatestOutputTimestamp(timestamp);
 | 
				
			||||||
          });
 | 
					          });
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (this.outputCategoryMask) {
 | 
				
			||||||
 | 
					      graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
 | 
				
			||||||
 | 
					      segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      this.graphRunner.attachImageListener(
 | 
				
			||||||
 | 
					          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
				
			||||||
 | 
					            this.result.categoryMask = mask.data;
 | 
				
			||||||
 | 
					            this.result.width = mask.width;
 | 
				
			||||||
 | 
					            this.result.height = mask.height;
 | 
				
			||||||
 | 
					            this.setLatestOutputTimestamp(timestamp);
 | 
				
			||||||
 | 
					          });
 | 
				
			||||||
 | 
					      this.graphRunner.attachEmptyPacketListener(
 | 
				
			||||||
 | 
					          CATEGORY_MASK_STREAM, timestamp => {
 | 
				
			||||||
 | 
					            this.setLatestOutputTimestamp(timestamp);
 | 
				
			||||||
 | 
					          });
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const binaryGraph = graphConfig.serializeBinary();
 | 
					    const binaryGraph = graphConfig.serializeBinary();
 | 
				
			||||||
    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
					    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  displayNamesLocale?: string|undefined;
 | 
					  displayNamesLocale?: string|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /** Whether to output confidence masks. Defaults to true. */
 | 
				
			||||||
   * The output type of segmentation results.
 | 
					  outputConfidenceMasks?: boolean|undefined;
 | 
				
			||||||
   *
 | 
					
 | 
				
			||||||
   * The two supported modes are:
 | 
					  /** Whether to output the category masks. Defaults to false. */
 | 
				
			||||||
   * - Category Mask:   Gives a single output mask where each pixel represents
 | 
					  outputCategoryMask?: boolean|undefined;
 | 
				
			||||||
   *                    the class which the pixel in the original image was
 | 
					 | 
				
			||||||
   *                    predicted to belong to.
 | 
					 | 
				
			||||||
   * - Confidence Mask: Gives a list of output masks (one for each class). For
 | 
					 | 
				
			||||||
   *                    each mask, the pixel represents the prediction
 | 
					 | 
				
			||||||
   *                    confidence, usually in the [0.0, 0.1] range.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * Defaults to `CATEGORY_MASK`.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										37
									
								
								mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,37 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** The output result of ImageSegmenter. */
 | 
				
			||||||
 | 
					export declare interface ImageSegmenterResult {
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
 | 
				
			||||||
 | 
					   * pixel represents the prediction confidence, usually in the [0, 1] range.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  confidenceMasks?: Float32Array[]|WebGLTexture[];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * A category mask as a Uint8ClampedArray or WebGLTexture where each
 | 
				
			||||||
 | 
					   * pixel represents the class which the pixel in the original image was
 | 
				
			||||||
 | 
					   * predicted to belong to.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  categoryMask?: Uint8ClampedArray|WebGLTexture;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** The width of the masks. */
 | 
				
			||||||
 | 
					  width: number;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** The height of the masks. */
 | 
				
			||||||
 | 
					  height: number;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -18,7 +18,7 @@ import 'jasmine';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Placeholder for internal dependency on encodeByteArray
 | 
					// Placeholder for internal dependency on encodeByteArray
 | 
				
			||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
 | 
					import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
 | 
				
			||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
 | 
					import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import {ImageSegmenter} from './image_segmenter';
 | 
					import {ImageSegmenter} from './image_segmenter';
 | 
				
			||||||
| 
						 | 
					@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
 | 
				
			||||||
  graph: CalculatorGraphConfig|undefined;
 | 
					  graph: CalculatorGraphConfig|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  fakeWasmModule: SpyWasmModule;
 | 
					  fakeWasmModule: SpyWasmModule;
 | 
				
			||||||
  imageVectorListener:
 | 
					  categoryMaskListener:
 | 
				
			||||||
 | 
					      ((images: WasmImage, timestamp: number) => void)|undefined;
 | 
				
			||||||
 | 
					  confidenceMasksListener:
 | 
				
			||||||
      ((images: WasmImage[], timestamp: number) => void)|undefined;
 | 
					      ((images: WasmImage[], timestamp: number) => void)|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  constructor() {
 | 
					  constructor() {
 | 
				
			||||||
| 
						 | 
					@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
 | 
				
			||||||
    this.fakeWasmModule =
 | 
					    this.fakeWasmModule =
 | 
				
			||||||
        this.graphRunner.wasmModule as unknown as SpyWasmModule;
 | 
					        this.graphRunner.wasmModule as unknown as SpyWasmModule;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    this.attachListenerSpies[0] =
 | 
					    this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
 | 
				
			||||||
 | 
					                                      .and.callFake((stream, listener) => {
 | 
				
			||||||
 | 
					                                        expect(stream).toEqual('category_mask');
 | 
				
			||||||
 | 
					                                        this.categoryMaskListener = listener;
 | 
				
			||||||
 | 
					                                      });
 | 
				
			||||||
 | 
					    this.attachListenerSpies[1] =
 | 
				
			||||||
        spyOn(this.graphRunner, 'attachImageVectorListener')
 | 
					        spyOn(this.graphRunner, 'attachImageVectorListener')
 | 
				
			||||||
            .and.callFake((stream, listener) => {
 | 
					            .and.callFake((stream, listener) => {
 | 
				
			||||||
              expect(stream).toEqual('segmented_masks');
 | 
					              expect(stream).toEqual('confidence_masks');
 | 
				
			||||||
              this.imageVectorListener = listener;
 | 
					              this.confidenceMasksListener = listener;
 | 
				
			||||||
            });
 | 
					            });
 | 
				
			||||||
    spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
 | 
					    spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
 | 
				
			||||||
      this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
 | 
					      this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
 | 
				
			||||||
| 
						 | 
					@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it('initializes graph', async () => {
 | 
					  it('initializes graph', async () => {
 | 
				
			||||||
    verifyGraph(imageSegmenter);
 | 
					    verifyGraph(imageSegmenter);
 | 
				
			||||||
    verifyListenersRegistered(imageSegmenter);
 | 
					
 | 
				
			||||||
 | 
					    // Verify default options
 | 
				
			||||||
 | 
					    expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
 | 
				
			||||||
 | 
					    expect(imageSegmenter.confidenceMasksListener).toBeDefined();
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it('reloads graph when settings are changed', async () => {
 | 
					  it('reloads graph when settings are changed', async () => {
 | 
				
			||||||
    await imageSegmenter.setOptions({displayNamesLocale: 'en'});
 | 
					    await imageSegmenter.setOptions({displayNamesLocale: 'en'});
 | 
				
			||||||
    verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
 | 
					    verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
 | 
				
			||||||
    verifyListenersRegistered(imageSegmenter);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await imageSegmenter.setOptions({displayNamesLocale: 'de'});
 | 
					    await imageSegmenter.setOptions({displayNamesLocale: 'de'});
 | 
				
			||||||
    verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
 | 
					    verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
 | 
				
			||||||
    verifyListenersRegistered(imageSegmenter);
 | 
					 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it('can use custom models', async () => {
 | 
					  it('can use custom models', async () => {
 | 
				
			||||||
| 
						 | 
					@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it('merges options', async () => {
 | 
					  it('merges options', async () => {
 | 
				
			||||||
    await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
 | 
					    await imageSegmenter.setOptions(
 | 
				
			||||||
 | 
					        {baseOptions: {modelAssetBuffer: new Uint8Array([])}});
 | 
				
			||||||
    await imageSegmenter.setOptions({displayNamesLocale: 'en'});
 | 
					    await imageSegmenter.setOptions({displayNamesLocale: 'en'});
 | 
				
			||||||
    verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
 | 
					    verifyGraph(
 | 
				
			||||||
 | 
					        imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
 | 
				
			||||||
    verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
 | 
					    verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
 | 
				
			||||||
      defaultValue: unknown;
 | 
					      defaultValue: unknown;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const testCases: TestCase[] = [
 | 
					    const testCases: TestCase[] = [{
 | 
				
			||||||
      {
 | 
					 | 
				
			||||||
      optionName: 'displayNamesLocale',
 | 
					      optionName: 'displayNamesLocale',
 | 
				
			||||||
      fieldPath: ['displayNamesLocale'],
 | 
					      fieldPath: ['displayNamesLocale'],
 | 
				
			||||||
      userValue: 'en',
 | 
					      userValue: 'en',
 | 
				
			||||||
      graphValue: 'en',
 | 
					      graphValue: 'en',
 | 
				
			||||||
      defaultValue: 'en'
 | 
					      defaultValue: 'en'
 | 
				
			||||||
      },
 | 
					    }];
 | 
				
			||||||
      {
 | 
					 | 
				
			||||||
        optionName: 'outputType',
 | 
					 | 
				
			||||||
        fieldPath: ['segmenterOptions', 'outputType'],
 | 
					 | 
				
			||||||
        userValue: 'CONFIDENCE_MASK',
 | 
					 | 
				
			||||||
        graphValue: 2,
 | 
					 | 
				
			||||||
        defaultValue: 1
 | 
					 | 
				
			||||||
      },
 | 
					 | 
				
			||||||
    ];
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (const testCase of testCases) {
 | 
					    for (const testCase of testCases) {
 | 
				
			||||||
      it(`can set ${testCase.optionName}`, async () => {
 | 
					      it(`can set ${testCase.optionName}`, async () => {
 | 
				
			||||||
| 
						 | 
					@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
 | 
				
			||||||
    }).toThrowError('This task doesn\'t support region-of-interest.');
 | 
					    }).toThrowError('This task doesn\'t support region-of-interest.');
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it('supports category masks', (done) => {
 | 
					  it('supports category mask', async () => {
 | 
				
			||||||
    const mask = new Uint8ClampedArray([1, 2, 3, 4]);
 | 
					    const mask = new Uint8ClampedArray([1, 2, 3, 4]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await imageSegmenter.setOptions(
 | 
				
			||||||
 | 
					        {outputCategoryMask: true, outputConfidenceMasks: false});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Pass the test data to our listener
 | 
					    // Pass the test data to our listener
 | 
				
			||||||
    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
					    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
				
			||||||
      verifyListenersRegistered(imageSegmenter);
 | 
					      expect(imageSegmenter.categoryMaskListener).toBeDefined();
 | 
				
			||||||
      imageSegmenter.imageVectorListener!(
 | 
					      imageSegmenter.categoryMaskListener!
 | 
				
			||||||
          [
 | 
					          ({data: mask, width: 2, height: 2},
 | 
				
			||||||
            {data: mask, width: 2, height: 2},
 | 
					 | 
				
			||||||
          ],
 | 
					 | 
				
			||||||
           /* timestamp= */ 1337);
 | 
					           /* timestamp= */ 1337);
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Invoke the image segmenter
 | 
					    // Invoke the image segmenter
 | 
				
			||||||
    imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
 | 
					
 | 
				
			||||||
 | 
					    return new Promise<void>(resolve => {
 | 
				
			||||||
 | 
					      imageSegmenter.segment({} as HTMLImageElement, result => {
 | 
				
			||||||
        expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
					        expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
				
			||||||
      expect(masks).toHaveSize(1);
 | 
					        expect(result.categoryMask).toEqual(mask);
 | 
				
			||||||
      expect(masks[0]).toEqual(mask);
 | 
					        expect(result.confidenceMasks).not.toBeDefined();
 | 
				
			||||||
      expect(width).toEqual(2);
 | 
					        expect(result.width).toEqual(2);
 | 
				
			||||||
      expect(height).toEqual(2);
 | 
					        expect(result.height).toEqual(2);
 | 
				
			||||||
      done();
 | 
					        resolve();
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
 | 
				
			||||||
    const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
 | 
					    const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
 | 
				
			||||||
    const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
 | 
					    const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
 | 
					    await imageSegmenter.setOptions(
 | 
				
			||||||
 | 
					        {outputCategoryMask: false, outputConfidenceMasks: true});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Pass the test data to our listener
 | 
					    // Pass the test data to our listener
 | 
				
			||||||
    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
					    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
				
			||||||
      verifyListenersRegistered(imageSegmenter);
 | 
					      expect(imageSegmenter.confidenceMasksListener).toBeDefined();
 | 
				
			||||||
      imageSegmenter.imageVectorListener!(
 | 
					      imageSegmenter.confidenceMasksListener!(
 | 
				
			||||||
          [
 | 
					          [
 | 
				
			||||||
            {data: mask1, width: 2, height: 2},
 | 
					            {data: mask1, width: 2, height: 2},
 | 
				
			||||||
            {data: mask2, width: 2, height: 2},
 | 
					            {data: mask2, width: 2, height: 2},
 | 
				
			||||||
| 
						 | 
					@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return new Promise<void>(resolve => {
 | 
					    return new Promise<void>(resolve => {
 | 
				
			||||||
      // Invoke the image segmenter
 | 
					      // Invoke the image segmenter
 | 
				
			||||||
      imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
 | 
					      imageSegmenter.segment({} as HTMLImageElement, result => {
 | 
				
			||||||
        expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
					        expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
				
			||||||
        expect(masks).toHaveSize(2);
 | 
					        expect(result.categoryMask).not.toBeDefined();
 | 
				
			||||||
        expect(masks[0]).toEqual(mask1);
 | 
					        expect(result.confidenceMasks).toEqual([mask1, mask2]);
 | 
				
			||||||
        expect(masks[1]).toEqual(mask2);
 | 
					        expect(result.width).toEqual(2);
 | 
				
			||||||
        expect(width).toEqual(2);
 | 
					        expect(result.height).toEqual(2);
 | 
				
			||||||
        expect(height).toEqual(2);
 | 
					        resolve();
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  it('supports combined category and confidence masks', async () => {
 | 
				
			||||||
 | 
					    const categoryMask = new Uint8ClampedArray([1, 0]);
 | 
				
			||||||
 | 
					    const confidenceMask1 = new Float32Array([0.0, 1.0]);
 | 
				
			||||||
 | 
					    const confidenceMask2 = new Float32Array([1.0, 0.0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    await imageSegmenter.setOptions(
 | 
				
			||||||
 | 
					        {outputCategoryMask: true, outputConfidenceMasks: true});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Pass the test data to our listener
 | 
				
			||||||
 | 
					    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
				
			||||||
 | 
					      expect(imageSegmenter.categoryMaskListener).toBeDefined();
 | 
				
			||||||
 | 
					      expect(imageSegmenter.confidenceMasksListener).toBeDefined();
 | 
				
			||||||
 | 
					      imageSegmenter.categoryMaskListener!
 | 
				
			||||||
 | 
					          ({data: categoryMask, width: 1, height: 1}, 1337);
 | 
				
			||||||
 | 
					      imageSegmenter.confidenceMasksListener!(
 | 
				
			||||||
 | 
					          [
 | 
				
			||||||
 | 
					            {data: confidenceMask1, width: 1, height: 1},
 | 
				
			||||||
 | 
					            {data: confidenceMask2, width: 1, height: 1},
 | 
				
			||||||
 | 
					          ],
 | 
				
			||||||
 | 
					          1337);
 | 
				
			||||||
 | 
					    });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return new Promise<void>(resolve => {
 | 
				
			||||||
 | 
					      // Invoke the image segmenter
 | 
				
			||||||
 | 
					      imageSegmenter.segment({} as HTMLImageElement, result => {
 | 
				
			||||||
 | 
					        expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
				
			||||||
 | 
					        expect(result.categoryMask).toEqual(categoryMask);
 | 
				
			||||||
 | 
					        expect(result.confidenceMasks).toEqual([
 | 
				
			||||||
 | 
					          confidenceMask1, confidenceMask2
 | 
				
			||||||
 | 
					        ]);
 | 
				
			||||||
 | 
					        expect(result.width).toEqual(1);
 | 
				
			||||||
 | 
					        expect(result.height).toEqual(1);
 | 
				
			||||||
        resolve();
 | 
					        resolve();
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user