Don't inherit from GraphRunner

PiperOrigin-RevId: 492584486
This commit is contained in:
Sebastian Schmidt 2022-12-02 16:16:34 -08:00 committed by Copybara-Service
parent da9587033d
commit e457039fc6
12 changed files with 92 additions and 70 deletions

View File

@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
protected override process(
audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioClassifierResult[] {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
this.graphRunner.addDoubleToStream(
sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.graphRunner.addAudioToStreamWithShape(
audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length,
AUDIO_STREAM, timestampMs);
this.classificationResults = [];
this.finishProcessing();
@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
graphConfig.addNode(classifierNode);
this.attachProtoVectorListener(
this.graphRunner.attachProtoVectorListener(
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
this.addJsAudioClassificationResults(binaryProtos);
});

View File

@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
protected override process(
audioData: Float32Array, sampleRate: number,
timestampMs: number): AudioEmbedderResult[] {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
this.graphRunner.addDoubleToStream(
sampleRate, SAMPLE_RATE_STREAM, timestampMs);
this.graphRunner.addAudioToStreamWithShape(
audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length,
AUDIO_STREAM, timestampMs);
this.embeddingResults = [];
this.finishProcessing();
@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResults.push(
convertFromEmbeddingResultProto(embeddingResult));
});
this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => {
for (const binaryProto of data) {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResults.push(
convertFromEmbeddingResultProto(embeddingResult));
}
});
this.graphRunner.attachProtoVectorListener(
TIMESTAMPED_EMBEDDINGS_STREAM, data => {
for (const binaryProto of data) {
const embeddingResult =
EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResults.push(
convertFromEmbeddingResultProto(embeddingResult));
}
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset';
const NO_ASSETS = undefined;
// tslint:disable-next-line:enforce-name-casing
const WasmMediaPipeImageLib =
const GraphRunnerImageLibType =
SupportModelResourcesGraphService(SupportImage(GraphRunner));
/** An implementation of the GraphRunner that supports image operations */
export class GraphRunnerImageLib extends GraphRunnerImageLibType {}
/** Base class for all MediaPipe Tasks. */
export abstract class TaskRunner<O extends TaskRunnerOptions> extends
WasmMediaPipeImageLib {
export abstract class TaskRunner<O extends TaskRunnerOptions> {
protected abstract baseOptions: BaseOptionsProto;
protected graphRunner: GraphRunnerImageLib;
private processingErrors: Error[] = [];
/**
@ -67,14 +69,14 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas);
// Disables the automatic render-to-screen code, which allows for pure
// CPU processing.
this.setAutoRenderToScreen(false);
this.graphRunner.setAutoRenderToScreen(false);
// Enables use of our model resource caching graph service.
this.registerModelResourcesGraphService();
this.graphRunner.registerModelResourcesGraphService();
}
/** Configures the shared options of a MediaPipe Task. */
@ -95,11 +97,11 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
* @param isBinary This should be set to true if the graph is in
* binary format, and false if it is in human-readable text format.
*/
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
this.attachErrorListener((code, message) => {
protected setGraph(graphData: Uint8Array, isBinary: boolean): void {
this.graphRunner.attachErrorListener((code, message) => {
this.processingErrors.push(new Error(message));
});
super.setGraph(graphData, isBinary);
this.graphRunner.setGraph(graphData, isBinary);
this.handleErrors();
}
@ -108,8 +110,8 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
* far as possible, performing all processing until no more processing can be
* done.
*/
override finishProcessing(): void {
super.finishProcessing();
protected finishProcessing(): void {
this.graphRunner.finishProcessing();
this.handleErrors();
}

View File

@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> {
classify(text: string): TextClassifierResult {
// Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []};
this.addStringToStream(
this.graphRunner.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing();
return this.classificationResult;
@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> {
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
});
this.graphRunner.attachProtoListener(
CLASSIFICATIONS_STREAM, binaryProto => {
this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
*/
embed(text: string): TextEmbedderResult {
// Get text embeddings by running our MediaPipe graph.
this.addStringToStream(
this.graphRunner.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing();
return this.embeddingResult;
@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult);
});

View File

@ -257,8 +257,9 @@ export class GestureRecognizer extends
this.worldLandmarks = [];
this.handednesses = [];
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
this.addProtoToStream(
this.graphRunner.addGpuBufferAsImageToStream(
imageSource, IMAGE_STREAM, timestamp);
this.graphRunner.addProtoToStream(
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
NORM_RECT_STREAM, timestamp);
this.finishProcessing();
@ -365,18 +366,22 @@ export class GestureRecognizer extends
graphConfig.addNode(recognizerNode);
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
this.addJsLandmarks(binaryProto);
});
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
this.adddJsWorldLandmarks(binaryProto);
});
this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => {
this.gestures.push(...this.toJsCategories(binaryProto));
});
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
this.graphRunner.attachProtoVectorListener(
LANDMARKS_STREAM, binaryProto => {
this.addJsLandmarks(binaryProto);
});
this.graphRunner.attachProtoVectorListener(
WORLD_LANDMARKS_STREAM, binaryProto => {
this.adddJsWorldLandmarks(binaryProto);
});
this.graphRunner.attachProtoVectorListener(
HAND_GESTURES_STREAM, binaryProto => {
this.gestures.push(...this.toJsCategories(binaryProto));
});
this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
this.worldLandmarks = [];
this.handednesses = [];
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
this.addProtoToStream(
this.graphRunner.addGpuBufferAsImageToStream(
imageSource, IMAGE_STREAM, timestamp);
this.graphRunner.addProtoToStream(
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
NORM_RECT_STREAM, timestamp);
this.finishProcessing();
@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
graphConfig.addNode(landmarkerNode);
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
this.addJsLandmarks(binaryProto);
});
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
this.adddJsWorldLandmarks(binaryProto);
});
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
this.graphRunner.attachProtoVectorListener(
LANDMARKS_STREAM, binaryProto => {
this.addJsLandmarks(binaryProto);
});
this.graphRunner.attachProtoVectorListener(
WORLD_LANDMARKS_STREAM, binaryProto => {
this.adddJsWorldLandmarks(binaryProto);
});
this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => {
this.handednesses.push(...this.toJsCategories(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
ImageClassifierResult {
// Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []};
this.addGpuBufferAsImageToStream(
this.graphRunner.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing();
return this.classificationResult;
@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
});
this.graphRunner.attachProtoListener(
CLASSIFICATIONS_STREAM, binaryProto => {
this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
protected process(image: ImageSource, timestamp: number):
ImageEmbedderResult {
// Get embeddings by running our MediaPipe graph.
this.addGpuBufferAsImageToStream(
this.graphRunner.addGpuBufferAsImageToStream(
image, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing();
return this.embeddings;
@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
graphConfig.addNode(embedderNode);
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
this.addJsImageEmdedding(binaryProto);
});

View File

@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
Detection[] {
// Get detections by running our MediaPipe graph.
this.detections = [];
this.addGpuBufferAsImageToStream(
this.graphRunner.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing();
return [...this.detections];
@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
graphConfig.addNode(detectorNode);
this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => {
this.addJsObjectDetections(binaryProto);
});
this.graphRunner.attachProtoVectorListener(
DETECTIONS_STREAM, binaryProto => {
this.addJsObjectDetections(binaryProto);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -22,7 +22,7 @@ export declare interface WasmImageModule {
* An implementation of GraphRunner that supports binding GPU image data as
* `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for
* effective multiple inheritance. Example usage:
* `const WasmMediaPipeImageLib = SupportImage(GraphRunner);`
* `const GraphRunnerImageLib = SupportImage(GraphRunner);`
*/
// tslint:disable-next-line:enforce-name-casing
export function SupportImage<TBase extends LibConstructor>(Base: TBase) {

View File

@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources {
* An implementation of GraphRunner that supports registering model
* resources to a cache, in the form of a GraphService C++-side. We implement as
* a proper TS mixin, to allow for effective multiple inheritance. Sample usage:
* `const WasmMediaPipeImageLib = SupportModelResourcesGraphService(
* GraphRunner);`
* `const GraphRunnerWithModelResourcesLib =
* SupportModelResourcesGraphService(GraphRunner);`
*/
// tslint:disable:enforce-name-casing
export function SupportModelResourcesGraphService<TBase extends LibConstructor>(