Update InteractiveSegmenter to return MPImage
PiperOrigin-RevId: 528010944
This commit is contained in:
		
							parent
							
								
									bbbc0f98c5
								
							
						
					
					
						commit
						dcef6df1cb
					
				| 
						 | 
					@ -35,7 +35,6 @@ const COLOR_MAP: Array<[number, number, number, number]> = [
 | 
				
			||||||
  [255, 255, 255, CM_ALPHA]   // class 11 is white; could do black instead?
 | 
					  [255, 255, 255, CM_ALPHA]   // class 11 is white; could do black instead?
 | 
				
			||||||
];
 | 
					];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
/** Helper function to draw a confidence mask */
 | 
					/** Helper function to draw a confidence mask */
 | 
				
			||||||
export function drawConfidenceMask(
 | 
					export function drawConfidenceMask(
 | 
				
			||||||
  ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
 | 
					  ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
 | 
				
			||||||
| 
						 | 
					@ -50,33 +49,6 @@ export function drawConfidenceMask(
 | 
				
			||||||
  ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
 | 
					  ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					 | 
				
			||||||
 * Helper function to draw a category mask. For GPU, we only have F32Arrays
 | 
					 | 
				
			||||||
 * for now.
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
export function drawCategoryMask(
 | 
					 | 
				
			||||||
    ctx: CanvasRenderingContext2D, image: Uint8ClampedArray|Float32Array,
 | 
					 | 
				
			||||||
    width: number, height: number): void {
 | 
					 | 
				
			||||||
  const rgbaArray = new Uint8ClampedArray(width * height * 4);
 | 
					 | 
				
			||||||
  const isFloatArray = image instanceof Float32Array;
 | 
					 | 
				
			||||||
  for (let i = 0; i < image.length; i++) {
 | 
					 | 
				
			||||||
    const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
 | 
					 | 
				
			||||||
    let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (!color) {
 | 
					 | 
				
			||||||
      // TODO: We should fix this.
 | 
					 | 
				
			||||||
      console.warn('No color for ', colorIndex);
 | 
					 | 
				
			||||||
      color = COLOR_MAP[colorIndex % COLOR_MAP.length];
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    rgbaArray[4 * i] = color[0];
 | 
					 | 
				
			||||||
    rgbaArray[4 * i + 1] = color[1];
 | 
					 | 
				
			||||||
    rgbaArray[4 * i + 2] = color[2];
 | 
					 | 
				
			||||||
    rgbaArray[4 * i + 3] = color[3];
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/** The color converter we use in our demos. */
 | 
					/** The color converter we use in our demos. */
 | 
				
			||||||
export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = {
 | 
					export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = {
 | 
				
			||||||
  floatToRGBAConverter: v => [128, 0, 0, v * 255],
 | 
					  floatToRGBAConverter: v => [128, 0, 0, v * 255],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										10
									
								
								mediapipe/tasks/web/vision/core/types.d.ts
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								mediapipe/tasks/web/vision/core/types.d.ts
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -16,16 +16,6 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint';
 | 
					import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					 | 
				
			||||||
 * The segmentation tasks return the segmentation either as a WebGLTexture (when
 | 
					 | 
				
			||||||
 * the output is on GPU) or as a typed JavaScript arrays for CPU-based
 | 
					 | 
				
			||||||
 * category or confidence masks. `Uint8ClampedArray`s are used to represent
 | 
					 | 
				
			||||||
 * CPU-based category masks and `Float32Array`s are used for CPU-based
 | 
					 | 
				
			||||||
 * confidence masks.
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/** A Region-Of-Interest (ROI) to represent a region within an image. */
 | 
					/** A Region-Of-Interest (ROI) to represent a region within an image. */
 | 
				
			||||||
export declare interface RegionOfInterest {
 | 
					export declare interface RegionOfInterest {
 | 
				
			||||||
  /** The ROI in keypoint format. */
 | 
					  /** The ROI in keypoint format. */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,6 +37,7 @@ mediapipe_ts_declaration(
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
        "//mediapipe/tasks/web/core:classifier_options",
 | 
					        "//mediapipe/tasks/web/core:classifier_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/vision/core:image",
 | 
				
			||||||
        "//mediapipe/tasks/web/vision/core:vision_task_options",
 | 
					        "//mediapipe/tasks/web/vision/core:vision_task_options",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -53,6 +54,7 @@ mediapipe_ts_library(
 | 
				
			||||||
        "//mediapipe/framework:calculator_jspb_proto",
 | 
					        "//mediapipe/framework:calculator_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
        "//mediapipe/tasks/web/core:task_runner_test_utils",
 | 
					        "//mediapipe/tasks/web/core:task_runner_test_utils",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/vision/core:image",
 | 
				
			||||||
        "//mediapipe/util:render_data_jspb_proto",
 | 
					        "//mediapipe/util:render_data_jspb_proto",
 | 
				
			||||||
        "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
 | 
					        "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
 | 
				
			||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
					import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
				
			||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
					import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
				
			||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
					import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
				
			||||||
import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types';
 | 
					import {RegionOfInterest} from '../../../../tasks/web/vision/core/types';
 | 
				
			||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
					import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
				
			||||||
import {Color as ColorProto} from '../../../../util/color_pb';
 | 
					import {Color as ColorProto} from '../../../../util/color_pb';
 | 
				
			||||||
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
 | 
					import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
 | 
				
			||||||
| 
						 | 
					@ -33,7 +33,7 @@ import {InteractiveSegmenterResult} from './interactive_segmenter_result';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export * from './interactive_segmenter_options';
 | 
					export * from './interactive_segmenter_options';
 | 
				
			||||||
export * from './interactive_segmenter_result';
 | 
					export * from './interactive_segmenter_result';
 | 
				
			||||||
export {SegmentationMask, RegionOfInterest};
 | 
					export {RegionOfInterest};
 | 
				
			||||||
export {ImageSource};
 | 
					export {ImageSource};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const IMAGE_IN_STREAM = 'image_in';
 | 
					const IMAGE_IN_STREAM = 'image_in';
 | 
				
			||||||
| 
						 | 
					@ -83,7 +83,7 @@ export type InteractiveSegmenterCallback =
 | 
				
			||||||
 *   - batch is always 1
 | 
					 *   - batch is always 1
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
export class InteractiveSegmenter extends VisionTaskRunner {
 | 
					export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
  private result: InteractiveSegmenterResult = {width: 0, height: 0};
 | 
					  private result: InteractiveSegmenterResult = {};
 | 
				
			||||||
  private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
 | 
					  private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
 | 
				
			||||||
  private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
 | 
					  private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
 | 
				
			||||||
  private readonly options: ImageSegmenterGraphOptionsProto;
 | 
					  private readonly options: ImageSegmenterGraphOptionsProto;
 | 
				
			||||||
| 
						 | 
					@ -253,7 +253,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private reset(): void {
 | 
					  private reset(): void {
 | 
				
			||||||
    this.result = {width: 0, height: 0};
 | 
					    this.result = {};
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Updates the MediaPipe graph configuration. */
 | 
					  /** Updates the MediaPipe graph configuration. */
 | 
				
			||||||
| 
						 | 
					@ -283,12 +283,8 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageVectorListener(
 | 
					      this.graphRunner.attachImageVectorListener(
 | 
				
			||||||
          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
					          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
				
			||||||
            this.result.confidenceMasks = masks.map(m => m.data);
 | 
					            this.result.confidenceMasks =
 | 
				
			||||||
            if (masks.length >= 0) {
 | 
					                masks.map(wasmImage => this.convertToMPImage(wasmImage));
 | 
				
			||||||
              this.result.width = masks[0].width;
 | 
					 | 
				
			||||||
              this.result.height = masks[0].height;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            this.setLatestOutputTimestamp(timestamp);
 | 
					            this.setLatestOutputTimestamp(timestamp);
 | 
				
			||||||
          });
 | 
					          });
 | 
				
			||||||
      this.graphRunner.attachEmptyPacketListener(
 | 
					      this.graphRunner.attachEmptyPacketListener(
 | 
				
			||||||
| 
						 | 
					@ -303,9 +299,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageListener(
 | 
					      this.graphRunner.attachImageListener(
 | 
				
			||||||
          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
					          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
				
			||||||
            this.result.categoryMask = mask.data;
 | 
					            this.result.categoryMask = this.convertToMPImage(mask);
 | 
				
			||||||
            this.result.width = mask.width;
 | 
					 | 
				
			||||||
            this.result.height = mask.height;
 | 
					 | 
				
			||||||
            this.setLatestOutputTimestamp(timestamp);
 | 
					            this.setLatestOutputTimestamp(timestamp);
 | 
				
			||||||
          });
 | 
					          });
 | 
				
			||||||
      this.graphRunner.attachEmptyPacketListener(
 | 
					      this.graphRunner.attachEmptyPacketListener(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,24 +14,21 @@
 | 
				
			||||||
 * limitations under the License.
 | 
					 * limitations under the License.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {MPImage} from '../../../../tasks/web/vision/core/image';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/** The output result of InteractiveSegmenter. */
 | 
					/** The output result of InteractiveSegmenter. */
 | 
				
			||||||
export declare interface InteractiveSegmenterResult {
 | 
					export declare interface InteractiveSegmenterResult {
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
 | 
					   * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
 | 
				
			||||||
   * pixel represents the prediction confidence, usually in the [0, 1] range.
 | 
					   * `MPImage`s where, for each mask, each pixel represents the prediction
 | 
				
			||||||
 | 
					   * confidence, usually in the [0, 1] range.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  confidenceMasks?: Float32Array[]|WebGLTexture[];
 | 
					  confidenceMasks?: MPImage[];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * A category mask as a Uint8ClampedArray or WebGLTexture where each
 | 
					   * A category mask represented as a `Uint8ClampedArray` or
 | 
				
			||||||
   * pixel represents the class which the pixel in the original image was
 | 
					   * `WebGLTexture`-backed `MPImage` where each pixel represents the class which
 | 
				
			||||||
   * predicted to belong to.
 | 
					   * the pixel in the original image was predicted to belong to.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  categoryMask?: Uint8ClampedArray|WebGLTexture;
 | 
					  categoryMask?: MPImage;
 | 
				
			||||||
 | 
					 | 
				
			||||||
  /** The width of the masks. */
 | 
					 | 
				
			||||||
  width: number;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /** The height of the masks. */
 | 
					 | 
				
			||||||
  height: number;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@ import 'jasmine';
 | 
				
			||||||
// Placeholder for internal dependency on encodeByteArray
 | 
					// Placeholder for internal dependency on encodeByteArray
 | 
				
			||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
					import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
				
			||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
 | 
					import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
 | 
				
			||||||
 | 
					import {MPImage} from '../../../../tasks/web/vision/core/image';
 | 
				
			||||||
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
 | 
					import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
 | 
				
			||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
 | 
					import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,10 +171,10 @@ describe('InteractiveSegmenter', () => {
 | 
				
			||||||
      interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
 | 
					      interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
 | 
				
			||||||
        expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
 | 
					        expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
 | 
				
			||||||
            .toHaveBeenCalled();
 | 
					            .toHaveBeenCalled();
 | 
				
			||||||
        expect(result.categoryMask).toEqual(mask);
 | 
					        expect(result.categoryMask).toBeInstanceOf(MPImage);
 | 
				
			||||||
 | 
					        expect(result.categoryMask!.width).toEqual(2);
 | 
				
			||||||
 | 
					        expect(result.categoryMask!.height).toEqual(2);
 | 
				
			||||||
        expect(result.confidenceMasks).not.toBeDefined();
 | 
					        expect(result.confidenceMasks).not.toBeDefined();
 | 
				
			||||||
        expect(result.width).toEqual(2);
 | 
					 | 
				
			||||||
        expect(result.height).toEqual(2);
 | 
					 | 
				
			||||||
        resolve();
 | 
					        resolve();
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
| 
						 | 
					@ -202,18 +203,21 @@ describe('InteractiveSegmenter', () => {
 | 
				
			||||||
        expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
 | 
					        expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
 | 
				
			||||||
            .toHaveBeenCalled();
 | 
					            .toHaveBeenCalled();
 | 
				
			||||||
        expect(result.categoryMask).not.toBeDefined();
 | 
					        expect(result.categoryMask).not.toBeDefined();
 | 
				
			||||||
        expect(result.confidenceMasks).toEqual([mask1, mask2]);
 | 
					
 | 
				
			||||||
        expect(result.width).toEqual(2);
 | 
					        expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
 | 
				
			||||||
        expect(result.height).toEqual(2);
 | 
					        expect(result.confidenceMasks![0].width).toEqual(2);
 | 
				
			||||||
 | 
					        expect(result.confidenceMasks![0].height).toEqual(2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
 | 
				
			||||||
        resolve();
 | 
					        resolve();
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  it('supports combined category and confidence masks', async () => {
 | 
					  it('supports combined category and confidence masks', async () => {
 | 
				
			||||||
    const categoryMask = new Uint8ClampedArray([1, 0]);
 | 
					    const categoryMask = new Uint8ClampedArray([1]);
 | 
				
			||||||
    const confidenceMask1 = new Float32Array([0.0, 1.0]);
 | 
					    const confidenceMask1 = new Float32Array([0.0]);
 | 
				
			||||||
    const confidenceMask2 = new Float32Array([1.0, 0.0]);
 | 
					    const confidenceMask2 = new Float32Array([1.0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    await interactiveSegmenter.setOptions(
 | 
					    await interactiveSegmenter.setOptions(
 | 
				
			||||||
        {outputCategoryMask: true, outputConfidenceMasks: true});
 | 
					        {outputCategoryMask: true, outputConfidenceMasks: true});
 | 
				
			||||||
| 
						 | 
					@ -238,12 +242,12 @@ describe('InteractiveSegmenter', () => {
 | 
				
			||||||
          {} as HTMLImageElement, ROI, result => {
 | 
					          {} as HTMLImageElement, ROI, result => {
 | 
				
			||||||
            expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
 | 
					            expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
 | 
				
			||||||
                .toHaveBeenCalled();
 | 
					                .toHaveBeenCalled();
 | 
				
			||||||
            expect(result.categoryMask).toEqual(categoryMask);
 | 
					            expect(result.categoryMask).toBeInstanceOf(MPImage);
 | 
				
			||||||
            expect(result.confidenceMasks).toEqual([
 | 
					            expect(result.categoryMask!.width).toEqual(1);
 | 
				
			||||||
              confidenceMask1, confidenceMask2
 | 
					            expect(result.categoryMask!.height).toEqual(1);
 | 
				
			||||||
            ]);
 | 
					
 | 
				
			||||||
            expect(result.width).toEqual(1);
 | 
					            expect(result.confidenceMasks![0]).toBeInstanceOf(MPImage);
 | 
				
			||||||
            expect(result.height).toEqual(1);
 | 
					            expect(result.confidenceMasks![1]).toBeInstanceOf(MPImage);
 | 
				
			||||||
            resolve();
 | 
					            resolve();
 | 
				
			||||||
          });
 | 
					          });
 | 
				
			||||||
    });
 | 
					    });
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user