Don't inherit from GraphRunner
PiperOrigin-RevId: 492584486
This commit is contained in:
		
							parent
							
								
									da9587033d
								
							
						
					
					
						commit
						e457039fc6
					
				|  | @ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | |||
|   protected override process( | ||||
|       audioData: Float32Array, sampleRate: number, | ||||
|       timestampMs: number): AudioClassifierResult[] { | ||||
|     this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); | ||||
|     this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); | ||||
|     this.graphRunner.addDoubleToStream( | ||||
|         sampleRate, SAMPLE_RATE_STREAM, timestampMs); | ||||
|     this.graphRunner.addAudioToStreamWithShape( | ||||
|         audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, | ||||
|         AUDIO_STREAM, timestampMs); | ||||
| 
 | ||||
|     this.classificationResults = []; | ||||
|     this.finishProcessing(); | ||||
|  | @ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> { | |||
| 
 | ||||
|     graphConfig.addNode(classifierNode); | ||||
| 
 | ||||
|     this.attachProtoVectorListener( | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { | ||||
|           this.addJsAudioClassificationResults(binaryProtos); | ||||
|         }); | ||||
|  |  | |||
|  | @ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | |||
|   protected override process( | ||||
|       audioData: Float32Array, sampleRate: number, | ||||
|       timestampMs: number): AudioEmbedderResult[] { | ||||
|     this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); | ||||
|     this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); | ||||
|     this.graphRunner.addDoubleToStream( | ||||
|         sampleRate, SAMPLE_RATE_STREAM, timestampMs); | ||||
|     this.graphRunner.addAudioToStreamWithShape( | ||||
|         audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, | ||||
|         AUDIO_STREAM, timestampMs); | ||||
| 
 | ||||
|     this.embeddingResults = []; | ||||
|     this.finishProcessing(); | ||||
|  | @ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> { | |||
| 
 | ||||
|     graphConfig.addNode(embedderNode); | ||||
| 
 | ||||
|     this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { | ||||
|     this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { | ||||
|       const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); | ||||
|       this.embeddingResults.push( | ||||
|           convertFromEmbeddingResultProto(embeddingResult)); | ||||
|     }); | ||||
| 
 | ||||
|     this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { | ||||
|       for (const binaryProto of data) { | ||||
|         const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); | ||||
|         this.embeddingResults.push( | ||||
|             convertFromEmbeddingResultProto(embeddingResult)); | ||||
|       } | ||||
|     }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         TIMESTAMPED_EMBEDDINGS_STREAM, data => { | ||||
|           for (const binaryProto of data) { | ||||
|             const embeddingResult = | ||||
|                 EmbeddingResult.deserializeBinary(binaryProto); | ||||
|             this.embeddingResults.push( | ||||
|                 convertFromEmbeddingResultProto(embeddingResult)); | ||||
|           } | ||||
|         }); | ||||
| 
 | ||||
|     const binaryGraph = graphConfig.serializeBinary(); | ||||
|     this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); | ||||
|  |  | |||
|  | @ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset'; | |||
| const NO_ASSETS = undefined; | ||||
| 
 | ||||
| // tslint:disable-next-line:enforce-name-casing
 | ||||
| const WasmMediaPipeImageLib = | ||||
| const GraphRunnerImageLibType = | ||||
|     SupportModelResourcesGraphService(SupportImage(GraphRunner)); | ||||
| /** An implementation of the GraphRunner that supports image operations */ | ||||
| export class GraphRunnerImageLib extends GraphRunnerImageLibType {} | ||||
| 
 | ||||
| /** Base class for all MediaPipe Tasks. */ | ||||
| export abstract class TaskRunner<O extends TaskRunnerOptions> extends | ||||
|     WasmMediaPipeImageLib { | ||||
| export abstract class TaskRunner<O extends TaskRunnerOptions> { | ||||
|   protected abstract baseOptions: BaseOptionsProto; | ||||
|   protected graphRunner: GraphRunnerImageLib; | ||||
|   private processingErrors: Error[] = []; | ||||
| 
 | ||||
|   /** | ||||
|  | @ -67,14 +69,14 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends | |||
|   constructor( | ||||
|       wasmModule: WasmModule, | ||||
|       glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { | ||||
|     super(wasmModule, glCanvas); | ||||
|     this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); | ||||
| 
 | ||||
|     // Disables the automatic render-to-screen code, which allows for pure
 | ||||
|     // CPU processing.
 | ||||
|     this.setAutoRenderToScreen(false); | ||||
|     this.graphRunner.setAutoRenderToScreen(false); | ||||
| 
 | ||||
|     // Enables use of our model resource caching graph service.
 | ||||
|     this.registerModelResourcesGraphService(); | ||||
|     this.graphRunner.registerModelResourcesGraphService(); | ||||
|   } | ||||
| 
 | ||||
|   /** Configures the shared options of a MediaPipe Task. */ | ||||
|  | @ -95,11 +97,11 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends | |||
|    * @param isBinary This should be set to true if the graph is in | ||||
|    *     binary format, and false if it is in human-readable text format. | ||||
|    */ | ||||
|   override setGraph(graphData: Uint8Array, isBinary: boolean): void { | ||||
|     this.attachErrorListener((code, message) => { | ||||
|   protected setGraph(graphData: Uint8Array, isBinary: boolean): void { | ||||
|     this.graphRunner.attachErrorListener((code, message) => { | ||||
|       this.processingErrors.push(new Error(message)); | ||||
|     }); | ||||
|     super.setGraph(graphData, isBinary); | ||||
|     this.graphRunner.setGraph(graphData, isBinary); | ||||
|     this.handleErrors(); | ||||
|   } | ||||
| 
 | ||||
|  | @ -108,8 +110,8 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends | |||
|    * far as possible, performing all processing until no more processing can be | ||||
|    * done. | ||||
|    */ | ||||
|   override finishProcessing(): void { | ||||
|     super.finishProcessing(); | ||||
|   protected finishProcessing(): void { | ||||
|     this.graphRunner.finishProcessing(); | ||||
|     this.handleErrors(); | ||||
|   } | ||||
| 
 | ||||
|  |  | |||
|  | @ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> { | |||
|   classify(text: string): TextClassifierResult { | ||||
|     // Get classification result by running our MediaPipe graph.
 | ||||
|     this.classificationResult = {classifications: []}; | ||||
|     this.addStringToStream( | ||||
|     this.graphRunner.addStringToStream( | ||||
|         text, INPUT_STREAM, /* timestamp= */ performance.now()); | ||||
|     this.finishProcessing(); | ||||
|     return this.classificationResult; | ||||
|  | @ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> { | |||
| 
 | ||||
|     graphConfig.addNode(classifierNode); | ||||
| 
 | ||||
|     this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { | ||||
|       this.classificationResult = convertFromClassificationResultProto( | ||||
|           ClassificationResult.deserializeBinary(binaryProto)); | ||||
|     }); | ||||
|     this.graphRunner.attachProtoListener( | ||||
|         CLASSIFICATIONS_STREAM, binaryProto => { | ||||
|           this.classificationResult = convertFromClassificationResultProto( | ||||
|               ClassificationResult.deserializeBinary(binaryProto)); | ||||
|         }); | ||||
| 
 | ||||
|     const binaryGraph = graphConfig.serializeBinary(); | ||||
|     this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); | ||||
|  |  | |||
|  | @ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> { | |||
|    */ | ||||
|   embed(text: string): TextEmbedderResult { | ||||
|     // Get text embeddings by running our MediaPipe graph.
 | ||||
|     this.addStringToStream( | ||||
|     this.graphRunner.addStringToStream( | ||||
|         text, INPUT_STREAM, /* timestamp= */ performance.now()); | ||||
|     this.finishProcessing(); | ||||
|     return this.embeddingResult; | ||||
|  | @ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> { | |||
| 
 | ||||
|     graphConfig.addNode(embedderNode); | ||||
| 
 | ||||
|     this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { | ||||
|     this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { | ||||
|       const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); | ||||
|       this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); | ||||
|     }); | ||||
|  |  | |||
|  | @ -257,8 +257,9 @@ export class GestureRecognizer extends | |||
|     this.worldLandmarks = []; | ||||
|     this.handednesses = []; | ||||
| 
 | ||||
|     this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); | ||||
|     this.addProtoToStream( | ||||
|     this.graphRunner.addGpuBufferAsImageToStream( | ||||
|         imageSource, IMAGE_STREAM, timestamp); | ||||
|     this.graphRunner.addProtoToStream( | ||||
|         FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', | ||||
|         NORM_RECT_STREAM, timestamp); | ||||
|     this.finishProcessing(); | ||||
|  | @ -365,18 +366,22 @@ export class GestureRecognizer extends | |||
| 
 | ||||
|     graphConfig.addNode(recognizerNode); | ||||
| 
 | ||||
|     this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { | ||||
|       this.addJsLandmarks(binaryProto); | ||||
|     }); | ||||
|     this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { | ||||
|       this.adddJsWorldLandmarks(binaryProto); | ||||
|     }); | ||||
|     this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { | ||||
|       this.gestures.push(...this.toJsCategories(binaryProto)); | ||||
|     }); | ||||
|     this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { | ||||
|       this.handednesses.push(...this.toJsCategories(binaryProto)); | ||||
|     }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         LANDMARKS_STREAM, binaryProto => { | ||||
|           this.addJsLandmarks(binaryProto); | ||||
|         }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         WORLD_LANDMARKS_STREAM, binaryProto => { | ||||
|           this.adddJsWorldLandmarks(binaryProto); | ||||
|         }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         HAND_GESTURES_STREAM, binaryProto => { | ||||
|           this.gestures.push(...this.toJsCategories(binaryProto)); | ||||
|         }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         HANDEDNESS_STREAM, binaryProto => { | ||||
|           this.handednesses.push(...this.toJsCategories(binaryProto)); | ||||
|         }); | ||||
| 
 | ||||
|     const binaryGraph = graphConfig.serializeBinary(); | ||||
|     this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); | ||||
|  |  | |||
|  | @ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | |||
|     this.worldLandmarks = []; | ||||
|     this.handednesses = []; | ||||
| 
 | ||||
|     this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); | ||||
|     this.addProtoToStream( | ||||
|     this.graphRunner.addGpuBufferAsImageToStream( | ||||
|         imageSource, IMAGE_STREAM, timestamp); | ||||
|     this.graphRunner.addProtoToStream( | ||||
|         FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', | ||||
|         NORM_RECT_STREAM, timestamp); | ||||
|     this.finishProcessing(); | ||||
|  | @ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> { | |||
| 
 | ||||
|     graphConfig.addNode(landmarkerNode); | ||||
| 
 | ||||
|     this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { | ||||
|       this.addJsLandmarks(binaryProto); | ||||
|     }); | ||||
|     this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { | ||||
|       this.adddJsWorldLandmarks(binaryProto); | ||||
|     }); | ||||
|     this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { | ||||
|       this.handednesses.push(...this.toJsCategories(binaryProto)); | ||||
|     }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         LANDMARKS_STREAM, binaryProto => { | ||||
|           this.addJsLandmarks(binaryProto); | ||||
|         }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         WORLD_LANDMARKS_STREAM, binaryProto => { | ||||
|           this.adddJsWorldLandmarks(binaryProto); | ||||
|         }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         HANDEDNESS_STREAM, binaryProto => { | ||||
|           this.handednesses.push(...this.toJsCategories(binaryProto)); | ||||
|         }); | ||||
| 
 | ||||
|     const binaryGraph = graphConfig.serializeBinary(); | ||||
|     this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); | ||||
|  |  | |||
|  | @ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | |||
|       ImageClassifierResult { | ||||
|     // Get classification result by running our MediaPipe graph.
 | ||||
|     this.classificationResult = {classifications: []}; | ||||
|     this.addGpuBufferAsImageToStream( | ||||
|     this.graphRunner.addGpuBufferAsImageToStream( | ||||
|         imageSource, INPUT_STREAM, timestamp ?? performance.now()); | ||||
|     this.finishProcessing(); | ||||
|     return this.classificationResult; | ||||
|  | @ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> { | |||
| 
 | ||||
|     graphConfig.addNode(classifierNode); | ||||
| 
 | ||||
|     this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { | ||||
|       this.classificationResult = convertFromClassificationResultProto( | ||||
|           ClassificationResult.deserializeBinary(binaryProto)); | ||||
|     }); | ||||
|     this.graphRunner.attachProtoListener( | ||||
|         CLASSIFICATIONS_STREAM, binaryProto => { | ||||
|           this.classificationResult = convertFromClassificationResultProto( | ||||
|               ClassificationResult.deserializeBinary(binaryProto)); | ||||
|         }); | ||||
| 
 | ||||
|     const binaryGraph = graphConfig.serializeBinary(); | ||||
|     this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); | ||||
|  |  | |||
|  | @ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | |||
|   protected process(image: ImageSource, timestamp: number): | ||||
|       ImageEmbedderResult { | ||||
|     // Get embeddings by running our MediaPipe graph.
 | ||||
|     this.addGpuBufferAsImageToStream( | ||||
|     this.graphRunner.addGpuBufferAsImageToStream( | ||||
|         image, INPUT_STREAM, timestamp ?? performance.now()); | ||||
|     this.finishProcessing(); | ||||
|     return this.embeddings; | ||||
|  | @ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> { | |||
| 
 | ||||
|     graphConfig.addNode(embedderNode); | ||||
| 
 | ||||
|     this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { | ||||
|     this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { | ||||
|       this.addJsImageEmdedding(binaryProto); | ||||
|     }); | ||||
| 
 | ||||
|  |  | |||
|  | @ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | |||
|       Detection[] { | ||||
|     // Get detections by running our MediaPipe graph.
 | ||||
|     this.detections = []; | ||||
|     this.addGpuBufferAsImageToStream( | ||||
|     this.graphRunner.addGpuBufferAsImageToStream( | ||||
|         imageSource, INPUT_STREAM, timestamp ?? performance.now()); | ||||
|     this.finishProcessing(); | ||||
|     return [...this.detections]; | ||||
|  | @ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> { | |||
| 
 | ||||
|     graphConfig.addNode(detectorNode); | ||||
| 
 | ||||
|     this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { | ||||
|       this.addJsObjectDetections(binaryProto); | ||||
|     }); | ||||
|     this.graphRunner.attachProtoVectorListener( | ||||
|         DETECTIONS_STREAM, binaryProto => { | ||||
|           this.addJsObjectDetections(binaryProto); | ||||
|         }); | ||||
| 
 | ||||
|     const binaryGraph = graphConfig.serializeBinary(); | ||||
|     this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ export declare interface WasmImageModule { | |||
|  * An implementation of GraphRunner that supports binding GPU image data as | ||||
|  * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for | ||||
|  * effective multiple inheritance. Example usage: | ||||
|  * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` | ||||
|  * `const GraphRunnerImageLib = SupportImage(GraphRunner);` | ||||
|  */ | ||||
| // tslint:disable-next-line:enforce-name-casing
 | ||||
| export function SupportImage<TBase extends LibConstructor>(Base: TBase) { | ||||
|  |  | |||
|  | @ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources { | |||
|  * An implementation of GraphRunner that supports registering model | ||||
|  * resources to a cache, in the form of a GraphService C++-side. We implement as | ||||
|  * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: | ||||
|  * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService(
 | ||||
|  *     GraphRunner);` | ||||
|  * `const GraphRunnerWithModelResourcesLib =
 | ||||
|  *      SupportModelResourcesGraphService(GraphRunner);` | ||||
|  */ | ||||
| // tslint:disable:enforce-name-casing
 | ||||
| export function SupportModelResourcesGraphService<TBase extends LibConstructor>( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user