Refactor Web code for InteractiveSegmenter

PiperOrigin-RevId: 516254891
This commit is contained in:
Sebastian Schmidt 2023-03-13 10:41:33 -07:00 committed by Copybara-Service
parent d6fb7c365e
commit 490d1a7516
10 changed files with 139 additions and 40 deletions

View File

@ -36,11 +36,9 @@ export abstract class AudioTaskRunner<T> extends TaskRunner {
/** Sends a single audio clip to the graph and awaits results. */ /** Sends a single audio clip to the graph and awaits results. */
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
// Increment the timestamp by 1 millisecond to guarantee that we send
// monotonically increasing timestamps to the graph.
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
return this.process( return this.process(
audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); audioData, sampleRate ?? this.defaultSampleRate,
this.getSynctheticTimestamp());
} }
} }

View File

@ -175,9 +175,13 @@ export abstract class TaskRunner {
Math.max(this.latestOutputTimestamp, timestamp); Math.max(this.latestOutputTimestamp, timestamp);
} }
/** Returns the latest output timestamp. */ /**
protected getLatestOutputTimestamp() { * Gets a syncthethic timestamp in ms that can be used to send data to the
return this.latestOutputTimestamp; * next packet. The timestamp is one millisecond past the last timestamp
* received from the graph.
*/
protected getSynctheticTimestamp(): number {
return this.latestOutputTimestamp + 1;
} }
/** Throws the error from the error listener if an error was raised. */ /** Throws the error from the error listener if an error was raised. */

View File

@ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner {
* @return The classification result of the text * @return The classification result of the text
*/ */
classify(text: string): TextClassifierResult { classify(text: string): TextClassifierResult {
// Increment the timestamp by 1 millisecond to guarantee that we send
// monotonically increasing timestamps to the graph.
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
this.classificationResult = {classifications: []}; this.classificationResult = {classifications: []};
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.graphRunner.addStringToStream(
text, INPUT_STREAM, this.getSynctheticTimestamp());
this.finishProcessing(); this.finishProcessing();
return this.classificationResult; return this.classificationResult;
} }

View File

@ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner {
* @return The embedding resuls of the text * @return The embedding resuls of the text
*/ */
embed(text: string): TextEmbedderResult { embed(text: string): TextEmbedderResult {
// Increment the timestamp by 1 millisecond to guarantee that we send this.graphRunner.addStringToStream(
// monotonically increasing timestamps to the graph. text, INPUT_STREAM, this.getSynctheticTimestamp());
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
this.finishProcessing(); this.finishProcessing();
return this.embeddingResult; return this.embeddingResult;
} }

View File

@ -21,6 +21,11 @@ mediapipe_ts_declaration(
], ],
) )
mediapipe_ts_declaration(
name = "types",
srcs = ["types.d.ts"],
)
mediapipe_ts_library( mediapipe_ts_library(
name = "vision_task_runner", name = "vision_task_runner",
srcs = ["vision_task_runner.ts"], srcs = ["vision_task_runner.ts"],
@ -51,6 +56,11 @@ mediapipe_ts_library(
], ],
) )
mediapipe_ts_library(
name = "render_utils",
srcs = ["render_utils.ts"],
)
jasmine_node_test( jasmine_node_test(
name = "vision_task_runner_test", name = "vision_task_runner_test",
deps = [":vision_task_runner_test_lib"], deps = [":vision_task_runner_test_lib"],

View File

@ -0,0 +1,78 @@
/** @fileoverview Utility functions used in the vision demos. */
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Pre-baked color table for a maximum of 12 classes.
const CM_ALPHA = 128;
const COLOR_MAP = [
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
[255, 0, 0, CM_ALPHA], // class 1 is red
[0, 255, 0, CM_ALPHA], // class 2 is light green
[0, 0, 255, CM_ALPHA], // class 3 is blue
[255, 255, 0, CM_ALPHA], // class 4 is yellow
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
[128, 128, 128, CM_ALPHA], // class 7 is gray
[255, 128, 0, CM_ALPHA], // class 8 is orange
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
[0, 128, 0, CM_ALPHA], // class 10 is dark green
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
];
/** Helper function to draw a confidence mask */
export function drawConfidenceMask(
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
width: number, height: number): void {
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < image.length; i++) {
uint8ClampedArray[4 * i] = 128;
uint8ClampedArray[4 * i + 1] = 0;
uint8ClampedArray[4 * i + 2] = 0;
uint8ClampedArray[4 * i + 3] = image[i] * 255;
}
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
}
/**
* Helper function to draw a category mask. For GPU, we only have F32Arrays
* for now.
*/
export function drawCategoryMask(
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
width: number, height: number): void {
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex];
// When we're given a confidence mask by accident, we just log and return.
// TODO: We should fix this.
if (!color) {
console.warn('No color for ', colorIndex);
return;
}
uint8ClampedArray[4 * i] = color[0];
uint8ClampedArray[4 * i + 1] = color[1];
uint8ClampedArray[4 * i + 2] = color[2];
uint8ClampedArray[4 * i + 3] = color[3];
}
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
}

View File

@ -0,0 +1,34 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* The segmentation tasks return the segmentation result as a Uint8Array
* (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
* for future usage.
*/
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the segmentation tasks. The
* callback either receives a single element array with a category mask (as a
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;

View File

@ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
'Task is not initialized with image mode. ' + 'Task is not initialized with image mode. ' +
'\'runningMode\' must be set to \'IMAGE\'.'); '\'runningMode\' must be set to \'IMAGE\'.');
} }
this.process(image, imageProcessingOptions, this.getSynctheticTimestamp());
// Increment the timestamp by 1 millisecond to guarantee that we send
// monotonically increasing timestamps to the graph.
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
this.process(image, imageProcessingOptions, syntheticTimestamp);
} }
/** Sends a single video frame to the graph and awaits results. */ /** Sends a single video frame to the graph and awaits results. */

View File

@ -19,8 +19,8 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:types",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )

View File

@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
@ -28,27 +29,9 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
import {ImageSegmenterOptions} from './image_segmenter_options'; import {ImageSegmenterOptions} from './image_segmenter_options';
export * from './image_segmenter_options'; export * from './image_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback};
export {ImageSource}; // Used in the public API export {ImageSource}; // Used in the public API
/**
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
* for future usage.
*/
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the image segmenter. The
* callback either receives a single element array with a category mask (as a
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;
const IMAGE_STREAM = 'image_in'; const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect'; const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';