Don't inherit from GraphRunner
PiperOrigin-RevId: 492584486
This commit is contained in:
parent
da9587033d
commit
e457039fc6
|
@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
protected override process(
|
||||
audioData: Float32Array, sampleRate: number,
|
||||
timestampMs: number): AudioClassifierResult[] {
|
||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
|
||||
this.graphRunner.addDoubleToStream(
|
||||
sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.graphRunner.addAudioToStreamWithShape(
|
||||
audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length,
|
||||
AUDIO_STREAM, timestampMs);
|
||||
|
||||
this.classificationResults = [];
|
||||
this.finishProcessing();
|
||||
|
@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
|
|||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoVectorListener(
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
|
||||
this.addJsAudioClassificationResults(binaryProtos);
|
||||
});
|
||||
|
|
|
@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
protected override process(
|
||||
audioData: Float32Array, sampleRate: number,
|
||||
timestampMs: number): AudioEmbedderResult[] {
|
||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs);
|
||||
this.graphRunner.addDoubleToStream(
|
||||
sampleRate, SAMPLE_RATE_STREAM, timestampMs);
|
||||
this.graphRunner.addAudioToStreamWithShape(
|
||||
audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length,
|
||||
AUDIO_STREAM, timestampMs);
|
||||
|
||||
this.embeddingResults = [];
|
||||
this.finishProcessing();
|
||||
|
@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
|
|||
|
||||
graphConfig.addNode(embedderNode);
|
||||
|
||||
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
|
||||
this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
|
||||
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
|
||||
this.embeddingResults.push(
|
||||
convertFromEmbeddingResultProto(embeddingResult));
|
||||
});
|
||||
|
||||
this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => {
|
||||
for (const binaryProto of data) {
|
||||
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
|
||||
this.embeddingResults.push(
|
||||
convertFromEmbeddingResultProto(embeddingResult));
|
||||
}
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
TIMESTAMPED_EMBEDDINGS_STREAM, data => {
|
||||
for (const binaryProto of data) {
|
||||
const embeddingResult =
|
||||
EmbeddingResult.deserializeBinary(binaryProto);
|
||||
this.embeddingResults.push(
|
||||
convertFromEmbeddingResultProto(embeddingResult));
|
||||
}
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset';
|
|||
const NO_ASSETS = undefined;
|
||||
|
||||
// tslint:disable-next-line:enforce-name-casing
|
||||
const WasmMediaPipeImageLib =
|
||||
const GraphRunnerImageLibType =
|
||||
SupportModelResourcesGraphService(SupportImage(GraphRunner));
|
||||
/** An implementation of the GraphRunner that supports image operations */
|
||||
export class GraphRunnerImageLib extends GraphRunnerImageLibType {}
|
||||
|
||||
/** Base class for all MediaPipe Tasks. */
|
||||
export abstract class TaskRunner<O extends TaskRunnerOptions> extends
|
||||
WasmMediaPipeImageLib {
|
||||
export abstract class TaskRunner<O extends TaskRunnerOptions> {
|
||||
protected abstract baseOptions: BaseOptionsProto;
|
||||
protected graphRunner: GraphRunnerImageLib;
|
||||
private processingErrors: Error[] = [];
|
||||
|
||||
/**
|
||||
|
@ -67,14 +69,14 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
|
|||
constructor(
|
||||
wasmModule: WasmModule,
|
||||
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||
super(wasmModule, glCanvas);
|
||||
this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas);
|
||||
|
||||
// Disables the automatic render-to-screen code, which allows for pure
|
||||
// CPU processing.
|
||||
this.setAutoRenderToScreen(false);
|
||||
this.graphRunner.setAutoRenderToScreen(false);
|
||||
|
||||
// Enables use of our model resource caching graph service.
|
||||
this.registerModelResourcesGraphService();
|
||||
this.graphRunner.registerModelResourcesGraphService();
|
||||
}
|
||||
|
||||
/** Configures the shared options of a MediaPipe Task. */
|
||||
|
@ -95,11 +97,11 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
|
|||
* @param isBinary This should be set to true if the graph is in
|
||||
* binary format, and false if it is in human-readable text format.
|
||||
*/
|
||||
override setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||
this.attachErrorListener((code, message) => {
|
||||
protected setGraph(graphData: Uint8Array, isBinary: boolean): void {
|
||||
this.graphRunner.attachErrorListener((code, message) => {
|
||||
this.processingErrors.push(new Error(message));
|
||||
});
|
||||
super.setGraph(graphData, isBinary);
|
||||
this.graphRunner.setGraph(graphData, isBinary);
|
||||
this.handleErrors();
|
||||
}
|
||||
|
||||
|
@ -108,8 +110,8 @@ export abstract class TaskRunner<O extends TaskRunnerOptions> extends
|
|||
* far as possible, performing all processing until no more processing can be
|
||||
* done.
|
||||
*/
|
||||
override finishProcessing(): void {
|
||||
super.finishProcessing();
|
||||
protected finishProcessing(): void {
|
||||
this.graphRunner.finishProcessing();
|
||||
this.handleErrors();
|
||||
}
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> {
|
|||
classify(text: string): TextClassifierResult {
|
||||
// Get classification result by running our MediaPipe graph.
|
||||
this.classificationResult = {classifications: []};
|
||||
this.addStringToStream(
|
||||
this.graphRunner.addStringToStream(
|
||||
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||
this.finishProcessing();
|
||||
return this.classificationResult;
|
||||
|
@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner<TextClassifierOptions> {
|
|||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
this.graphRunner.attachProtoListener(
|
||||
CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
|
|||
*/
|
||||
embed(text: string): TextEmbedderResult {
|
||||
// Get text embeddings by running our MediaPipe graph.
|
||||
this.addStringToStream(
|
||||
this.graphRunner.addStringToStream(
|
||||
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||
this.finishProcessing();
|
||||
return this.embeddingResult;
|
||||
|
@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner<TextEmbedderOptions> {
|
|||
|
||||
graphConfig.addNode(embedderNode);
|
||||
|
||||
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
|
||||
this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
|
||||
const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto);
|
||||
this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult);
|
||||
});
|
||||
|
|
|
@ -257,8 +257,9 @@ export class GestureRecognizer extends
|
|||
this.worldLandmarks = [];
|
||||
this.handednesses = [];
|
||||
|
||||
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
|
||||
this.addProtoToStream(
|
||||
this.graphRunner.addGpuBufferAsImageToStream(
|
||||
imageSource, IMAGE_STREAM, timestamp);
|
||||
this.graphRunner.addProtoToStream(
|
||||
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
|
||||
NORM_RECT_STREAM, timestamp);
|
||||
this.finishProcessing();
|
||||
|
@ -365,18 +366,22 @@ export class GestureRecognizer extends
|
|||
|
||||
graphConfig.addNode(recognizerNode);
|
||||
|
||||
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
|
||||
this.addJsLandmarks(binaryProto);
|
||||
});
|
||||
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
|
||||
this.adddJsWorldLandmarks(binaryProto);
|
||||
});
|
||||
this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => {
|
||||
this.gestures.push(...this.toJsCategories(binaryProto));
|
||||
});
|
||||
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
|
||||
this.handednesses.push(...this.toJsCategories(binaryProto));
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
LANDMARKS_STREAM, binaryProto => {
|
||||
this.addJsLandmarks(binaryProto);
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
WORLD_LANDMARKS_STREAM, binaryProto => {
|
||||
this.adddJsWorldLandmarks(binaryProto);
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
HAND_GESTURES_STREAM, binaryProto => {
|
||||
this.gestures.push(...this.toJsCategories(binaryProto));
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
HANDEDNESS_STREAM, binaryProto => {
|
||||
this.handednesses.push(...this.toJsCategories(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
this.worldLandmarks = [];
|
||||
this.handednesses = [];
|
||||
|
||||
this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp);
|
||||
this.addProtoToStream(
|
||||
this.graphRunner.addGpuBufferAsImageToStream(
|
||||
imageSource, IMAGE_STREAM, timestamp);
|
||||
this.graphRunner.addProtoToStream(
|
||||
FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect',
|
||||
NORM_RECT_STREAM, timestamp);
|
||||
this.finishProcessing();
|
||||
|
@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
|
|||
|
||||
graphConfig.addNode(landmarkerNode);
|
||||
|
||||
this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => {
|
||||
this.addJsLandmarks(binaryProto);
|
||||
});
|
||||
this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => {
|
||||
this.adddJsWorldLandmarks(binaryProto);
|
||||
});
|
||||
this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => {
|
||||
this.handednesses.push(...this.toJsCategories(binaryProto));
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
LANDMARKS_STREAM, binaryProto => {
|
||||
this.addJsLandmarks(binaryProto);
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
WORLD_LANDMARKS_STREAM, binaryProto => {
|
||||
this.adddJsWorldLandmarks(binaryProto);
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
HANDEDNESS_STREAM, binaryProto => {
|
||||
this.handednesses.push(...this.toJsCategories(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
|||
ImageClassifierResult {
|
||||
// Get classification result by running our MediaPipe graph.
|
||||
this.classificationResult = {classifications: []};
|
||||
this.addGpuBufferAsImageToStream(
|
||||
this.graphRunner.addGpuBufferAsImageToStream(
|
||||
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||
this.finishProcessing();
|
||||
return this.classificationResult;
|
||||
|
@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
|
|||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
this.graphRunner.attachProtoListener(
|
||||
CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
|||
protected process(image: ImageSource, timestamp: number):
|
||||
ImageEmbedderResult {
|
||||
// Get embeddings by running our MediaPipe graph.
|
||||
this.addGpuBufferAsImageToStream(
|
||||
this.graphRunner.addGpuBufferAsImageToStream(
|
||||
image, INPUT_STREAM, timestamp ?? performance.now());
|
||||
this.finishProcessing();
|
||||
return this.embeddings;
|
||||
|
@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
|
|||
|
||||
graphConfig.addNode(embedderNode);
|
||||
|
||||
this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
|
||||
this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => {
|
||||
this.addJsImageEmdedding(binaryProto);
|
||||
});
|
||||
|
||||
|
|
|
@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
Detection[] {
|
||||
// Get detections by running our MediaPipe graph.
|
||||
this.detections = [];
|
||||
this.addGpuBufferAsImageToStream(
|
||||
this.graphRunner.addGpuBufferAsImageToStream(
|
||||
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||
this.finishProcessing();
|
||||
return [...this.detections];
|
||||
|
@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
|
|||
|
||||
graphConfig.addNode(detectorNode);
|
||||
|
||||
this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => {
|
||||
this.addJsObjectDetections(binaryProto);
|
||||
});
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
DETECTIONS_STREAM, binaryProto => {
|
||||
this.addJsObjectDetections(binaryProto);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -22,7 +22,7 @@ export declare interface WasmImageModule {
|
|||
* An implementation of GraphRunner that supports binding GPU image data as
|
||||
* `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for
|
||||
* effective multiple inheritance. Example usage:
|
||||
* `const WasmMediaPipeImageLib = SupportImage(GraphRunner);`
|
||||
* `const GraphRunnerImageLib = SupportImage(GraphRunner);`
|
||||
*/
|
||||
// tslint:disable-next-line:enforce-name-casing
|
||||
export function SupportImage<TBase extends LibConstructor>(Base: TBase) {
|
||||
|
|
|
@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources {
|
|||
* An implementation of GraphRunner that supports registering model
|
||||
* resources to a cache, in the form of a GraphService C++-side. We implement as
|
||||
* a proper TS mixin, to allow for effective multiple inheritance. Sample usage:
|
||||
* `const WasmMediaPipeImageLib = SupportModelResourcesGraphService(
|
||||
* GraphRunner);`
|
||||
* `const GraphRunnerWithModelResourcesLib =
|
||||
* SupportModelResourcesGraphService(GraphRunner);`
|
||||
*/
|
||||
// tslint:disable:enforce-name-casing
|
||||
export function SupportModelResourcesGraphService<TBase extends LibConstructor>(
|
||||
|
|
Loading…
Reference in New Issue
Block a user