Don't inherit from GraphRunner

PiperOrigin-RevId: 492584486
This commit is contained in:
Sebastian Schmidt 2022-12-02 16:16:34 -08:00 committed by Copybara-Service
parent da9587033d
commit e457039fc6
12 changed files with 92 additions and 70 deletions

View File

@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
protected override process( protected override process(
audioData: Float32Array, sampleRate: number, audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioClassifierResult[] { timestampMs: number): AudioClassifierResult[] {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); this.graphRunner.addDoubleToStream(
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.graphRunner.addAudioToStreamWithShape(
audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length,
AUDIO_STREAM, timestampMs);
this.classificationResults = []; this.classificationResults = [];
this.finishProcessing(); this.finishProcessing();
@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
this.addJsAudioClassificationResults(binaryProtos); this.addJsAudioClassificationResults(binaryProtos);
}); });

View File

@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
protected override process( protected override process(
audioData: Float32Array, sampleRate: number, audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioEmbedderResult[] { timestampMs: number): AudioEmbedderResult[] {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); this.graphRunner.addDoubleToStream(
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.graphRunner.addAudioToStreamWithShape(
audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length,
AUDIO_STREAM, timestampMs);
this.embeddingResults = []; this.embeddingResults = [];
this.finishProcessing(); this.finishProcessing();
@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
graphConfig.addNode(embedderNode); graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResults.push( this.embeddingResults.push(
convertFromEmbeddingResultProto(embeddingResult)); convertFromEmbeddingResultProto(embeddingResult));
}); });
this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { this.graphRunner.attachProtoVectorListener(
for (const binaryProto of data) { TIMESTAMPED_EMBEDDINGS_STREAM, data => {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); for (const binaryProto of data) {
this.embeddingResults.push( const embeddingResult =
convertFromEmbeddingResultProto(embeddingResult)); EmbeddingResult.deserializeBinary(binaryProto);
} this.embeddingResults.push(
}); convertFromEmbeddingResultProto(embeddingResult));
}
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset';
const NO_ASSETS = undefined; const NO_ASSETS = undefined;
// tslint:disable-next-line:enforce-name-casing // tslint:disable-next-line:enforce-name-casing
const WasmMediaPipeImageLib = const GraphRunnerImageLibType =
SupportModelResourcesGraphService(SupportImage(GraphRunner)); SupportModelResourcesGraphService(SupportImage(GraphRunner));
/** An implementation of the GraphRunner that supports image operations */
export class GraphRunnerImageLib extends GraphRunnerImageLibType {}
/** Base class for all MediaPipe Tasks. */ /** Base class for all MediaPipe Tasks. */
export abstract class TaskRunner<O extends TaskRunnerOptions> extends export abstract class TaskRunner<O extends TaskRunnerOptions> {
WasmMediaPipeImageLib {
protected abstract baseOptions: BaseOptionsProto; protected abstract baseOptions: BaseOptionsProto;
protected graphRunner: GraphRunnerImageLib;
private processingErrors: Error[] = []; private processingErrors: Error[] = [];
/** /**
@ -67,14 +69,14 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
constructor( constructor(
wasmModule: WasmModule, wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas); this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas);
// Disables the automatic render-to-screen code, which allows for pure // Disables the automatic render-to-screen code, which allows for pure
// CPU processing. // CPU processing.
this.setAutoRenderToScreen(false); this.graphRunner.setAutoRenderToScreen(false);
// Enables use of our model resource caching graph service. // Enables use of our model resource caching graph service.
this.registerModelResourcesGraphService(); this.graphRunner.registerModelResourcesGraphService();
} }
/** Configures the shared options of a MediaPipe Task. */ /** Configures the shared options of a MediaPipe Task. */
@ -95,11 +97,11 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
* @param isBinary This should be set to true if the graph is in * @param isBinary This should be set to true if the graph is in
* binary format, and false if it is in human-readable text format. * binary format, and false if it is in human-readable text format.
*/ */
override setGraph(graphData: Uint8Array, isBinary: boolean): void { protected setGraph(graphData: Uint8Array, isBinary: boolean): void {
this.attachErrorListener((code, message) => { this.graphRunner.attachErrorListener((code, message) => {
this.processingErrors.push(new Error(message)); this.processingErrors.push(new Error(message));
}); });
super.setGraph(graphData, isBinary); this.graphRunner.setGraph(graphData, isBinary);
this.handleErrors(); this.handleErrors();
} }
@ -108,8 +110,8 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
* far as possible, performing all processing until no more processing can be * far as possible, performing all processing until no more processing can be
* done. * done.
*/ */
override finishProcessing(): void { protected finishProcessing(): void {
super.finishProcessing(); this.graphRunner.finishProcessing();
this.handleErrors(); this.handleErrors();
} }

View File

@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> {
classify(text: string): TextClassifierResult { classify(text: string): TextClassifierResult {
// Get classification result by running our MediaPipe graph. // Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []}; this.classificationResult = {classifications: []};
this.addStringToStream( this.graphRunner.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now()); text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing(); this.finishProcessing();
return this.classificationResult; return this.classificationResult;
@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> {
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { this.graphRunner.attachProtoListener(
this.classificationResult = convertFromClassificationResultProto( CLASSIFICATIONS_STREAM, binaryProto => {
ClassificationResult.deserializeBinary(binaryProto)); this.classificationResult = convertFromClassificationResultProto(
}); ClassificationResult.deserializeBinary(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
*/ */
embed(text: string): TextEmbedderResult { embed(text: string): TextEmbedderResult {
// Get text embeddings by running our MediaPipe graph. // Get text embeddings by running our MediaPipe graph.
this.addStringToStream( this.graphRunner.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now()); text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing(); this.finishProcessing();
return this.embeddingResult; return this.embeddingResult;
@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
graphConfig.addNode(embedderNode); graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult);
}); });

View File

@ -257,8 +257,9 @@ export class GestureRecognizer extends
this.worldLandmarks = []; this.worldLandmarks = [];
this.handednesses = []; this.handednesses = [];
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); this.graphRunner.addGpuBufferAsImageToStream(
this.addProtoToStream( imageSource, IMAGE_STREAM, timestamp);
this.graphRunner.addProtoToStream(
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
NORM_RECT_STREAM, timestamp); NORM_RECT_STREAM, timestamp);
this.finishProcessing(); this.finishProcessing();
@ -365,18 +366,22 @@ export class GestureRecognizer extends
graphConfig.addNode(recognizerNode); graphConfig.addNode(recognizerNode);
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { this.graphRunner.attachProtoVectorListener(
this.addJsLandmarks(binaryProto); LANDMARKS_STREAM, binaryProto => {
}); this.addJsLandmarks(binaryProto);
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { });
this.adddJsWorldLandmarks(binaryProto); this.graphRunner.attachProtoVectorListener(
}); WORLD_LANDMARKS_STREAM, binaryProto => {
this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { this.adddJsWorldLandmarks(binaryProto);
this.gestures.push(...this.toJsCategories(binaryProto)); });
}); this.graphRunner.attachProtoVectorListener(
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { HAND_GESTURES_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto)); this.gestures.push(...this.toJsCategories(binaryProto));
}); });
this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
this.worldLandmarks = []; this.worldLandmarks = [];
this.handednesses = []; this.handednesses = [];
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); this.graphRunner.addGpuBufferAsImageToStream(
this.addProtoToStream( imageSource, IMAGE_STREAM, timestamp);
this.graphRunner.addProtoToStream(
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
NORM_RECT_STREAM, timestamp); NORM_RECT_STREAM, timestamp);
this.finishProcessing(); this.finishProcessing();
@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
graphConfig.addNode(landmarkerNode); graphConfig.addNode(landmarkerNode);
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { this.graphRunner.attachProtoVectorListener(
this.addJsLandmarks(binaryProto); LANDMARKS_STREAM, binaryProto => {
}); this.addJsLandmarks(binaryProto);
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { });
this.adddJsWorldLandmarks(binaryProto); this.graphRunner.attachProtoVectorListener(
}); WORLD_LANDMARKS_STREAM, binaryProto => {
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { this.adddJsWorldLandmarks(binaryProto);
this.handednesses.push(...this.toJsCategories(binaryProto)); });
}); this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
ImageClassifierResult { ImageClassifierResult {
// Get classification result by running our MediaPipe graph. // Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []}; this.classificationResult = {classifications: []};
this.addGpuBufferAsImageToStream( this.graphRunner.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now()); imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();
return this.classificationResult; return this.classificationResult;
@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { this.graphRunner.attachProtoListener(
this.classificationResult = convertFromClassificationResultProto( CLASSIFICATIONS_STREAM, binaryProto => {
ClassificationResult.deserializeBinary(binaryProto)); this.classificationResult = convertFromClassificationResultProto(
}); ClassificationResult.deserializeBinary(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
protected process(image: ImageSource, timestamp: number): protected process(image: ImageSource, timestamp: number):
ImageEmbedderResult { ImageEmbedderResult {
// Get embeddings by running our MediaPipe graph. // Get embeddings by running our MediaPipe graph.
this.addGpuBufferAsImageToStream( this.graphRunner.addGpuBufferAsImageToStream(
image, INPUT_STREAM, timestamp ?? performance.now()); image, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();
return this.embeddings; return this.embeddings;
@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
graphConfig.addNode(embedderNode); graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
this.addJsImageEmdedding(binaryProto); this.addJsImageEmdedding(binaryProto);
}); });

View File

@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
Detection[] { Detection[] {
// Get detections by running our MediaPipe graph. // Get detections by running our MediaPipe graph.
this.detections = []; this.detections = [];
this.addGpuBufferAsImageToStream( this.graphRunner.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now()); imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();
return [...this.detections]; return [...this.detections];
@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
graphConfig.addNode(detectorNode); graphConfig.addNode(detectorNode);
this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { this.graphRunner.attachProtoVectorListener(
this.addJsObjectDetections(binaryProto); DETECTIONS_STREAM, binaryProto => {
}); this.addJsObjectDetections(binaryProto);
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -22,7 +22,7 @@ export declare interface WasmImageModule {
* An implementation of GraphRunner that supports binding GPU image data as * An implementation of GraphRunner that supports binding GPU image data as
* `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for
* effective multiple inheritance. Example usage: * effective multiple inheritance. Example usage:
* `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` * `const GraphRunnerImageLib = SupportImage(GraphRunner);`
*/ */
// tslint:disable-next-line:enforce-name-casing // tslint:disable-next-line:enforce-name-casing
export function SupportImage<TBase extends LibConstructor>(Base: TBase) { export function SupportImage<TBase extends LibConstructor>(Base: TBase) {

View File

@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources {
* An implementation of GraphRunner that supports registering model * An implementation of GraphRunner that supports registering model
* resources to a cache, in the form of a GraphService C++-side. We implement as * resources to a cache, in the form of a GraphService C++-side. We implement as
* a proper TS mixin, to allow for effective multiple inheritance. Sample usage: * a proper TS mixin, to allow for effective multiple inheritance. Sample usage:
* `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( * `const GraphRunnerWithModelResourcesLib =
* GraphRunner);` * SupportModelResourcesGraphService(GraphRunner);`
*/ */
// tslint:disable:enforce-name-casing // tslint:disable:enforce-name-casing
export function SupportModelResourcesGraphService<TBase extends LibConstructor>( export function SupportModelResourcesGraphService<TBase extends LibConstructor>(