Add MediaPipe Image Segmenter task for Web
PiperOrigin-RevId: 504912518
This commit is contained in:
parent
29001234d5
commit
4d38557f11
|
@ -23,6 +23,7 @@ VISION_LIBS = [
|
||||||
"//mediapipe/tasks/web/vision/hand_landmarker",
|
"//mediapipe/tasks/web/vision/hand_landmarker",
|
||||||
"//mediapipe/tasks/web/vision/image_classifier",
|
"//mediapipe/tasks/web/vision/image_classifier",
|
||||||
"//mediapipe/tasks/web/vision/image_embedder",
|
"//mediapipe/tasks/web/vision/image_embedder",
|
||||||
|
"//mediapipe/tasks/web/vision/image_segmenter",
|
||||||
"//mediapipe/tasks/web/vision/object_detector",
|
"//mediapipe/tasks/web/vision/object_detector",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,23 @@ const classifications = imageClassifier.classify(image);
|
||||||
|
|
||||||
For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation.
|
For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation.
|
||||||
|
|
||||||
|
## Image Segmentation
|
||||||
|
|
||||||
|
The MediaPipe Image Segmenter lets you segment an image into categories.
|
||||||
|
|
||||||
|
```
|
||||||
|
const vision = await FilesetResolver.forVisionTasks(
|
||||||
|
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||||
|
);
|
||||||
|
const imageSegmenter = await ImageSegmenter.createFromModelPath(vision,
|
||||||
|
"model.tflite"
|
||||||
|
);
|
||||||
|
const image = document.getElementById("image") as HTMLImageElement;
|
||||||
|
imageSegmenter.segment(image, (masks, width, height) => {
|
||||||
|
...
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
## Gesture Recognition
|
## Gesture Recognition
|
||||||
|
|
||||||
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
|
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
|
||||||
|
|
58
mediapipe/tasks/web/vision/image_segmenter/BUILD
Normal file
58
mediapipe/tasks/web/vision/image_segmenter/BUILD
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# This contains the MediaPipe Image Segmenter Task.
|
||||||
|
|
||||||
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image_segmenter",
|
||||||
|
srcs = ["image_segmenter.ts"],
|
||||||
|
deps = [
|
||||||
|
":image_segmenter_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_declaration(
|
||||||
|
name = "image_segmenter_types",
|
||||||
|
srcs = ["image_segmenter_options.d.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image_segmenter_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"image_segmenter_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":image_segmenter",
|
||||||
|
":image_segmenter_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "image_segmenter_test",
|
||||||
|
tags = ["nomsan"],
|
||||||
|
deps = [":image_segmenter_test_lib"],
|
||||||
|
)
|
300
mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts
Normal file
300
mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts
Normal file
|
@ -0,0 +1,300 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
|
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
|
||||||
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
|
||||||
|
export * from './image_segmenter_options';
|
||||||
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
|
||||||
|
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
|
||||||
|
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
|
||||||
|
* for future usage.
|
||||||
|
*/
|
||||||
|
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A callback that receives the computed masks from the image segmenter. The
|
||||||
|
* callback either receives a single element array with a category mask (as a
|
||||||
|
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
|
||||||
|
* The returned data is only valid for the duration of the callback. If
|
||||||
|
* asynchronous processing is needed, all data needs to be copied before the
|
||||||
|
* callback returns.
|
||||||
|
*/
|
||||||
|
export type SegmentationMaskCallback =
|
||||||
|
(masks: SegmentationMask[], width: number, height: number) => void;
|
||||||
|
|
||||||
|
const IMAGE_STREAM = 'image_in';
|
||||||
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
|
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
||||||
|
const IMAGEA_SEGMENTER_GRAPH =
|
||||||
|
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/** Performs image segmentation on images. */
|
||||||
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
private userCallback: SegmentationMaskCallback = () => {};
|
||||||
|
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||||
|
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new image segmenter from the
|
||||||
|
* provided options.
|
||||||
|
* @param wasmFileset A configuration object that provides the location of
|
||||||
|
* the Wasm binary and its loader.
|
||||||
|
* @param imageSegmenterOptions The options for the Image Segmenter. Note
|
||||||
|
* that either a path to the model asset or a model buffer needs to be
|
||||||
|
* provided (via `baseOptions`).
|
||||||
|
*/
|
||||||
|
static createFromOptions(
|
||||||
|
wasmFileset: WasmFileset,
|
||||||
|
imageSegmenterOptions: ImageSegmenterOptions): Promise<ImageSegmenter> {
|
||||||
|
return VisionTaskRunner.createInstance(
|
||||||
|
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
imageSegmenterOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new image segmenter based on
|
||||||
|
* the provided model asset buffer.
|
||||||
|
* @param wasmFileset A configuration object that provides the location of
|
||||||
|
* the Wasm binary and its loader.
|
||||||
|
* @param modelAssetBuffer A binary representation of the model.
|
||||||
|
*/
|
||||||
|
static createFromModelBuffer(
|
||||||
|
wasmFileset: WasmFileset,
|
||||||
|
modelAssetBuffer: Uint8Array): Promise<ImageSegmenter> {
|
||||||
|
return VisionTaskRunner.createInstance(
|
||||||
|
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetBuffer}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes the Wasm runtime and creates a new image segmenter based on
|
||||||
|
* the path to the model asset.
|
||||||
|
* @param wasmFileset A configuration object that provides the location of
|
||||||
|
* the Wasm binary and its loader.
|
||||||
|
* @param modelAssetPath The path to the model asset.
|
||||||
|
*/
|
||||||
|
static createFromModelPath(
|
||||||
|
wasmFileset: WasmFileset,
|
||||||
|
modelAssetPath: string): Promise<ImageSegmenter> {
|
||||||
|
return VisionTaskRunner.createInstance(
|
||||||
|
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||||
|
{baseOptions: {modelAssetPath}});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @hideconstructor */
|
||||||
|
constructor(
|
||||||
|
wasmModule: WasmModule,
|
||||||
|
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||||
|
super(
|
||||||
|
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
|
||||||
|
NORM_RECT_STREAM, /* roiAllowed= */ false);
|
||||||
|
this.options = new ImageSegmenterGraphOptionsProto();
|
||||||
|
this.segmenterOptions = new SegmenterOptionsProto();
|
||||||
|
this.options.setSegmenterOptions(this.segmenterOptions);
|
||||||
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
|
return this.options.getBaseOptions()!;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||||
|
this.options.setBaseOptions(proto);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets new options for the image segmenter.
|
||||||
|
*
|
||||||
|
* Calling `setOptions()` with a subset of options only affects those
|
||||||
|
* options. You can reset an option back to its default value by
|
||||||
|
* explicitly setting it to `undefined`.
|
||||||
|
*
|
||||||
|
* @param options The options for the image segmenter.
|
||||||
|
*/
|
||||||
|
override setOptions(options: ImageSegmenterOptions): Promise<void> {
|
||||||
|
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||||
|
// have to expliclity clear the values instead of setting them to
|
||||||
|
// `undefined`.
|
||||||
|
if (options.displayNamesLocale !== undefined) {
|
||||||
|
this.options.setDisplayNamesLocale(options.displayNamesLocale);
|
||||||
|
} else if ('displayNamesLocale' in options) { // Check for undefined
|
||||||
|
this.options.clearDisplayNamesLocale();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||||
|
this.segmenterOptions.setOutputType(
|
||||||
|
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||||
|
} else {
|
||||||
|
this.segmenterOptions.setOutputType(
|
||||||
|
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||||
|
}
|
||||||
|
|
||||||
|
return super.applyOptions(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided single image and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `image`.
|
||||||
|
*
|
||||||
|
* @param image An image to process.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided single image and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `image`.
|
||||||
|
*
|
||||||
|
* @param image An image to process.
|
||||||
|
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||||
|
* to process the input image before running inference.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segment(
|
||||||
|
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||||
|
callback: SegmentationMaskCallback): void;
|
||||||
|
segment(
|
||||||
|
image: ImageSource,
|
||||||
|
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||||
|
SegmentationMaskCallback,
|
||||||
|
callback?: SegmentationMaskCallback): void {
|
||||||
|
const imageProcessingOptions =
|
||||||
|
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||||
|
imageProcessingOptionsOrCallback :
|
||||||
|
{};
|
||||||
|
|
||||||
|
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||||
|
imageProcessingOptionsOrCallback :
|
||||||
|
callback!;
|
||||||
|
this.processImageData(image, imageProcessingOptions);
|
||||||
|
this.userCallback = () => {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided video frame and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `video`.
|
||||||
|
*
|
||||||
|
* @param videoFrame A video frame to process.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource, timestamp: number,
|
||||||
|
callback: SegmentationMaskCallback): void;
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided video frame and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `video`.
|
||||||
|
*
|
||||||
|
* @param videoFrame A video frame to process.
|
||||||
|
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||||
|
* to process the input image before running inference.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||||
|
timestamp: number, callback: SegmentationMaskCallback): void;
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource,
|
||||||
|
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
||||||
|
timestampOrCallback: number|SegmentationMaskCallback,
|
||||||
|
callback?: SegmentationMaskCallback): void {
|
||||||
|
const imageProcessingOptions =
|
||||||
|
typeof timestampOrImageProcessingOptions !== 'number' ?
|
||||||
|
timestampOrImageProcessingOptions :
|
||||||
|
{};
|
||||||
|
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||||
|
timestampOrImageProcessingOptions :
|
||||||
|
timestampOrCallback as number;
|
||||||
|
|
||||||
|
this.userCallback = typeof timestampOrCallback === 'function' ?
|
||||||
|
timestampOrCallback :
|
||||||
|
callback!;
|
||||||
|
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||||
|
this.userCallback = () => {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
protected override refreshGraph(): void {
|
||||||
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
|
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
|
||||||
|
|
||||||
|
const calculatorOptions = new CalculatorOptions();
|
||||||
|
calculatorOptions.setExtension(
|
||||||
|
ImageSegmenterGraphOptionsProto.ext, this.options);
|
||||||
|
|
||||||
|
const segmenterNode = new CalculatorGraphConfig.Node();
|
||||||
|
segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
|
||||||
|
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||||
|
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||||
|
segmenterNode.addOutputStream(
|
||||||
|
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
|
||||||
|
segmenterNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
|
graphConfig.addNode(segmenterNode);
|
||||||
|
|
||||||
|
this.graphRunner.attachImageVectorListener(
|
||||||
|
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
|
||||||
|
if (masks.length === 0) {
|
||||||
|
this.userCallback([], 0, 0);
|
||||||
|
} else {
|
||||||
|
this.userCallback(
|
||||||
|
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||||
|
}
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
|
||||||
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
41
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts
vendored
Normal file
41
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts
vendored
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options';
|
||||||
|
|
||||||
|
/** Options to configure the MediaPipe Image Segmenter Task */
|
||||||
|
export interface ImageSegmenterOptions extends VisionTaskOptions {
|
||||||
|
/**
|
||||||
|
* The locale to use for display names specified through the TFLite Model
|
||||||
|
* Metadata, if any. Defaults to English.
|
||||||
|
*/
|
||||||
|
displayNamesLocale?: string|undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The output type of segmentation results.
|
||||||
|
*
|
||||||
|
* The two supported modes are:
|
||||||
|
* - Category Mask: Gives a single output mask where each pixel represents
|
||||||
|
* the class which the pixel in the original image was
|
||||||
|
* predicted to belong to.
|
||||||
|
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||||
|
* each mask, the pixel represents the prediction
|
||||||
|
* confidence, usually in the [0.0, 0.1] range.
|
||||||
|
*
|
||||||
|
* Defaults to `CATEGORY_MASK`.
|
||||||
|
*/
|
||||||
|
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||||
|
}
|
|
@ -0,0 +1,215 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
|
||||||
|
import {ImageSegmenter} from './image_segmenter';
|
||||||
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
|
||||||
|
class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
|
calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
imageVectorListener:
|
||||||
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('segmented_masks');
|
||||||
|
this.imageVectorListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('ImageSegmenter', () => {
|
||||||
|
let imageSegmenter: ImageSegmenterFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
imageSegmenter = new ImageSegmenterFake();
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(imageSegmenter);
|
||||||
|
verifyListenersRegistered(imageSegmenter);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||||
|
verifyListenersRegistered(imageSegmenter);
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
||||||
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
||||||
|
verifyListenersRegistered(imageSegmenter);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await imageSegmenter.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
imageSegmenter,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */
|
||||||
|
[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||||
|
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||||
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setOptions()', () => {
|
||||||
|
interface TestCase {
|
||||||
|
optionName: keyof ImageSegmenterOptions;
|
||||||
|
fieldPath: string[];
|
||||||
|
userValue: unknown;
|
||||||
|
graphValue: unknown;
|
||||||
|
defaultValue: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
const testCases: TestCase[] = [
|
||||||
|
{
|
||||||
|
optionName: 'displayNamesLocale',
|
||||||
|
fieldPath: ['displayNamesLocale'],
|
||||||
|
userValue: 'en',
|
||||||
|
graphValue: 'en',
|
||||||
|
defaultValue: 'en'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionName: 'outputType',
|
||||||
|
fieldPath: ['segmenterOptions', 'outputType'],
|
||||||
|
userValue: 'CONFIDENCE_MASK',
|
||||||
|
graphValue: 2,
|
||||||
|
defaultValue: 1
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const testCase of testCases) {
|
||||||
|
it(`can set ${testCase.optionName}`, async () => {
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{[testCase.optionName]: testCase.userValue});
|
||||||
|
verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it(`can clear ${testCase.optionName}`, async () => {
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{[testCase.optionName]: testCase.userValue});
|
||||||
|
verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
|
||||||
|
await imageSegmenter.setOptions({[testCase.optionName]: undefined});
|
||||||
|
verifyGraph(
|
||||||
|
imageSegmenter, [testCase.fieldPath, testCase.defaultValue]);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('doesn\'t support region of interest', () => {
|
||||||
|
expect(() => {
|
||||||
|
imageSegmenter.segment(
|
||||||
|
{} as HTMLImageElement,
|
||||||
|
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
|
||||||
|
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('supports category masks', (done) => {
|
||||||
|
const mask = new Uint8Array([1, 2, 3, 4]);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(imageSegmenter);
|
||||||
|
imageSegmenter.imageVectorListener!(
|
||||||
|
[
|
||||||
|
{data: mask, width: 2, height: 2},
|
||||||
|
],
|
||||||
|
/* timestamp= */ 1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the image segmenter
|
||||||
|
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
||||||
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(masks).toHaveSize(1);
|
||||||
|
expect(masks[0]).toEqual(mask);
|
||||||
|
expect(width).toEqual(2);
|
||||||
|
expect(height).toEqual(2);
|
||||||
|
done();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('supports confidence masks', async () => {
|
||||||
|
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||||
|
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(imageSegmenter);
|
||||||
|
imageSegmenter.imageVectorListener!(
|
||||||
|
[
|
||||||
|
{data: mask1, width: 2, height: 2},
|
||||||
|
{data: mask2, width: 2, height: 2},
|
||||||
|
],
|
||||||
|
1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Promise<void>(resolve => {
|
||||||
|
// Invoke the image segmenter
|
||||||
|
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
||||||
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(masks).toHaveSize(2);
|
||||||
|
expect(masks[0]).toEqual(mask1);
|
||||||
|
expect(masks[1]).toEqual(mask2);
|
||||||
|
expect(width).toEqual(2);
|
||||||
|
expect(height).toEqual(2);
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -19,6 +19,7 @@ import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vis
|
||||||
import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
||||||
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
|
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||||
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
|
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||||
|
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||||
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
|
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
|
||||||
|
|
||||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||||
|
@ -28,6 +29,7 @@ const GestureRecognizer = GestureRecognizerImpl;
|
||||||
const HandLandmarker = HandLandmarkerImpl;
|
const HandLandmarker = HandLandmarkerImpl;
|
||||||
const ImageClassifier = ImageClassifierImpl;
|
const ImageClassifier = ImageClassifierImpl;
|
||||||
const ImageEmbedder = ImageEmbedderImpl;
|
const ImageEmbedder = ImageEmbedderImpl;
|
||||||
|
const ImageSegmenter = ImageSegementerImpl;
|
||||||
const ObjectDetector = ObjectDetectorImpl;
|
const ObjectDetector = ObjectDetectorImpl;
|
||||||
|
|
||||||
export {
|
export {
|
||||||
|
@ -36,5 +38,6 @@ export {
|
||||||
HandLandmarker,
|
HandLandmarker,
|
||||||
ImageClassifier,
|
ImageClassifier,
|
||||||
ImageEmbedder,
|
ImageEmbedder,
|
||||||
|
ImageSegmenter,
|
||||||
ObjectDetector
|
ObjectDetector
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,4 +19,5 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
|
||||||
export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
||||||
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
|
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||||
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
|
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||||
|
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||||
export * from '../../../tasks/web/vision/object_detector/object_detector';
|
export * from '../../../tasks/web/vision/object_detector/object_detector';
|
||||||
|
|
Loading…
Reference in New Issue
Block a user