Add InteractiveSegmenter Web API
PiperOrigin-RevId: 516654090
This commit is contained in:
parent
6774794d02
commit
ec3cd45d61
|
@ -24,6 +24,7 @@ VISION_LIBS = [
|
|||
"//mediapipe/tasks/web/vision/image_classifier",
|
||||
"//mediapipe/tasks/web/vision/image_embedder",
|
||||
"//mediapipe/tasks/web/vision/image_segmenter",
|
||||
"//mediapipe/tasks/web/vision/interactive_segmenter",
|
||||
"//mediapipe/tasks/web/vision/object_detector",
|
||||
]
|
||||
|
||||
|
|
|
@ -75,6 +75,24 @@ imageSegmenter.segment(image, (masks, width, height) => {
|
|||
});
|
||||
```
|
||||
|
||||
## Interactive Segmentation
|
||||
|
||||
The MediaPipe Interactive Segmenter lets you select a region of interest to
|
||||
segment an image by.
|
||||
|
||||
```
|
||||
const vision = await FilesetResolver.forVisionTasks(
|
||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||
);
|
||||
const interactiveSegmenter = await InteractiveSegmenter.createFromModelPath(
|
||||
vision, "model.tflite"
|
||||
);
|
||||
const image = document.getElementById("image") as HTMLImageElement;
|
||||
interactiveSegmenter.segment(image, { keypoint: { x: 0.1, y: 0.2 } },
|
||||
(masks, width, height) => { ... }
|
||||
);
|
||||
```
|
||||
|
||||
## Object Detection
|
||||
|
||||
The MediaPipe Object Detector task lets you detect the presence and location of
|
||||
|
|
|
@ -20,6 +20,7 @@ import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/ha
|
|||
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||
import {InteractiveSegmenter as InteractiveSegmenterImpl} from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter';
|
||||
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
|
||||
|
||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||
|
@ -30,6 +31,7 @@ const HandLandmarker = HandLandmarkerImpl;
|
|||
const ImageClassifier = ImageClassifierImpl;
|
||||
const ImageEmbedder = ImageEmbedderImpl;
|
||||
const ImageSegmenter = ImageSegementerImpl;
|
||||
const InteractiveSegmenter = InteractiveSegmenterImpl;
|
||||
const ObjectDetector = ObjectDetectorImpl;
|
||||
|
||||
export {
|
||||
|
@ -39,5 +41,6 @@ export {
|
|||
ImageClassifier,
|
||||
ImageEmbedder,
|
||||
ImageSegmenter,
|
||||
InteractiveSegmenter,
|
||||
ObjectDetector
|
||||
};
|
||||
|
|
62
mediapipe/tasks/web/vision/interactive_segmenter/BUILD
Normal file
62
mediapipe/tasks/web/vision/interactive_segmenter/BUILD
Normal file
|
@ -0,0 +1,62 @@
|
|||
# This contains the MediaPipe Interactive Segmenter Task.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "interactive_segmenter",
|
||||
srcs = ["interactive_segmenter.ts"],
|
||||
deps = [
|
||||
":interactive_segmenter_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:keypoint",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/web/vision/core:types",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||
"//mediapipe/util:color_jspb_proto",
|
||||
"//mediapipe/util:render_data_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_declaration(
|
||||
name = "interactive_segmenter_types",
|
||||
srcs = ["interactive_segmenter_options.d.ts"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "interactive_segmenter_test_lib",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"interactive_segmenter_test.ts",
|
||||
],
|
||||
deps = [
|
||||
":interactive_segmenter",
|
||||
":interactive_segmenter_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||
"//mediapipe/util:render_data_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "interactive_segmenter_test",
|
||||
tags = ["nomsan"],
|
||||
deps = [":interactive_segmenter_test_lib"],
|
||||
)
|
|
@ -0,0 +1,306 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
|
||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {Color as ColorProto} from '../../../../util/color_pb';
|
||||
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
|
||||
|
||||
export * from './interactive_segmenter_options';
|
||||
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
|
||||
export {ImageSource};
|
||||
|
||||
const IMAGE_IN_STREAM = 'image_in';
|
||||
const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||
const ROI_IN_STREAM = 'roi_in';
|
||||
const IMAGE_OUT_STREAM = 'image_out';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* Performs interactive segmentation on images.
|
||||
*
|
||||
* Users can represent user interaction through `RegionOfInterest`, which gives
|
||||
* a hint to InteractiveSegmenter to perform segmentation focusing on the given
|
||||
* region of interest.
|
||||
*
|
||||
* The API expects a TFLite model with mandatory TFLite Model Metadata.
|
||||
*
|
||||
* Input tensor:
|
||||
* (kTfLiteUInt8/kTfLiteFloat32)
|
||||
* - image input of size `[batch x height x width x channels]`.
|
||||
* - batch inference is not supported (`batch` is required to be 1).
|
||||
* - RGB inputs is supported (`channels` is required to be 3).
|
||||
* - if type is kTfLiteFloat32, NormalizationOptions are required to be
|
||||
* attached to the metadata for input normalization.
|
||||
* Output tensors:
|
||||
* (kTfLiteUInt8/kTfLiteFloat32)
|
||||
* - list of segmented masks.
|
||||
* - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1.
|
||||
* - if `output_type` is CONFIDENCE_MASK, float32 Image list of size
|
||||
* `channels`.
|
||||
* - batch is always 1
|
||||
*/
|
||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new interactive segmenter from
|
||||
* the provided options.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param interactiveSegmenterOptions The options for the Interactive
|
||||
* Segmenter. Note that either a path to the model asset or a model buffer
|
||||
* needs to be provided (via `baseOptions`).
|
||||
* @return A new `InteractiveSegmenter`.
|
||||
*/
|
||||
static createFromOptions(
|
||||
wasmFileset: WasmFileset,
|
||||
interactiveSegmenterOptions: InteractiveSegmenterOptions):
|
||||
Promise<InteractiveSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
interactiveSegmenterOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new interactive segmenter based
|
||||
* on the provided model asset buffer.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param modelAssetBuffer A binary representation of the model.
|
||||
* @return A new `InteractiveSegmenter`.
|
||||
*/
|
||||
static createFromModelBuffer(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetBuffer: Uint8Array): Promise<InteractiveSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
{baseOptions: {modelAssetBuffer}});
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new interactive segmenter based
|
||||
* on the path to the model asset.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param modelAssetPath The path to the model asset.
|
||||
* @return A new `InteractiveSegmenter`.
|
||||
*/
|
||||
static createFromModelPath(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetPath: string): Promise<InteractiveSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
{baseOptions: {modelAssetPath}});
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
wasmModule: WasmModule,
|
||||
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||
super(
|
||||
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_IN_STREAM,
|
||||
NORM_RECT_IN_STREAM, /* roiAllowed= */ false);
|
||||
this.options = new ImageSegmenterGraphOptionsProto();
|
||||
this.segmenterOptions = new SegmenterOptionsProto();
|
||||
this.options.setSegmenterOptions(this.segmenterOptions);
|
||||
this.options.setBaseOptions(new BaseOptionsProto());
|
||||
}
|
||||
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
return this.options.getBaseOptions()!;
|
||||
}
|
||||
|
||||
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||
this.options.setBaseOptions(proto);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets new options for the interactive segmenter.
|
||||
*
|
||||
* Calling `setOptions()` with a subset of options only affects those
|
||||
* options. You can reset an option back to its default value by
|
||||
* explicitly setting it to `undefined`.
|
||||
*
|
||||
* @param options The options for the interactive segmenter.
|
||||
* @return A Promise that resolves when the settings have been applied.
|
||||
*/
|
||||
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
|
||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||
} else {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs interactive segmentation on the provided single image and invokes
|
||||
* the callback with the response. The `roi` parameter is used to represent a
|
||||
* user's region of interest for segmentation.
|
||||
*
|
||||
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
|
||||
* of images that represent per-category segmented image mask. If the
|
||||
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
|
||||
* images that contains only one confidence image mask. The method returns
|
||||
* synchronously once the callback returns.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param roi The region of interest for segmentation.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
/**
|
||||
* Performs interactive segmentation on the provided single image and invokes
|
||||
* the callback with the response. The `roi` parameter is used to represent a
|
||||
* user's region of interest for segmentation.
|
||||
*
|
||||
* The 'image_processing_options' parameter can be used to specify the
|
||||
* rotation to apply to the image before performing segmentation, by setting
|
||||
* its 'rotationDegrees' field. Note that specifying a region-of-interest
|
||||
* using the 'regionOfInterest' field is NOT supported and will result in an
|
||||
* error.
|
||||
*
|
||||
* If the output_type is `CATEGORY_MASK`, the callback is invoked with vector
|
||||
* of images that represent per-category segmented image mask. If the
|
||||
* output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of
|
||||
* images that contains only one confidence image mask. The method returns
|
||||
* synchronously once the callback returns.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param roi The region of interest for segmentation.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptions: ImageProcessingOptions,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||
SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
const imageProcessingOptions =
|
||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
|
||||
this.processRenderData(roi, this.getSynctheticTimestamp());
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_IN_STREAM);
|
||||
graphConfig.addInputStream(ROI_IN_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
|
||||
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
ImageSegmenterGraphOptionsProto.ext, this.options);
|
||||
|
||||
const segmenterNode = new CalculatorGraphConfig.Node();
|
||||
segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
|
||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
|
||||
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
|
||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
|
||||
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
|
||||
segmenterNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(segmenterNode);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
IMAGE_OUT_STREAM, (masks, timestamp) => {
|
||||
if (masks.length === 0) {
|
||||
this.userCallback([], 0, 0);
|
||||
} else {
|
||||
this.userCallback(
|
||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||
}
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the user-facing RegionOfInterest message to the RenderData proto
|
||||
* and sends it to the graph
|
||||
*/
|
||||
private processRenderData(roi: RegionOfInterest, timestamp: number): void {
|
||||
const renderData = new RenderDataProto();
|
||||
|
||||
const renderAnnotation = new RenderAnnotationProto();
|
||||
|
||||
const color = new ColorProto();
|
||||
color.setR(255);
|
||||
renderAnnotation.setColor(color);
|
||||
|
||||
const point = new RenderAnnotationProto.Point();
|
||||
point.setNormalized(true);
|
||||
point.setX(roi.keypoint.x);
|
||||
point.setY(roi.keypoint.y);
|
||||
renderAnnotation.setPoint(point);
|
||||
|
||||
renderData.addRenderAnnotations(renderAnnotation);
|
||||
|
||||
this.graphRunner.addProtoToStream(
|
||||
renderData.serializeBinary(), 'mediapipe.RenderData', ROI_IN_STREAM,
|
||||
timestamp);
|
||||
}
|
||||
}
|
||||
|
||||
|
36
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts
vendored
Normal file
36
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts
vendored
Normal file
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options';
|
||||
|
||||
/** Options to configure the MediaPipe Interactive Segmenter Task */
|
||||
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
|
||||
/**
|
||||
* The output type of segmentation results.
|
||||
*
|
||||
* The two supported modes are:
|
||||
* - Category Mask: Gives a single output mask where each pixel represents
|
||||
* the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||
* each mask, the pixel represents the prediction
|
||||
* confidence, usually in the [0.0, 0.1] range.
|
||||
*
|
||||
* Defaults to `CATEGORY_MASK`.
|
||||
*/
|
||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||
}
|
|
@ -0,0 +1,214 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||
|
||||
import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter';
|
||||
|
||||
|
||||
const ROI: RegionOfInterest = {
|
||||
keypoint: {x: 0.1, y: 0.2}
|
||||
};
|
||||
|
||||
class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||
MediapipeTasksFake {
|
||||
calculatorName =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
attachListenerSpies: jasmine.Spy[] = [];
|
||||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
imageVectorListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('image_out');
|
||||
this.imageVectorListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||
|
||||
spyOn(this.graphRunner, 'addProtoToStream')
|
||||
.and.callFake((data, protoName, stream) => {
|
||||
if (stream === 'roi_in') {
|
||||
expect(protoName).toEqual('mediapipe.RenderData');
|
||||
this.lastRoi = RenderDataProto.deserializeBinary(data);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
describe('InteractiveSegmenter', () => {
|
||||
let interactiveSegmenter: InteractiveSegmenterFake;
|
||||
|
||||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
interactiveSegmenter = new InteractiveSegmenterFake();
|
||||
await interactiveSegmenter.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(interactiveSegmenter);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||
await interactiveSegmenter.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: newModel,
|
||||
}
|
||||
});
|
||||
|
||||
verifyGraph(
|
||||
interactiveSegmenter,
|
||||
/* expectedCalculatorOptions= */ undefined,
|
||||
/* expectedBaseOptions= */
|
||||
[
|
||||
'modelAsset', {
|
||||
fileContent: newModelBase64,
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
|
||||
describe('setOptions()', () => {
|
||||
const fieldPath = ['segmenterOptions', 'outputType'];
|
||||
|
||||
it(`can set outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
});
|
||||
|
||||
it(`can clear outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
await interactiveSegmenter.setOptions({outputType: undefined});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
|
||||
});
|
||||
});
|
||||
|
||||
it('doesn\'t support region of interest', () => {
|
||||
expect(() => {
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI,
|
||||
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
|
||||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||
});
|
||||
|
||||
it('sends region-of-interest', (done) => {
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
expect(interactiveSegmenter.lastRoi).toBeDefined();
|
||||
expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0])
|
||||
.toEqual(jasmine.objectContaining({
|
||||
color: {r: 255, b: undefined, g: undefined},
|
||||
}));
|
||||
done();
|
||||
});
|
||||
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
|
||||
});
|
||||
|
||||
it('supports category masks', (done) => {
|
||||
const mask = new Uint8Array([1, 2, 3, 4]);
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask, width: 2, height: 2},
|
||||
],
|
||||
/* timestamp= */ 1337);
|
||||
});
|
||||
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(1);
|
||||
expect(masks[0]).toEqual(mask);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
done();
|
||||
});
|
||||
});
|
||||
|
||||
it('supports confidence masks', async () => {
|
||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask1, width: 2, height: 2},
|
||||
{data: mask2, width: 2, height: 2},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(2);
|
||||
expect(masks[0]).toEqual(mask1);
|
||||
expect(masks[1]).toEqual(mask2);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -20,4 +20,5 @@ export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
|||
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||
export * from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter';
|
||||
export * from '../../../tasks/web/vision/object_detector/object_detector';
|
||||
|
|
Loading…
Reference in New Issue
Block a user