Add MediaPipe Image Segmenter task for Web
PiperOrigin-RevId: 504912518
This commit is contained in:
		
							parent
							
								
									29001234d5
								
							
						
					
					
						commit
						4d38557f11
					
				| 
						 | 
				
			
			@ -23,6 +23,7 @@ VISION_LIBS = [
 | 
			
		|||
    "//mediapipe/tasks/web/vision/hand_landmarker",
 | 
			
		||||
    "//mediapipe/tasks/web/vision/image_classifier",
 | 
			
		||||
    "//mediapipe/tasks/web/vision/image_embedder",
 | 
			
		||||
    "//mediapipe/tasks/web/vision/image_segmenter",
 | 
			
		||||
    "//mediapipe/tasks/web/vision/object_detector",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,6 +39,23 @@ const classifications = imageClassifier.classify(image);
 | 
			
		|||
 | 
			
		||||
For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation.
 | 
			
		||||
 | 
			
		||||
## Image Segmentation
 | 
			
		||||
 | 
			
		||||
The MediaPipe Image Segmenter lets you segment an image into categories.
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
const vision = await FilesetResolver.forVisionTasks(
 | 
			
		||||
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
 | 
			
		||||
);
 | 
			
		||||
const imageSegmenter = await ImageSegmenter.createFromModelPath(vision,
 | 
			
		||||
    "model.tflite"
 | 
			
		||||
);
 | 
			
		||||
const image = document.getElementById("image") as HTMLImageElement;
 | 
			
		||||
imageSegmenter.segment(image, (masks, width, height) => {
 | 
			
		||||
  ...
 | 
			
		||||
});
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Gesture Recognition
 | 
			
		||||
 | 
			
		||||
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										58
									
								
								mediapipe/tasks/web/vision/image_segmenter/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								mediapipe/tasks/web/vision/image_segmenter/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,58 @@
 | 
			
		|||
# This contains the MediaPipe Image Segmenter Task.
 | 
			
		||||
 | 
			
		||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
 | 
			
		||||
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "image_segmenter",
 | 
			
		||||
    srcs = ["image_segmenter.ts"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image_segmenter_types",
 | 
			
		||||
        "//mediapipe/framework:calculator_jspb_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/web/core",
 | 
			
		||||
        "//mediapipe/tasks/web/vision/core:image_processing_options",
 | 
			
		||||
        "//mediapipe/tasks/web/vision/core:vision_task_runner",
 | 
			
		||||
        "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
 | 
			
		||||
        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_declaration(
 | 
			
		||||
    name = "image_segmenter_types",
 | 
			
		||||
    srcs = ["image_segmenter_options.d.ts"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/web/core",
 | 
			
		||||
        "//mediapipe/tasks/web/core:classifier_options",
 | 
			
		||||
        "//mediapipe/tasks/web/vision/core:vision_task_options",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
mediapipe_ts_library(
 | 
			
		||||
    name = "image_segmenter_test_lib",
 | 
			
		||||
    testonly = True,
 | 
			
		||||
    srcs = [
 | 
			
		||||
        "image_segmenter_test.ts",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":image_segmenter",
 | 
			
		||||
        ":image_segmenter_types",
 | 
			
		||||
        "//mediapipe/framework:calculator_jspb_proto",
 | 
			
		||||
        "//mediapipe/tasks/web/core",
 | 
			
		||||
        "//mediapipe/tasks/web/core:task_runner_test_utils",
 | 
			
		||||
        "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
jasmine_node_test(
 | 
			
		||||
    name = "image_segmenter_test",
 | 
			
		||||
    tags = ["nomsan"],
 | 
			
		||||
    deps = [":image_segmenter_test_lib"],
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										300
									
								
								mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										300
									
								
								mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,300 @@
 | 
			
		|||
/**
 | 
			
		||||
 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
			
		||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
 | 
			
		||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
			
		||||
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
 | 
			
		||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
			
		||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
			
		||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
			
		||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
			
		||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
 | 
			
		||||
// Placeholder for internal dependency on trusted resource url
 | 
			
		||||
 | 
			
		||||
import {ImageSegmenterOptions} from './image_segmenter_options';
 | 
			
		||||
 | 
			
		||||
export * from './image_segmenter_options';
 | 
			
		||||
export {ImageSource};  // Used in the public API
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The ImageSegmenter returns the segmentation result as a Uint8Array (when
 | 
			
		||||
 * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
 | 
			
		||||
 * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
 | 
			
		||||
 * for future usage.
 | 
			
		||||
 */
 | 
			
		||||
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A callback that receives the computed masks from the image segmenter. The
 | 
			
		||||
 * callback either receives a single element array with a category mask (as a
 | 
			
		||||
 * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
 | 
			
		||||
 * The returned data is only valid for the duration of the callback. If
 | 
			
		||||
 * asynchronous processing is needed, all data needs to be copied before the
 | 
			
		||||
 * callback returns.
 | 
			
		||||
 */
 | 
			
		||||
export type SegmentationMaskCallback =
 | 
			
		||||
    (masks: SegmentationMask[], width: number, height: number) => void;
 | 
			
		||||
 | 
			
		||||
const IMAGE_STREAM = 'image_in';
 | 
			
		||||
const NORM_RECT_STREAM = 'norm_rect';
 | 
			
		||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
 | 
			
		||||
const IMAGEA_SEGMENTER_GRAPH =
 | 
			
		||||
    'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
 | 
			
		||||
 | 
			
		||||
// The OSS JS API does not support the builder pattern.
 | 
			
		||||
// tslint:disable:jspb-use-builder-pattern
 | 
			
		||||
 | 
			
		||||
/** Performs image segmentation on images. */
 | 
			
		||||
export class ImageSegmenter extends VisionTaskRunner {
 | 
			
		||||
  private userCallback: SegmentationMaskCallback = () => {};
 | 
			
		||||
  private readonly options: ImageSegmenterGraphOptionsProto;
 | 
			
		||||
  private readonly segmenterOptions: SegmenterOptionsProto;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Initializes the Wasm runtime and creates a new image segmenter from the
 | 
			
		||||
   * provided options.
 | 
			
		||||
   * @param wasmFileset A configuration object that provides the location of
 | 
			
		||||
   *     the Wasm binary and its loader.
 | 
			
		||||
   * @param imageSegmenterOptions The options for the Image Segmenter. Note
 | 
			
		||||
   *     that either a path to the model asset or a model buffer needs to be
 | 
			
		||||
   *     provided (via `baseOptions`).
 | 
			
		||||
   */
 | 
			
		||||
  static createFromOptions(
 | 
			
		||||
      wasmFileset: WasmFileset,
 | 
			
		||||
      imageSegmenterOptions: ImageSegmenterOptions): Promise<ImageSegmenter> {
 | 
			
		||||
    return VisionTaskRunner.createInstance(
 | 
			
		||||
        ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
 | 
			
		||||
        imageSegmenterOptions);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Initializes the Wasm runtime and creates a new image segmenter based on
 | 
			
		||||
   * the provided model asset buffer.
 | 
			
		||||
   * @param wasmFileset A configuration object that provides the location of
 | 
			
		||||
   *     the Wasm binary and its loader.
 | 
			
		||||
   * @param modelAssetBuffer A binary representation of the model.
 | 
			
		||||
   */
 | 
			
		||||
  static createFromModelBuffer(
 | 
			
		||||
      wasmFileset: WasmFileset,
 | 
			
		||||
      modelAssetBuffer: Uint8Array): Promise<ImageSegmenter> {
 | 
			
		||||
    return VisionTaskRunner.createInstance(
 | 
			
		||||
        ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
 | 
			
		||||
        {baseOptions: {modelAssetBuffer}});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Initializes the Wasm runtime and creates a new image segmenter based on
 | 
			
		||||
   * the path to the model asset.
 | 
			
		||||
   * @param wasmFileset A configuration object that provides the location of
 | 
			
		||||
   *     the Wasm binary and its loader.
 | 
			
		||||
   * @param modelAssetPath The path to the model asset.
 | 
			
		||||
   */
 | 
			
		||||
  static createFromModelPath(
 | 
			
		||||
      wasmFileset: WasmFileset,
 | 
			
		||||
      modelAssetPath: string): Promise<ImageSegmenter> {
 | 
			
		||||
    return VisionTaskRunner.createInstance(
 | 
			
		||||
        ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
 | 
			
		||||
        {baseOptions: {modelAssetPath}});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** @hideconstructor */
 | 
			
		||||
  constructor(
 | 
			
		||||
      wasmModule: WasmModule,
 | 
			
		||||
      glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
 | 
			
		||||
    super(
 | 
			
		||||
        new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
 | 
			
		||||
        NORM_RECT_STREAM, /* roiAllowed= */ false);
 | 
			
		||||
    this.options = new ImageSegmenterGraphOptionsProto();
 | 
			
		||||
    this.segmenterOptions = new SegmenterOptionsProto();
 | 
			
		||||
    this.options.setSegmenterOptions(this.segmenterOptions);
 | 
			
		||||
    this.options.setBaseOptions(new BaseOptionsProto());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  protected override get baseOptions(): BaseOptionsProto {
 | 
			
		||||
    return this.options.getBaseOptions()!;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  protected override set baseOptions(proto: BaseOptionsProto) {
 | 
			
		||||
    this.options.setBaseOptions(proto);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Sets new options for the image segmenter.
 | 
			
		||||
   *
 | 
			
		||||
   * Calling `setOptions()` with a subset of options only affects those
 | 
			
		||||
   * options. You can reset an option back to its default value by
 | 
			
		||||
   * explicitly setting it to `undefined`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param options The options for the image segmenter.
 | 
			
		||||
   */
 | 
			
		||||
  override setOptions(options: ImageSegmenterOptions): Promise<void> {
 | 
			
		||||
    // Note that we have to support both JSPB and ProtobufJS, hence we
 | 
			
		||||
    // have to expliclity clear the values instead of setting them to
 | 
			
		||||
    // `undefined`.
 | 
			
		||||
    if (options.displayNamesLocale !== undefined) {
 | 
			
		||||
      this.options.setDisplayNamesLocale(options.displayNamesLocale);
 | 
			
		||||
    } else if ('displayNamesLocale' in options) {  // Check for undefined
 | 
			
		||||
      this.options.clearDisplayNamesLocale();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (options.outputType === 'CONFIDENCE_MASK') {
 | 
			
		||||
      this.segmenterOptions.setOutputType(
 | 
			
		||||
          SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
 | 
			
		||||
    } else {
 | 
			
		||||
      this.segmenterOptions.setOutputType(
 | 
			
		||||
          SegmenterOptionsProto.OutputType.CATEGORY_MASK);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return super.applyOptions(options);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided single image and invokes the
 | 
			
		||||
   * callback with the response. The method returns synchronously once the
 | 
			
		||||
   * callback returns. Only use this method when the ImageSegmenter is
 | 
			
		||||
   * created with running mode `image`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image An image to process.
 | 
			
		||||
   * @param callback The callback that is invoked with the segmented masks. The
 | 
			
		||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
			
		||||
   *    callback.
 | 
			
		||||
   */
 | 
			
		||||
  segment(image: ImageSource, callback: SegmentationMaskCallback): void;
 | 
			
		||||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided single image and invokes the
 | 
			
		||||
   * callback with the response. The method returns synchronously once the
 | 
			
		||||
   * callback returns. Only use this method when the ImageSegmenter is
 | 
			
		||||
   * created with running mode `image`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param image An image to process.
 | 
			
		||||
   * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
 | 
			
		||||
   *    to process the input image before running inference.
 | 
			
		||||
   * @param callback The callback that is invoked with the segmented masks. The
 | 
			
		||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
			
		||||
   *    callback.
 | 
			
		||||
   */
 | 
			
		||||
  segment(
 | 
			
		||||
      image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
 | 
			
		||||
      callback: SegmentationMaskCallback): void;
 | 
			
		||||
  segment(
 | 
			
		||||
      image: ImageSource,
 | 
			
		||||
      imageProcessingOptionsOrCallback: ImageProcessingOptions|
 | 
			
		||||
      SegmentationMaskCallback,
 | 
			
		||||
      callback?: SegmentationMaskCallback): void {
 | 
			
		||||
    const imageProcessingOptions =
 | 
			
		||||
        typeof imageProcessingOptionsOrCallback !== 'function' ?
 | 
			
		||||
        imageProcessingOptionsOrCallback :
 | 
			
		||||
        {};
 | 
			
		||||
 | 
			
		||||
    this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
 | 
			
		||||
        imageProcessingOptionsOrCallback :
 | 
			
		||||
        callback!;
 | 
			
		||||
    this.processImageData(image, imageProcessingOptions);
 | 
			
		||||
    this.userCallback = () => {};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided video frame and invokes the
 | 
			
		||||
   * callback with the response. The method returns synchronously once the
 | 
			
		||||
   * callback returns. Only use this method when the ImageSegmenter is
 | 
			
		||||
   * created with running mode `video`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param videoFrame A video frame to process.
 | 
			
		||||
   * @param timestamp The timestamp of the current frame, in ms.
 | 
			
		||||
   * @param callback The callback that is invoked with the segmented masks. The
 | 
			
		||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
			
		||||
   *    callback.
 | 
			
		||||
   */
 | 
			
		||||
  segmentForVideo(
 | 
			
		||||
      videoFrame: ImageSource, timestamp: number,
 | 
			
		||||
      callback: SegmentationMaskCallback): void;
 | 
			
		||||
  /**
 | 
			
		||||
   * Performs image segmentation on the provided video frame and invokes the
 | 
			
		||||
   * callback with the response. The method returns synchronously once the
 | 
			
		||||
   * callback returns. Only use this method when the ImageSegmenter is
 | 
			
		||||
   * created with running mode `video`.
 | 
			
		||||
   *
 | 
			
		||||
   * @param videoFrame A video frame to process.
 | 
			
		||||
   * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
 | 
			
		||||
   *    to process the input image before running inference.
 | 
			
		||||
   * @param timestamp The timestamp of the current frame, in ms.
 | 
			
		||||
   * @param callback The callback that is invoked with the segmented masks. The
 | 
			
		||||
   *    lifetime of the returned data is only guaranteed for the duration of the
 | 
			
		||||
   *    callback.
 | 
			
		||||
   */
 | 
			
		||||
  segmentForVideo(
 | 
			
		||||
      videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
 | 
			
		||||
      timestamp: number, callback: SegmentationMaskCallback): void;
 | 
			
		||||
  segmentForVideo(
 | 
			
		||||
      videoFrame: ImageSource,
 | 
			
		||||
      timestampOrImageProcessingOptions: number|ImageProcessingOptions,
 | 
			
		||||
      timestampOrCallback: number|SegmentationMaskCallback,
 | 
			
		||||
      callback?: SegmentationMaskCallback): void {
 | 
			
		||||
    const imageProcessingOptions =
 | 
			
		||||
        typeof timestampOrImageProcessingOptions !== 'number' ?
 | 
			
		||||
        timestampOrImageProcessingOptions :
 | 
			
		||||
        {};
 | 
			
		||||
    const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
 | 
			
		||||
        timestampOrImageProcessingOptions :
 | 
			
		||||
        timestampOrCallback as number;
 | 
			
		||||
 | 
			
		||||
    this.userCallback = typeof timestampOrCallback === 'function' ?
 | 
			
		||||
        timestampOrCallback :
 | 
			
		||||
        callback!;
 | 
			
		||||
    this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
 | 
			
		||||
    this.userCallback = () => {};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /** Updates the MediaPipe graph configuration. */
 | 
			
		||||
  protected override refreshGraph(): void {
 | 
			
		||||
    const graphConfig = new CalculatorGraphConfig();
 | 
			
		||||
    graphConfig.addInputStream(IMAGE_STREAM);
 | 
			
		||||
    graphConfig.addInputStream(NORM_RECT_STREAM);
 | 
			
		||||
    graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
 | 
			
		||||
 | 
			
		||||
    const calculatorOptions = new CalculatorOptions();
 | 
			
		||||
    calculatorOptions.setExtension(
 | 
			
		||||
        ImageSegmenterGraphOptionsProto.ext, this.options);
 | 
			
		||||
 | 
			
		||||
    const segmenterNode = new CalculatorGraphConfig.Node();
 | 
			
		||||
    segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
 | 
			
		||||
    segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
 | 
			
		||||
    segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
 | 
			
		||||
    segmenterNode.addOutputStream(
 | 
			
		||||
        'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
 | 
			
		||||
    segmenterNode.setOptions(calculatorOptions);
 | 
			
		||||
 | 
			
		||||
    graphConfig.addNode(segmenterNode);
 | 
			
		||||
 | 
			
		||||
    this.graphRunner.attachImageVectorListener(
 | 
			
		||||
        GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
 | 
			
		||||
          if (masks.length === 0) {
 | 
			
		||||
            this.userCallback([], 0, 0);
 | 
			
		||||
          } else {
 | 
			
		||||
            this.userCallback(
 | 
			
		||||
                masks.map(m => m.data), masks[0].width, masks[0].height);
 | 
			
		||||
          }
 | 
			
		||||
          this.setLatestOutputTimestamp(timestamp);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    const binaryGraph = graphConfig.serializeBinary();
 | 
			
		||||
    this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										41
									
								
								mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,41 @@
 | 
			
		|||
/**
 | 
			
		||||
 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options';
 | 
			
		||||
 | 
			
		||||
/** Options to configure the MediaPipe Image Segmenter Task */
 | 
			
		||||
export interface ImageSegmenterOptions extends VisionTaskOptions {
 | 
			
		||||
  /**
 | 
			
		||||
   * The locale to use for display names specified through the TFLite Model
 | 
			
		||||
   * Metadata, if any. Defaults to English.
 | 
			
		||||
   */
 | 
			
		||||
  displayNamesLocale?: string|undefined;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * The output type of segmentation results.
 | 
			
		||||
   *
 | 
			
		||||
   * The two supported modes are:
 | 
			
		||||
   * - Category Mask:   Gives a single output mask where each pixel represents
 | 
			
		||||
   *                    the class which the pixel in the original image was
 | 
			
		||||
   *                    predicted to belong to.
 | 
			
		||||
   * - Confidence Mask: Gives a list of output masks (one for each class). For
 | 
			
		||||
   *                    each mask, the pixel represents the prediction
 | 
			
		||||
   *                    confidence, usually in the [0.0, 0.1] range.
 | 
			
		||||
   *
 | 
			
		||||
   * Defaults to `CATEGORY_MASK`.
 | 
			
		||||
   */
 | 
			
		||||
  outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,215 @@
 | 
			
		|||
/**
 | 
			
		||||
 * Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import 'jasmine';
 | 
			
		||||
 | 
			
		||||
// Placeholder for internal dependency on encodeByteArray
 | 
			
		||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
 | 
			
		||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
 | 
			
		||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
 | 
			
		||||
 | 
			
		||||
import {ImageSegmenter} from './image_segmenter';
 | 
			
		||||
import {ImageSegmenterOptions} from './image_segmenter_options';
 | 
			
		||||
 | 
			
		||||
class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
 | 
			
		||||
  calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
 | 
			
		||||
  attachListenerSpies: jasmine.Spy[] = [];
 | 
			
		||||
  graph: CalculatorGraphConfig|undefined;
 | 
			
		||||
 | 
			
		||||
  fakeWasmModule: SpyWasmModule;
 | 
			
		||||
  imageVectorListener:
 | 
			
		||||
      ((images: WasmImage[], timestamp: number) => void)|undefined;
 | 
			
		||||
 | 
			
		||||
  constructor() {
 | 
			
		||||
    super(createSpyWasmModule(), /* glCanvas= */ null);
 | 
			
		||||
    this.fakeWasmModule =
 | 
			
		||||
        this.graphRunner.wasmModule as unknown as SpyWasmModule;
 | 
			
		||||
 | 
			
		||||
    this.attachListenerSpies[0] =
 | 
			
		||||
        spyOn(this.graphRunner, 'attachImageVectorListener')
 | 
			
		||||
            .and.callFake((stream, listener) => {
 | 
			
		||||
              expect(stream).toEqual('segmented_masks');
 | 
			
		||||
              this.imageVectorListener = listener;
 | 
			
		||||
            });
 | 
			
		||||
    spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
 | 
			
		||||
      this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
 | 
			
		||||
    });
 | 
			
		||||
    spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
describe('ImageSegmenter', () => {
 | 
			
		||||
  let imageSegmenter: ImageSegmenterFake;
 | 
			
		||||
 | 
			
		||||
  beforeEach(async () => {
 | 
			
		||||
    addJasmineCustomFloatEqualityTester();
 | 
			
		||||
    imageSegmenter = new ImageSegmenterFake();
 | 
			
		||||
    await imageSegmenter.setOptions(
 | 
			
		||||
        {baseOptions: {modelAssetBuffer: new Uint8Array([])}});
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('initializes graph', async () => {
 | 
			
		||||
    verifyGraph(imageSegmenter);
 | 
			
		||||
    verifyListenersRegistered(imageSegmenter);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('reloads graph when settings are changed', async () => {
 | 
			
		||||
    await imageSegmenter.setOptions({displayNamesLocale: 'en'});
 | 
			
		||||
    verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
 | 
			
		||||
    verifyListenersRegistered(imageSegmenter);
 | 
			
		||||
 | 
			
		||||
    await imageSegmenter.setOptions({displayNamesLocale: 'de'});
 | 
			
		||||
    verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
 | 
			
		||||
    verifyListenersRegistered(imageSegmenter);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('can use custom models', async () => {
 | 
			
		||||
    const newModel = new Uint8Array([0, 1, 2, 3, 4]);
 | 
			
		||||
    const newModelBase64 = Buffer.from(newModel).toString('base64');
 | 
			
		||||
    await imageSegmenter.setOptions({
 | 
			
		||||
      baseOptions: {
 | 
			
		||||
        modelAssetBuffer: newModel,
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    verifyGraph(
 | 
			
		||||
        imageSegmenter,
 | 
			
		||||
        /* expectedCalculatorOptions= */ undefined,
 | 
			
		||||
        /* expectedBaseOptions= */
 | 
			
		||||
        [
 | 
			
		||||
          'modelAsset', {
 | 
			
		||||
            fileContent: newModelBase64,
 | 
			
		||||
            fileName: undefined,
 | 
			
		||||
            fileDescriptorMeta: undefined,
 | 
			
		||||
            filePointerMeta: undefined
 | 
			
		||||
          }
 | 
			
		||||
        ]);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('merges options', async () => {
 | 
			
		||||
    await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
 | 
			
		||||
    await imageSegmenter.setOptions({displayNamesLocale: 'en'});
 | 
			
		||||
    verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
 | 
			
		||||
    verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  describe('setOptions()', () => {
 | 
			
		||||
    interface TestCase {
 | 
			
		||||
      optionName: keyof ImageSegmenterOptions;
 | 
			
		||||
      fieldPath: string[];
 | 
			
		||||
      userValue: unknown;
 | 
			
		||||
      graphValue: unknown;
 | 
			
		||||
      defaultValue: unknown;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const testCases: TestCase[] = [
 | 
			
		||||
      {
 | 
			
		||||
        optionName: 'displayNamesLocale',
 | 
			
		||||
        fieldPath: ['displayNamesLocale'],
 | 
			
		||||
        userValue: 'en',
 | 
			
		||||
        graphValue: 'en',
 | 
			
		||||
        defaultValue: 'en'
 | 
			
		||||
      },
 | 
			
		||||
      {
 | 
			
		||||
        optionName: 'outputType',
 | 
			
		||||
        fieldPath: ['segmenterOptions', 'outputType'],
 | 
			
		||||
        userValue: 'CONFIDENCE_MASK',
 | 
			
		||||
        graphValue: 2,
 | 
			
		||||
        defaultValue: 1
 | 
			
		||||
      },
 | 
			
		||||
    ];
 | 
			
		||||
 | 
			
		||||
    for (const testCase of testCases) {
 | 
			
		||||
      it(`can set ${testCase.optionName}`, async () => {
 | 
			
		||||
        await imageSegmenter.setOptions(
 | 
			
		||||
            {[testCase.optionName]: testCase.userValue});
 | 
			
		||||
        verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
      it(`can clear ${testCase.optionName}`, async () => {
 | 
			
		||||
        await imageSegmenter.setOptions(
 | 
			
		||||
            {[testCase.optionName]: testCase.userValue});
 | 
			
		||||
        verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
 | 
			
		||||
        await imageSegmenter.setOptions({[testCase.optionName]: undefined});
 | 
			
		||||
        verifyGraph(
 | 
			
		||||
            imageSegmenter, [testCase.fieldPath, testCase.defaultValue]);
 | 
			
		||||
      });
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('doesn\'t support region of interest', () => {
 | 
			
		||||
    expect(() => {
 | 
			
		||||
      imageSegmenter.segment(
 | 
			
		||||
          {} as HTMLImageElement,
 | 
			
		||||
          {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
 | 
			
		||||
    }).toThrowError('This task doesn\'t support region-of-interest.');
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('supports category masks', (done) => {
 | 
			
		||||
    const mask = new Uint8Array([1, 2, 3, 4]);
 | 
			
		||||
 | 
			
		||||
    // Pass the test data to our listener
 | 
			
		||||
    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
			
		||||
      verifyListenersRegistered(imageSegmenter);
 | 
			
		||||
      imageSegmenter.imageVectorListener!(
 | 
			
		||||
          [
 | 
			
		||||
            {data: mask, width: 2, height: 2},
 | 
			
		||||
          ],
 | 
			
		||||
          /* timestamp= */ 1337);
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    // Invoke the image segmenter
 | 
			
		||||
    imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
 | 
			
		||||
      expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
			
		||||
      expect(masks).toHaveSize(1);
 | 
			
		||||
      expect(masks[0]).toEqual(mask);
 | 
			
		||||
      expect(width).toEqual(2);
 | 
			
		||||
      expect(height).toEqual(2);
 | 
			
		||||
      done();
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  it('supports confidence masks', async () => {
 | 
			
		||||
    const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
 | 
			
		||||
    const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
 | 
			
		||||
 | 
			
		||||
    await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
 | 
			
		||||
 | 
			
		||||
    // Pass the test data to our listener
 | 
			
		||||
    imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
 | 
			
		||||
      verifyListenersRegistered(imageSegmenter);
 | 
			
		||||
      imageSegmenter.imageVectorListener!(
 | 
			
		||||
          [
 | 
			
		||||
            {data: mask1, width: 2, height: 2},
 | 
			
		||||
            {data: mask2, width: 2, height: 2},
 | 
			
		||||
          ],
 | 
			
		||||
          1337);
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    return new Promise<void>(resolve => {
 | 
			
		||||
      // Invoke the image segmenter
 | 
			
		||||
      imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
 | 
			
		||||
        expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
 | 
			
		||||
        expect(masks).toHaveSize(2);
 | 
			
		||||
        expect(masks[0]).toEqual(mask1);
 | 
			
		||||
        expect(masks[1]).toEqual(mask2);
 | 
			
		||||
        expect(width).toEqual(2);
 | 
			
		||||
        expect(height).toEqual(2);
 | 
			
		||||
        resolve();
 | 
			
		||||
      });
 | 
			
		||||
    });
 | 
			
		||||
  });
 | 
			
		||||
});
 | 
			
		||||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vis
 | 
			
		|||
import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
 | 
			
		||||
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
 | 
			
		||||
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
 | 
			
		||||
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
 | 
			
		||||
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
 | 
			
		||||
 | 
			
		||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
 | 
			
		||||
| 
						 | 
				
			
			@ -28,6 +29,7 @@ const GestureRecognizer = GestureRecognizerImpl;
 | 
			
		|||
const HandLandmarker = HandLandmarkerImpl;
 | 
			
		||||
const ImageClassifier = ImageClassifierImpl;
 | 
			
		||||
const ImageEmbedder = ImageEmbedderImpl;
 | 
			
		||||
const ImageSegmenter = ImageSegementerImpl;
 | 
			
		||||
const ObjectDetector = ObjectDetectorImpl;
 | 
			
		||||
 | 
			
		||||
export {
 | 
			
		||||
| 
						 | 
				
			
			@ -36,5 +38,6 @@ export {
 | 
			
		|||
  HandLandmarker,
 | 
			
		||||
  ImageClassifier,
 | 
			
		||||
  ImageEmbedder,
 | 
			
		||||
  ImageSegmenter,
 | 
			
		||||
  ObjectDetector
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,4 +19,5 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
 | 
			
		|||
export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
 | 
			
		||||
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
 | 
			
		||||
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
 | 
			
		||||
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
 | 
			
		||||
export * from '../../../tasks/web/vision/object_detector/object_detector';
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user