diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 91a38cd44..efeffbb87 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -28,9 +28,16 @@ import {WasmFileset} from './wasm_fileset'; // None of the MP Tasks ship bundle assets. const NO_ASSETS = undefined; +// Internal stream names for temporarily keeping memory alive, then freeing it. +const FREE_MEMORY_STREAM = 'free_memory'; +const UNUSED_STREAM_SUFFIX = '_unused_out'; + // tslint:disable-next-line:enforce-name-casing const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + /** * An implementation of the GraphRunner that exposes the resource graph * service. @@ -64,6 +71,7 @@ export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; private latestOutputTimestamp = 0; + private keepaliveNode?: CalculatorGraphConfig.Node; /** * Creates a new instance of a Mediapipe Task. Determines if SIMD is @@ -177,6 +185,7 @@ export abstract class TaskRunner { this.graphRunner.registerModelResourcesGraphService(); this.graphRunner.setGraph(graphData, isBinary); + this.keepaliveNode = undefined; this.handleErrors(); } @@ -257,8 +266,36 @@ export abstract class TaskRunner { this.baseOptions.setAcceleration(acceleration); } + /** + * Adds a node to the graph to temporarily keep certain streams alive. + * NOTE: To use this call, PassThroughCalculator must be included in your wasm + * dependencies. + */ + protected addKeepaliveNode(graphConfig: CalculatorGraphConfig) { + this.keepaliveNode = new CalculatorGraphConfig.Node(); + this.keepaliveNode.setCalculator('PassThroughCalculator'); + this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM); + this.keepaliveNode.addOutputStream( + FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX); + graphConfig.addInputStream(FREE_MEMORY_STREAM); + graphConfig.addNode(this.keepaliveNode); + } + + /** Adds streams to the keepalive node to be kept alive until callback. */ + protected keepStreamAlive(streamName: string) { + this.keepaliveNode!.addInputStream(streamName); + this.keepaliveNode!.addOutputStream(streamName + UNUSED_STREAM_SUFFIX); + } + + /** Frees any streams being kept alive by the keepStreamAlive callback. */ + protected freeKeepaliveStreams() { + this.graphRunner.addBoolToStream( + true, FREE_MEMORY_STREAM, this.latestOutputTimestamp); + } + /** Closes and cleans up the resources held by this task. */ close(): void { + this.keepaliveNode = undefined; this.graphRunner.closeGraph(); } } diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index edf1d0d32..1532eb2a5 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -37,7 +37,7 @@ export function createSpyWasmModule(): SpyWasmModule { '_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addStringToInputStream', '_registerModelResourcesGraphService', '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig', - '_closeGraph' + '_closeGraph', '_addBoolToInputStream' ]); spyWasmModule._getGraphConfig.and.callFake(() => { (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as @@ -81,7 +81,10 @@ export function verifyGraph( expectedBaseOptions?: FieldPathToValue, ): void { expect(tasksFake.graph).toBeDefined(); - expect(tasksFake.graph!.getNodeList().length).toBe(1); + // Our graphs should have at least one node in them for processing, and + // sometimes one additional one for keeping alive certain streams in memory. + expect(tasksFake.graph!.getNodeList().length).toBeGreaterThanOrEqual(1); + expect(tasksFake.graph!.getNodeList().length).toBeLessThanOrEqual(2); const node = tasksFake.graph!.getNodeList()[0].toObject(); expect(node).toEqual( jasmine.objectContaining({calculator: tasksFake.calculatorName})); diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 39e57d94e..b12adb0df 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -383,6 +383,10 @@ export class ImageSegmenter extends VisionTaskRunner { if (this.userCallback) { this.userCallback(this.result); + + // Free the image memory, now that we've kept all streams alive long + // enough to be returned in our callbacks. + this.freeKeepaliveStreams(); } } @@ -403,11 +407,13 @@ export class ImageSegmenter extends VisionTaskRunner { segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); + this.addKeepaliveNode(graphConfig); if (this.outputConfidenceMasks) { graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); segmenterNode.addOutputStream( 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + this.keepStreamAlive(CONFIDENCE_MASKS_STREAM); this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { @@ -428,6 +434,7 @@ export class ImageSegmenter extends VisionTaskRunner { if (this.outputCategoryMask) { graphConfig.addOutputStream(CATEGORY_MASK_STREAM); segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + this.keepStreamAlive(CATEGORY_MASK_STREAM); this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index 2a51a5fcf..e3f79d26d 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -297,6 +297,10 @@ export class InteractiveSegmenter extends VisionTaskRunner { if (this.userCallback) { this.userCallback(this.result); + + // Free the image memory, now that we've kept all streams alive long + // enough to be returned in our callbacks. + this.freeKeepaliveStreams(); } } @@ -319,11 +323,13 @@ export class InteractiveSegmenter extends VisionTaskRunner { segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); + this.addKeepaliveNode(graphConfig); if (this.outputConfidenceMasks) { graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); segmenterNode.addOutputStream( 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + this.keepStreamAlive(CONFIDENCE_MASKS_STREAM); this.graphRunner.attachImageVectorListener( CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { @@ -344,6 +350,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { if (this.outputCategoryMask) { graphConfig.addOutputStream(CATEGORY_MASK_STREAM); segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + this.keepStreamAlive(CATEGORY_MASK_STREAM); this.graphRunner.attachImageListener( CATEGORY_MASK_STREAM, (mask, timestamp) => { diff --git a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts index 87fdacbc2..0d3181aa0 100644 --- a/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts +++ b/mediapipe/tasks/web/vision/pose_landmarker/pose_landmarker.ts @@ -376,6 +376,9 @@ export class PoseLandmarker extends VisionTaskRunner { if (this.userCallback) { this.userCallback(this.result as Required); + + // Free the image memory, now that we've finished our callback. + this.freeKeepaliveStreams(); } } @@ -437,6 +440,9 @@ export class PoseLandmarker extends VisionTaskRunner { landmarkerNode.setOptions(calculatorOptions); graphConfig.addNode(landmarkerNode); + // We only need to keep alive the image stream, since the protos are being + // deep-copied anyways via serialization+deserialization. + this.addKeepaliveNode(graphConfig); this.graphRunner.attachProtoVectorListener( NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { @@ -467,6 +473,8 @@ export class PoseLandmarker extends VisionTaskRunner { if (this.outputSegmentationMasks) { landmarkerNode.addOutputStream( 'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM); + this.keepStreamAlive(SEGMENTATION_MASK_STREAM); + this.graphRunner.attachImageVectorListener( SEGMENTATION_MASK_STREAM, (masks, timestamp) => { this.result.segmentationMasks = masks.map(