Do not depend on Image methods in TaskRunner

PiperOrigin-RevId: 500299571
This commit is contained in:
Sebastian Schmidt 2023-01-06 18:15:34 -08:00 committed by Copybara-Service
parent 2cce88080e
commit 9b34a105cf
21 changed files with 61 additions and 52 deletions

View File

@ -27,6 +27,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)

View File

@ -22,6 +22,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -98,7 +99,7 @@ export class AudioClassifier extends AudioTaskRunner<AudioClassifierResult[]> {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new CachedGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -27,6 +27,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/utils:cosine_similarity",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:embedder_options",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)

View File

@ -24,6 +24,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {CachedGraphRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -100,7 +101,7 @@ export class AudioEmbedder extends AudioTaskRunner<AudioEmbedderResult[]> {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new CachedGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -22,7 +22,6 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
],
@ -57,7 +56,6 @@ mediapipe_ts_library(
deps = [
":core",
":task_runner",
":task_runner_test_utils",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts",
],

View File

@ -19,8 +19,7 @@ import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner';
import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib';
import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor} from '../../../web/graph_runner/graph_runner';
import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service';
import {WasmFileset} from './wasm_fileset';
@ -29,10 +28,12 @@ import {WasmFileset} from './wasm_fileset';
const NO_ASSETS = undefined;
// tslint:disable-next-line:enforce-name-casing
const GraphRunnerImageLibType =
SupportModelResourcesGraphService(SupportImage(GraphRunner));
/** An implementation of the GraphRunner that supports image operations */
export class GraphRunnerImageLib extends GraphRunnerImageLibType {}
const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
/**
* An implementation of the GraphRunner that exposes the resource graph
* service.
*/
export class CachedGraphRunner extends CachedGraphRunnerType {}
/**
* Creates a new instance of a Mediapipe Task. Determines if SIMD is
@ -64,7 +65,6 @@ export async function createTaskRunner<T extends TaskRunner>(
/** Base class for all MediaPipe Tasks. */
export abstract class TaskRunner {
protected abstract baseOptions: BaseOptionsProto;
protected graphRunner: GraphRunnerImageLib;
private processingErrors: Error[] = [];
/**
@ -79,12 +79,7 @@ export abstract class TaskRunner {
}
/** @hideconstructor protected */
constructor(
wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null,
graphRunner?: GraphRunnerImageLib) {
this.graphRunner =
graphRunner ?? new GraphRunnerImageLib(wasmModule, glCanvas);
constructor(protected readonly graphRunner: CachedGraphRunner) {
// Disables the automatic render-to-screen code, which allows for pure
// CPU processing.
this.graphRunner.setAutoRenderToScreen(false);

View File

@ -18,11 +18,10 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {TaskRunner} from '../../../tasks/web/core/task_runner';
import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils';
import {ErrorListener} from '../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource URL builder
import {GraphRunnerImageLib} from './task_runner';
import {CachedGraphRunner} from './task_runner';
import {TaskRunnerOptions} from './task_runner_options.d';
class TaskRunnerFake extends TaskRunner {
@ -32,18 +31,15 @@ class TaskRunnerFake extends TaskRunner {
baseOptions = new BaseOptionsProto();
static createFake(): TaskRunnerFake {
const wasmModule = createSpyWasmModule();
return new TaskRunnerFake(wasmModule);
return new TaskRunnerFake();
}
constructor(wasmModuleFake: SpyWasmModule) {
super(
wasmModuleFake, /* glCanvas= */ null,
jasmine.createSpyObj<GraphRunnerImageLib>([
constructor() {
super(jasmine.createSpyObj<CachedGraphRunner>([
'setAutoRenderToScreen', 'setGraph', 'finishProcessing',
'registerModelResourcesGraphService', 'attachErrorListener'
]));
const graphRunner = this.graphRunner as jasmine.SpyObj<GraphRunnerImageLib>;
const graphRunner = this.graphRunner as jasmine.SpyObj<CachedGraphRunner>;
expect(graphRunner.registerModelResourcesGraphService).toHaveBeenCalled();
expect(graphRunner.setAutoRenderToScreen).toHaveBeenCalled();
graphRunner.attachErrorListener.and.callFake(listener => {

View File

@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -96,7 +96,7 @@ export class TextClassifier extends TaskRunner {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new CachedGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -23,7 +23,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r
import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options';
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {CachedGraphRunner, TaskRunner} from '../../../../tasks/web/core/task_runner';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -100,7 +100,7 @@ export class TextEmbedder extends TaskRunner {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new CachedGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -20,7 +20,9 @@ mediapipe_ts_library(
":vision_task_options",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts",
"//mediapipe/web/graph_runner:register_model_resources_graph_service_ts",
],
)

View File

@ -21,13 +21,13 @@ import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_u
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {VisionTaskOptions} from './vision_task_options';
import {VisionTaskRunner} from './vision_task_runner';
import {VisionGraphRunner, VisionTaskRunner} from './vision_task_runner';
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
baseOptions = new BaseOptionsProto();
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
super(new VisionGraphRunner(createSpyWasmModule(), /* glCanvas= */ null));
}
protected override process(): void {}

View File

@ -15,12 +15,25 @@
*/
import {TaskRunner} from '../../../../tasks/web/core/task_runner';
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner';
import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service';
import {VisionTaskOptions} from './vision_task_options';
// tslint:disable-next-line:enforce-name-casing
const GraphRunnerVisionType =
SupportModelResourcesGraphService(SupportImage(GraphRunner));
/** An implementation of the GraphRunner that supports image operations */
export class VisionGraphRunner extends GraphRunnerVisionType {}
/** Base class for all MediaPipe Vision Tasks. */
export abstract class VisionTaskRunner<T> extends TaskRunner {
/** @hideconstructor protected */
constructor(protected override readonly graphRunner: VisionGraphRunner) {
super(graphRunner);
}
/** Configures the shared options of a vision task. */
override applyOptions(options: VisionTaskOptions): Promise<void> {
if ('runningMode' in options) {

View File

@ -67,8 +67,8 @@ mediapipe_ts_library(
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
],
)

View File

@ -30,7 +30,7 @@ import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -131,7 +131,7 @@ export class GestureRecognizer extends
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new VisionGraphRunner(wasmModule, glCanvas));
this.options = new GestureRecognizerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());

View File

@ -18,8 +18,8 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer';
@ -98,7 +98,7 @@ class GestureRecognizerFake extends GestureRecognizer implements
spyOn(this.graphRunner, 'addProtoToStream');
}
getGraphRunner(): GraphRunnerImageLib {
getGraphRunner(): VisionGraphRunner {
return this.graphRunner;
}
}

View File

@ -62,8 +62,8 @@ mediapipe_ts_library(
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
],
)

View File

@ -26,7 +26,7 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han
import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -119,7 +119,7 @@ export class HandLandmarker extends VisionTaskRunner<HandLandmarkerResult> {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new VisionGraphRunner(wasmModule, glCanvas));
this.options = new HandLandmarkerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());

View File

@ -18,12 +18,13 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {HandLandmarker} from './hand_landmarker';
import {HandLandmarkerOptions} from './hand_landmarker_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
@ -87,7 +88,7 @@ class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake {
spyOn(this.graphRunner, 'addProtoToStream');
}
getGraphRunner(): GraphRunnerImageLib {
getGraphRunner(): VisionGraphRunner {
return this.graphRunner;
}
}

View File

@ -22,7 +22,7 @@ import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_cla
import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options';
import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -97,7 +97,7 @@ export class ImageClassifier extends VisionTaskRunner<ImageClassifierResult> {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new VisionGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -24,7 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr
import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result';
import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -99,7 +99,7 @@ export class ImageEmbedder extends VisionTaskRunner<ImageEmbedderResult> {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new VisionGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}

View File

@ -20,7 +20,7 @@ import {Detection as DetectionProto} from '../../../../framework/formats/detecti
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -96,7 +96,7 @@ export class ObjectDetector extends VisionTaskRunner<Detection[]> {
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(wasmModule, glCanvas);
super(new VisionGraphRunner(wasmModule, glCanvas));
this.options.setBaseOptions(new BaseOptionsProto());
}