From e457039fc6350fbd2e75aa2d034f9b68af6d3410 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 16:16:34 -0800 Subject: [PATCH] Don't inherit from GraphRunner PiperOrigin-RevId: 492584486 --- .../audio_classifier/audio_classifier.ts | 9 +++-- .../audio/audio_embedder/audio_embedder.ts | 25 ++++++++------ mediapipe/tasks/web/core/task_runner.ts | 24 +++++++------- .../text/text_classifier/text_classifier.ts | 11 ++++--- .../web/text/text_embedder/text_embedder.ts | 4 +-- .../gesture_recognizer/gesture_recognizer.ts | 33 +++++++++++-------- .../vision/hand_landmarker/hand_landmarker.ts | 26 ++++++++------- .../image_classifier/image_classifier.ts | 11 ++++--- .../vision/image_embedder/image_embedder.ts | 4 +-- .../vision/object_detector/object_detector.ts | 9 ++--- .../graph_runner/graph_runner_image_lib.ts | 2 +- .../register_model_resources_graph_service.ts | 4 +-- 12 files changed, 92 insertions(+), 70 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 4e12780d2..265ba2b33 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); @@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoVectorListener( + this.graphRunner.attachProtoVectorListener( TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { this.addJsAudioClassificationResults(binaryProtos); }); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index d08eb4791..445dd5172 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); @@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResults.push( convertFromEmbeddingResultProto(embeddingResult)); }); - this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { - for (const binaryProto of data) { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResults.push( - convertFromEmbeddingResultProto(embeddingResult)); - } - }); + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c2691fc76..d769139bc 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset'; const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const WasmMediaPipeImageLib = +const GraphRunnerImageLibType = SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class GraphRunnerImageLib extends GraphRunnerImageLibType {} /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends - WasmMediaPipeImageLib { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; + protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; /** @@ -67,14 +69,14 @@ export abstract class TaskRunner extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. - this.setAutoRenderToScreen(false); + this.graphRunner.setAutoRenderToScreen(false); // Enables use of our model resource caching graph service. - this.registerModelResourcesGraphService(); + this.graphRunner.registerModelResourcesGraphService(); } /** Configures the shared options of a MediaPipe Task. */ @@ -95,11 +97,11 @@ export abstract class TaskRunner extends * @param isBinary This should be set to true if the graph is in * binary format, and false if it is in human-readable text format. */ - override setGraph(graphData: Uint8Array, isBinary: boolean): void { - this.attachErrorListener((code, message) => { + protected setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.graphRunner.attachErrorListener((code, message) => { this.processingErrors.push(new Error(message)); }); - super.setGraph(graphData, isBinary); + this.graphRunner.setGraph(graphData, isBinary); this.handleErrors(); } @@ -108,8 +110,8 @@ export abstract class TaskRunner extends * far as possible, performing all processing until no more processing can be * done. */ - override finishProcessing(): void { - super.finishProcessing(); + protected finishProcessing(): void { + this.graphRunner.finishProcessing(); this.handleErrors(); } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index bd2a207ce..8810d4b42 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner { classify(text: string): TextClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.classificationResult; @@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index d2899fbe2..62f9b06db 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner { */ embed(text: string): TextEmbedderResult { // Get text embeddings by running our MediaPipe graph. - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.embeddingResult; @@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); }); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8baee5ce3..69a8118a6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -257,8 +257,9 @@ export class GestureRecognizer extends this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -365,18 +366,22 @@ export class GestureRecognizer extends graphConfig.addNode(recognizerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 263ed4b48..9a0823f23 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner { this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner { graphConfig.addNode(landmarkerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 90dbf9798..40e8b5099 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner { ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.classificationResult; @@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 559332650..f8b0204ee 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner { protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( image, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.embeddings; @@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.addJsImageEmdedding(binaryProto); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 03171003f..e2cfe0575 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner { Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return [...this.detections]; @@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner { graphConfig.addNode(detectorNode); - this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { - this.addJsObjectDetections(binaryProto); - }); + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index e886999cb..7a4ea09e2 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -22,7 +22,7 @@ export declare interface WasmImageModule { * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` + * `const GraphRunnerImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index bc9c93e8a..9f2791d80 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources { * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: - * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * GraphRunner);` + * `const GraphRunnerWithModelResourcesLib = + * SupportModelResourcesGraphService(GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService(