Refactor Web code for InteractiveSegmenter
PiperOrigin-RevId: 516254891
This commit is contained in:
		
							parent
							
								
									d6fb7c365e
								
							
						
					
					
						commit
						490d1a7516
					
				| 
						 | 
					@ -36,11 +36,9 @@ export abstract class AudioTaskRunner<T> extends TaskRunner {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Sends a single audio clip to the graph and awaits results. */
 | 
					  /** Sends a single audio clip to the graph and awaits results. */
 | 
				
			||||||
  protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
 | 
					  protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
 | 
				
			||||||
    // Increment the timestamp by 1 millisecond to guarantee that we send
 | 
					 | 
				
			||||||
    // monotonically increasing timestamps to the graph.
 | 
					 | 
				
			||||||
    const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
 | 
					 | 
				
			||||||
    return this.process(
 | 
					    return this.process(
 | 
				
			||||||
        audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp);
 | 
					        audioData, sampleRate ?? this.defaultSampleRate,
 | 
				
			||||||
 | 
					        this.getSynctheticTimestamp());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -175,9 +175,13 @@ export abstract class TaskRunner {
 | 
				
			||||||
        Math.max(this.latestOutputTimestamp, timestamp);
 | 
					        Math.max(this.latestOutputTimestamp, timestamp);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Returns the latest output timestamp. */
 | 
					  /**
 | 
				
			||||||
  protected getLatestOutputTimestamp() {
 | 
					   * Gets a syncthethic timestamp in ms that can be used to send data to the
 | 
				
			||||||
    return this.latestOutputTimestamp;
 | 
					   * next packet. The timestamp is one millisecond past the last timestamp
 | 
				
			||||||
 | 
					   * received from the graph.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  protected getSynctheticTimestamp(): number {
 | 
				
			||||||
 | 
					    return this.latestOutputTimestamp + 1;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Throws the error from the error listener if an error was raised. */
 | 
					  /** Throws the error from the error listener if an error was raised. */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner {
 | 
				
			||||||
   * @return The classification result of the text
 | 
					   * @return The classification result of the text
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  classify(text: string): TextClassifierResult {
 | 
					  classify(text: string): TextClassifierResult {
 | 
				
			||||||
    // Increment the timestamp by 1 millisecond to guarantee that we send
 | 
					 | 
				
			||||||
    // monotonically increasing timestamps to the graph.
 | 
					 | 
				
			||||||
    const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
 | 
					 | 
				
			||||||
    this.classificationResult = {classifications: []};
 | 
					    this.classificationResult = {classifications: []};
 | 
				
			||||||
    this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
 | 
					    this.graphRunner.addStringToStream(
 | 
				
			||||||
 | 
					        text, INPUT_STREAM, this.getSynctheticTimestamp());
 | 
				
			||||||
    this.finishProcessing();
 | 
					    this.finishProcessing();
 | 
				
			||||||
    return this.classificationResult;
 | 
					    return this.classificationResult;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner {
 | 
				
			||||||
   * @return The embedding resuls of the text
 | 
					   * @return The embedding resuls of the text
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  embed(text: string): TextEmbedderResult {
 | 
					  embed(text: string): TextEmbedderResult {
 | 
				
			||||||
    // Increment the timestamp by 1 millisecond to guarantee that we send
 | 
					    this.graphRunner.addStringToStream(
 | 
				
			||||||
    // monotonically increasing timestamps to the graph.
 | 
					        text, INPUT_STREAM, this.getSynctheticTimestamp());
 | 
				
			||||||
    const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
 | 
					 | 
				
			||||||
    this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
 | 
					 | 
				
			||||||
    this.finishProcessing();
 | 
					    this.finishProcessing();
 | 
				
			||||||
    return this.embeddingResult;
 | 
					    return this.embeddingResult;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,6 +21,11 @@ mediapipe_ts_declaration(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_declaration(
 | 
				
			||||||
 | 
					    name = "types",
 | 
				
			||||||
 | 
					    srcs = ["types.d.ts"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
mediapipe_ts_library(
 | 
					mediapipe_ts_library(
 | 
				
			||||||
    name = "vision_task_runner",
 | 
					    name = "vision_task_runner",
 | 
				
			||||||
    srcs = ["vision_task_runner.ts"],
 | 
					    srcs = ["vision_task_runner.ts"],
 | 
				
			||||||
| 
						 | 
					@ -51,6 +56,11 @@ mediapipe_ts_library(
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mediapipe_ts_library(
 | 
				
			||||||
 | 
					    name = "render_utils",
 | 
				
			||||||
 | 
					    srcs = ["render_utils.ts"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
jasmine_node_test(
 | 
					jasmine_node_test(
 | 
				
			||||||
    name = "vision_task_runner_test",
 | 
					    name = "vision_task_runner_test",
 | 
				
			||||||
    deps = [":vision_task_runner_test_lib"],
 | 
					    deps = [":vision_task_runner_test_lib"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										78
									
								
								mediapipe/tasks/web/vision/core/render_utils.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								mediapipe/tasks/web/vision/core/render_utils.ts
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,78 @@
 | 
				
			||||||
 | 
					/** @fileoverview Utility functions used in the vision demos. */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Pre-baked color table for a maximum of 12 classes.
 | 
				
			||||||
 | 
					const CM_ALPHA = 128;
 | 
				
			||||||
 | 
					const COLOR_MAP = [
 | 
				
			||||||
 | 
					  [0, 0, 0, CM_ALPHA],        // class 0 is BG = transparent
 | 
				
			||||||
 | 
					  [255, 0, 0, CM_ALPHA],      // class 1 is red
 | 
				
			||||||
 | 
					  [0, 255, 0, CM_ALPHA],      // class 2 is light green
 | 
				
			||||||
 | 
					  [0, 0, 255, CM_ALPHA],      // class 3 is blue
 | 
				
			||||||
 | 
					  [255, 255, 0, CM_ALPHA],    // class 4 is yellow
 | 
				
			||||||
 | 
					  [255, 0, 255, CM_ALPHA],    // class 5 is light purple / magenta
 | 
				
			||||||
 | 
					  [0, 255, 255, CM_ALPHA],    // class 6 is light blue / aqua
 | 
				
			||||||
 | 
					  [128, 128, 128, CM_ALPHA],  // class 7 is gray
 | 
				
			||||||
 | 
					  [255, 128, 0, CM_ALPHA],    // class 8 is orange
 | 
				
			||||||
 | 
					  [128, 0, 255, CM_ALPHA],    // class 9 is dark purple
 | 
				
			||||||
 | 
					  [0, 128, 0, CM_ALPHA],      // class 10 is dark green
 | 
				
			||||||
 | 
					  [255, 255, 255, CM_ALPHA]   // class 11 is white; could do black instead?
 | 
				
			||||||
 | 
					];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Helper function to draw a confidence mask */
 | 
				
			||||||
 | 
					export function drawConfidenceMask(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
 | 
				
			||||||
 | 
					    width: number, height: number): void {
 | 
				
			||||||
 | 
					  const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
 | 
				
			||||||
 | 
					  for (let i = 0; i < image.length; i++) {
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i] = 128;
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i + 1] = 0;
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i + 2] = 0;
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i + 3] = image[i] * 255;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Helper function to draw a category mask. For GPU, we only have F32Arrays
 | 
				
			||||||
 | 
					 * for now.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export function drawCategoryMask(
 | 
				
			||||||
 | 
					    ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
 | 
				
			||||||
 | 
					    width: number, height: number): void {
 | 
				
			||||||
 | 
					  const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
 | 
				
			||||||
 | 
					  const isFloatArray = image instanceof Float32Array;
 | 
				
			||||||
 | 
					  for (let i = 0; i < image.length; i++) {
 | 
				
			||||||
 | 
					    const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
 | 
				
			||||||
 | 
					    const color = COLOR_MAP[colorIndex];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // When we're given a confidence mask by accident, we just log and return.
 | 
				
			||||||
 | 
					    // TODO: We should fix this.
 | 
				
			||||||
 | 
					    if (!color) {
 | 
				
			||||||
 | 
					      console.warn('No color for ', colorIndex);
 | 
				
			||||||
 | 
					      return;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i] = color[0];
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i + 1] = color[1];
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i + 2] = color[2];
 | 
				
			||||||
 | 
					    uint8ClampedArray[4 * i + 3] = color[3];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										34
									
								
								mediapipe/tasks/web/vision/core/types.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								mediapipe/tasks/web/vision/core/types.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					 * you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					 * You may obtain a copy of the License at
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					 * See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					 * limitations under the License.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * The segmentation tasks return the segmentation result as a Uint8Array
 | 
				
			||||||
 | 
					 * (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
 | 
				
			||||||
 | 
					 * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
 | 
				
			||||||
 | 
					 * for future usage.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * A callback that receives the computed masks from the segmentation tasks. The
 | 
				
			||||||
 | 
					 * callback either receives a single element array with a category mask (as a
 | 
				
			||||||
 | 
					 * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
 | 
				
			||||||
 | 
					 * The returned data is only valid for the duration of the callback. If
 | 
				
			||||||
 | 
					 * asynchronous processing is needed, all data needs to be copied before the
 | 
				
			||||||
 | 
					 * callback returns.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					export type SegmentationMaskCallback =
 | 
				
			||||||
 | 
					    (masks: SegmentationMask[], width: number, height: number) => void;
 | 
				
			||||||
| 
						 | 
					@ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
 | 
				
			||||||
          'Task is not initialized with image mode. ' +
 | 
					          'Task is not initialized with image mode. ' +
 | 
				
			||||||
          '\'runningMode\' must be set to \'IMAGE\'.');
 | 
					          '\'runningMode\' must be set to \'IMAGE\'.');
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					    this.process(image, imageProcessingOptions, this.getSynctheticTimestamp());
 | 
				
			||||||
    // Increment the timestamp by 1 millisecond to guarantee that we send
 | 
					 | 
				
			||||||
    // monotonically increasing timestamps to the graph.
 | 
					 | 
				
			||||||
    const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
 | 
					 | 
				
			||||||
    this.process(image, imageProcessingOptions, syntheticTimestamp);
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Sends a single video frame to the graph and awaits results. */
 | 
					  /** Sends a single video frame to the graph and awaits results. */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,8 +19,8 @@ mediapipe_ts_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
 | 
					        "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
        "//mediapipe/tasks/web/vision/core:image_processing_options",
 | 
					        "//mediapipe/tasks/web/vision/core:image_processing_options",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/web/vision/core:types",
 | 
				
			||||||
        "//mediapipe/tasks/web/vision/core:vision_task_runner",
 | 
					        "//mediapipe/tasks/web/vision/core:vision_task_runner",
 | 
				
			||||||
        "//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
 | 
					 | 
				
			||||||
        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
					        "//mediapipe/web/graph_runner:graph_runner_ts",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
 | 
				
			||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
					import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
 | 
				
			||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
					import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
 | 
				
			||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
					import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
 | 
				
			||||||
 | 
					import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
 | 
				
			||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
					import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
 | 
				
			||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
 | 
					import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
 | 
				
			||||||
// Placeholder for internal dependency on trusted resource url
 | 
					// Placeholder for internal dependency on trusted resource url
 | 
				
			||||||
| 
						 | 
					@ -28,27 +29,9 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
 | 
				
			||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
 | 
					import {ImageSegmenterOptions} from './image_segmenter_options';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
export * from './image_segmenter_options';
 | 
					export * from './image_segmenter_options';
 | 
				
			||||||
 | 
					export {SegmentationMask, SegmentationMaskCallback};
 | 
				
			||||||
export {ImageSource};  // Used in the public API
 | 
					export {ImageSource};  // Used in the public API
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					 | 
				
			||||||
 * The ImageSegmenter returns the segmentation result as a Uint8Array (when
 | 
					 | 
				
			||||||
 * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
 | 
					 | 
				
			||||||
 * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
 | 
					 | 
				
			||||||
 * for future usage.
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					 | 
				
			||||||
 * A callback that receives the computed masks from the image segmenter. The
 | 
					 | 
				
			||||||
 * callback either receives a single element array with a category mask (as a
 | 
					 | 
				
			||||||
 * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
 | 
					 | 
				
			||||||
 * The returned data is only valid for the duration of the callback. If
 | 
					 | 
				
			||||||
 * asynchronous processing is needed, all data needs to be copied before the
 | 
					 | 
				
			||||||
 * callback returns.
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
export type SegmentationMaskCallback =
 | 
					 | 
				
			||||||
    (masks: SegmentationMask[], width: number, height: number) => void;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const IMAGE_STREAM = 'image_in';
 | 
					const IMAGE_STREAM = 'image_in';
 | 
				
			||||||
const NORM_RECT_STREAM = 'norm_rect';
 | 
					const NORM_RECT_STREAM = 'norm_rect';
 | 
				
			||||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
 | 
					const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user