Refactor Web code for InteractiveSegmenter
PiperOrigin-RevId: 516254891
This commit is contained in:
parent
d6fb7c365e
commit
490d1a7516
|
@ -36,11 +36,9 @@ export abstract class AudioTaskRunner<T> extends TaskRunner {
|
||||||
|
|
||||||
/** Sends a single audio clip to the graph and awaits results. */
|
/** Sends a single audio clip to the graph and awaits results. */
|
||||||
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
|
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
|
||||||
// monotonically increasing timestamps to the graph.
|
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
return this.process(
|
return this.process(
|
||||||
audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp);
|
audioData, sampleRate ?? this.defaultSampleRate,
|
||||||
|
this.getSynctheticTimestamp());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -175,9 +175,13 @@ export abstract class TaskRunner {
|
||||||
Math.max(this.latestOutputTimestamp, timestamp);
|
Math.max(this.latestOutputTimestamp, timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the latest output timestamp. */
|
/**
|
||||||
protected getLatestOutputTimestamp() {
|
* Gets a syncthethic timestamp in ms that can be used to send data to the
|
||||||
return this.latestOutputTimestamp;
|
* next packet. The timestamp is one millisecond past the last timestamp
|
||||||
|
* received from the graph.
|
||||||
|
*/
|
||||||
|
protected getSynctheticTimestamp(): number {
|
||||||
|
return this.latestOutputTimestamp + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Throws the error from the error listener if an error was raised. */
|
/** Throws the error from the error listener if an error was raised. */
|
||||||
|
|
|
@ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner {
|
||||||
* @return The classification result of the text
|
* @return The classification result of the text
|
||||||
*/
|
*/
|
||||||
classify(text: string): TextClassifierResult {
|
classify(text: string): TextClassifierResult {
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
|
||||||
// monotonically increasing timestamps to the graph.
|
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
this.classificationResult = {classifications: []};
|
this.classificationResult = {classifications: []};
|
||||||
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
|
this.graphRunner.addStringToStream(
|
||||||
|
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return this.classificationResult;
|
return this.classificationResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner {
|
||||||
* @return The embedding resuls of the text
|
* @return The embedding resuls of the text
|
||||||
*/
|
*/
|
||||||
embed(text: string): TextEmbedderResult {
|
embed(text: string): TextEmbedderResult {
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
this.graphRunner.addStringToStream(
|
||||||
// monotonically increasing timestamps to the graph.
|
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
|
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return this.embeddingResult;
|
return this.embeddingResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,11 @@ mediapipe_ts_declaration(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_declaration(
|
||||||
|
name = "types",
|
||||||
|
srcs = ["types.d.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "vision_task_runner",
|
name = "vision_task_runner",
|
||||||
srcs = ["vision_task_runner.ts"],
|
srcs = ["vision_task_runner.ts"],
|
||||||
|
@ -51,6 +56,11 @@ mediapipe_ts_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "render_utils",
|
||||||
|
srcs = ["render_utils.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
jasmine_node_test(
|
jasmine_node_test(
|
||||||
name = "vision_task_runner_test",
|
name = "vision_task_runner_test",
|
||||||
deps = [":vision_task_runner_test_lib"],
|
deps = [":vision_task_runner_test_lib"],
|
||||||
|
|
78
mediapipe/tasks/web/vision/core/render_utils.ts
Normal file
78
mediapipe/tasks/web/vision/core/render_utils.ts
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
/** @fileoverview Utility functions used in the vision demos. */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Pre-baked color table for a maximum of 12 classes.
|
||||||
|
const CM_ALPHA = 128;
|
||||||
|
const COLOR_MAP = [
|
||||||
|
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
|
||||||
|
[255, 0, 0, CM_ALPHA], // class 1 is red
|
||||||
|
[0, 255, 0, CM_ALPHA], // class 2 is light green
|
||||||
|
[0, 0, 255, CM_ALPHA], // class 3 is blue
|
||||||
|
[255, 255, 0, CM_ALPHA], // class 4 is yellow
|
||||||
|
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
|
||||||
|
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
|
||||||
|
[128, 128, 128, CM_ALPHA], // class 7 is gray
|
||||||
|
[255, 128, 0, CM_ALPHA], // class 8 is orange
|
||||||
|
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
|
||||||
|
[0, 128, 0, CM_ALPHA], // class 10 is dark green
|
||||||
|
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
|
||||||
|
];
|
||||||
|
|
||||||
|
|
||||||
|
/** Helper function to draw a confidence mask */
|
||||||
|
export function drawConfidenceMask(
|
||||||
|
|
||||||
|
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
|
||||||
|
width: number, height: number): void {
|
||||||
|
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
|
||||||
|
for (let i = 0; i < image.length; i++) {
|
||||||
|
uint8ClampedArray[4 * i] = 128;
|
||||||
|
uint8ClampedArray[4 * i + 1] = 0;
|
||||||
|
uint8ClampedArray[4 * i + 2] = 0;
|
||||||
|
uint8ClampedArray[4 * i + 3] = image[i] * 255;
|
||||||
|
}
|
||||||
|
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function to draw a category mask. For GPU, we only have F32Arrays
|
||||||
|
* for now.
|
||||||
|
*/
|
||||||
|
export function drawCategoryMask(
|
||||||
|
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
|
||||||
|
width: number, height: number): void {
|
||||||
|
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
|
||||||
|
const isFloatArray = image instanceof Float32Array;
|
||||||
|
for (let i = 0; i < image.length; i++) {
|
||||||
|
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
||||||
|
const color = COLOR_MAP[colorIndex];
|
||||||
|
|
||||||
|
// When we're given a confidence mask by accident, we just log and return.
|
||||||
|
// TODO: We should fix this.
|
||||||
|
if (!color) {
|
||||||
|
console.warn('No color for ', colorIndex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8ClampedArray[4 * i] = color[0];
|
||||||
|
uint8ClampedArray[4 * i + 1] = color[1];
|
||||||
|
uint8ClampedArray[4 * i + 2] = color[2];
|
||||||
|
uint8ClampedArray[4 * i + 3] = color[3];
|
||||||
|
}
|
||||||
|
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
|
||||||
|
}
|
34
mediapipe/tasks/web/vision/core/types.d.ts
vendored
Normal file
34
mediapipe/tasks/web/vision/core/types.d.ts
vendored
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The segmentation tasks return the segmentation result as a Uint8Array
|
||||||
|
* (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
|
||||||
|
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
|
||||||
|
* for future usage.
|
||||||
|
*/
|
||||||
|
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A callback that receives the computed masks from the segmentation tasks. The
|
||||||
|
* callback either receives a single element array with a category mask (as a
|
||||||
|
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
|
||||||
|
* The returned data is only valid for the duration of the callback. If
|
||||||
|
* asynchronous processing is needed, all data needs to be copied before the
|
||||||
|
* callback returns.
|
||||||
|
*/
|
||||||
|
export type SegmentationMaskCallback =
|
||||||
|
(masks: SegmentationMask[], width: number, height: number) => void;
|
|
@ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
||||||
'Task is not initialized with image mode. ' +
|
'Task is not initialized with image mode. ' +
|
||||||
'\'runningMode\' must be set to \'IMAGE\'.');
|
'\'runningMode\' must be set to \'IMAGE\'.');
|
||||||
}
|
}
|
||||||
|
this.process(image, imageProcessingOptions, this.getSynctheticTimestamp());
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
|
||||||
// monotonically increasing timestamps to the graph.
|
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
this.process(image, imageProcessingOptions, syntheticTimestamp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sends a single video frame to the graph and awaits results. */
|
/** Sends a single video frame to the graph and awaits results. */
|
||||||
|
|
|
@ -19,8 +19,8 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/web/vision/core:types",
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
|
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
@ -28,27 +29,9 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
|
||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
|
||||||
export * from './image_segmenter_options';
|
export * from './image_segmenter_options';
|
||||||
|
export {SegmentationMask, SegmentationMaskCallback};
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
/**
|
|
||||||
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
|
|
||||||
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
|
|
||||||
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
|
|
||||||
* for future usage.
|
|
||||||
*/
|
|
||||||
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A callback that receives the computed masks from the image segmenter. The
|
|
||||||
* callback either receives a single element array with a category mask (as a
|
|
||||||
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
|
|
||||||
* The returned data is only valid for the duration of the callback. If
|
|
||||||
* asynchronous processing is needed, all data needs to be copied before the
|
|
||||||
* callback returns.
|
|
||||||
*/
|
|
||||||
export type SegmentationMaskCallback =
|
|
||||||
(masks: SegmentationMask[], width: number, height: number) => void;
|
|
||||||
|
|
||||||
const IMAGE_STREAM = 'image_in';
|
const IMAGE_STREAM = 'image_in';
|
||||||
const NORM_RECT_STREAM = 'norm_rect';
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
||||||
|
|
Loading…
Reference in New Issue
Block a user