When returning multiple output streams together, keep them alive until callback.
PiperOrigin-RevId: 530771884
This commit is contained in:
		
							parent
							
								
									e391c76433
								
							
						
					
					
						commit
						f824424700
					
				|  | @ -28,9 +28,16 @@ import {WasmFileset} from './wasm_fileset'; | ||||||
| // None of the MP Tasks ship bundle assets.
 | // None of the MP Tasks ship bundle assets.
 | ||||||
| const NO_ASSETS = undefined; | const NO_ASSETS = undefined; | ||||||
| 
 | 
 | ||||||
|  | // Internal stream names for temporarily keeping memory alive, then freeing it.
 | ||||||
|  | const FREE_MEMORY_STREAM = 'free_memory'; | ||||||
|  | const UNUSED_STREAM_SUFFIX = '_unused_out'; | ||||||
|  | 
 | ||||||
| // tslint:disable-next-line:enforce-name-casing
 | // tslint:disable-next-line:enforce-name-casing
 | ||||||
| const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); | const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner); | ||||||
| 
 | 
 | ||||||
|  | // The OSS JS API does not support the builder pattern.
 | ||||||
|  | // tslint:disable:jspb-use-builder-pattern
 | ||||||
|  | 
 | ||||||
| /** | /** | ||||||
|  * An implementation of the GraphRunner that exposes the resource graph |  * An implementation of the GraphRunner that exposes the resource graph | ||||||
|  * service. |  * service. | ||||||
|  | @ -64,6 +71,7 @@ export abstract class TaskRunner { | ||||||
|   protected abstract baseOptions: BaseOptionsProto; |   protected abstract baseOptions: BaseOptionsProto; | ||||||
|   private processingErrors: Error[] = []; |   private processingErrors: Error[] = []; | ||||||
|   private latestOutputTimestamp = 0; |   private latestOutputTimestamp = 0; | ||||||
|  |   private keepaliveNode?: CalculatorGraphConfig.Node; | ||||||
| 
 | 
 | ||||||
|   /** |   /** | ||||||
|    * Creates a new instance of a Mediapipe Task. Determines if SIMD is |    * Creates a new instance of a Mediapipe Task. Determines if SIMD is | ||||||
|  | @ -177,6 +185,7 @@ export abstract class TaskRunner { | ||||||
|     this.graphRunner.registerModelResourcesGraphService(); |     this.graphRunner.registerModelResourcesGraphService(); | ||||||
| 
 | 
 | ||||||
|     this.graphRunner.setGraph(graphData, isBinary); |     this.graphRunner.setGraph(graphData, isBinary); | ||||||
|  |     this.keepaliveNode = undefined; | ||||||
|     this.handleErrors(); |     this.handleErrors(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -257,8 +266,36 @@ export abstract class TaskRunner { | ||||||
|     this.baseOptions.setAcceleration(acceleration); |     this.baseOptions.setAcceleration(acceleration); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /** | ||||||
|  |    * Adds a node to the graph to temporarily keep certain streams alive. | ||||||
|  |    * NOTE: To use this call, PassThroughCalculator must be included in your wasm | ||||||
|  |    *     dependencies. | ||||||
|  |    */ | ||||||
|  |   protected addKeepaliveNode(graphConfig: CalculatorGraphConfig) { | ||||||
|  |     this.keepaliveNode = new CalculatorGraphConfig.Node(); | ||||||
|  |     this.keepaliveNode.setCalculator('PassThroughCalculator'); | ||||||
|  |     this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM); | ||||||
|  |     this.keepaliveNode.addOutputStream( | ||||||
|  |         FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX); | ||||||
|  |     graphConfig.addInputStream(FREE_MEMORY_STREAM); | ||||||
|  |     graphConfig.addNode(this.keepaliveNode); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Adds streams to the keepalive node to be kept alive until callback. */ | ||||||
|  |   protected keepStreamAlive(streamName: string) { | ||||||
|  |     this.keepaliveNode!.addInputStream(streamName); | ||||||
|  |     this.keepaliveNode!.addOutputStream(streamName + UNUSED_STREAM_SUFFIX); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Frees any streams being kept alive by the keepStreamAlive callback. */ | ||||||
|  |   protected freeKeepaliveStreams() { | ||||||
|  |     this.graphRunner.addBoolToStream( | ||||||
|  |         true, FREE_MEMORY_STREAM, this.latestOutputTimestamp); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** Closes and cleans up the resources held by this task. */ |   /** Closes and cleans up the resources held by this task. */ | ||||||
|   close(): void { |   close(): void { | ||||||
|  |     this.keepaliveNode = undefined; | ||||||
|     this.graphRunner.closeGraph(); |     this.graphRunner.closeGraph(); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -37,7 +37,7 @@ export function createSpyWasmModule(): SpyWasmModule { | ||||||
|     '_attachProtoVectorListener', '_free', '_waitUntilIdle', |     '_attachProtoVectorListener', '_free', '_waitUntilIdle', | ||||||
|     '_addStringToInputStream', '_registerModelResourcesGraphService', |     '_addStringToInputStream', '_registerModelResourcesGraphService', | ||||||
|     '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig', |     '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig', | ||||||
|     '_closeGraph' |     '_closeGraph', '_addBoolToInputStream' | ||||||
|   ]); |   ]); | ||||||
|   spyWasmModule._getGraphConfig.and.callFake(() => { |   spyWasmModule._getGraphConfig.and.callFake(() => { | ||||||
|     (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as |     (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as | ||||||
|  | @ -81,7 +81,10 @@ export function verifyGraph( | ||||||
|     expectedBaseOptions?: FieldPathToValue, |     expectedBaseOptions?: FieldPathToValue, | ||||||
|     ): void { |     ): void { | ||||||
|   expect(tasksFake.graph).toBeDefined(); |   expect(tasksFake.graph).toBeDefined(); | ||||||
|   expect(tasksFake.graph!.getNodeList().length).toBe(1); |   // Our graphs should have at least one node in them for processing, and
 | ||||||
|  |   // sometimes one additional one for keeping alive certain streams in memory.
 | ||||||
|  |   expect(tasksFake.graph!.getNodeList().length).toBeGreaterThanOrEqual(1); | ||||||
|  |   expect(tasksFake.graph!.getNodeList().length).toBeLessThanOrEqual(2); | ||||||
|   const node = tasksFake.graph!.getNodeList()[0].toObject(); |   const node = tasksFake.graph!.getNodeList()[0].toObject(); | ||||||
|   expect(node).toEqual( |   expect(node).toEqual( | ||||||
|       jasmine.objectContaining({calculator: tasksFake.calculatorName})); |       jasmine.objectContaining({calculator: tasksFake.calculatorName})); | ||||||
|  |  | ||||||
|  | @ -383,6 +383,10 @@ export class ImageSegmenter extends VisionTaskRunner { | ||||||
| 
 | 
 | ||||||
|     if (this.userCallback) { |     if (this.userCallback) { | ||||||
|       this.userCallback(this.result); |       this.userCallback(this.result); | ||||||
|  | 
 | ||||||
|  |       // Free the image memory, now that we've kept all streams alive long
 | ||||||
|  |       // enough to be returned in our callbacks.
 | ||||||
|  |       this.freeKeepaliveStreams(); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -403,11 +407,13 @@ export class ImageSegmenter extends VisionTaskRunner { | ||||||
|     segmenterNode.setOptions(calculatorOptions); |     segmenterNode.setOptions(calculatorOptions); | ||||||
| 
 | 
 | ||||||
|     graphConfig.addNode(segmenterNode); |     graphConfig.addNode(segmenterNode); | ||||||
|  |     this.addKeepaliveNode(graphConfig); | ||||||
| 
 | 
 | ||||||
|     if (this.outputConfidenceMasks) { |     if (this.outputConfidenceMasks) { | ||||||
|       graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); |       graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); | ||||||
|       segmenterNode.addOutputStream( |       segmenterNode.addOutputStream( | ||||||
|           'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); |           'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); | ||||||
|  |       this.keepStreamAlive(CONFIDENCE_MASKS_STREAM); | ||||||
| 
 | 
 | ||||||
|       this.graphRunner.attachImageVectorListener( |       this.graphRunner.attachImageVectorListener( | ||||||
|           CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { |           CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { | ||||||
|  | @ -428,6 +434,7 @@ export class ImageSegmenter extends VisionTaskRunner { | ||||||
|     if (this.outputCategoryMask) { |     if (this.outputCategoryMask) { | ||||||
|       graphConfig.addOutputStream(CATEGORY_MASK_STREAM); |       graphConfig.addOutputStream(CATEGORY_MASK_STREAM); | ||||||
|       segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); |       segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); | ||||||
|  |       this.keepStreamAlive(CATEGORY_MASK_STREAM); | ||||||
| 
 | 
 | ||||||
|       this.graphRunner.attachImageListener( |       this.graphRunner.attachImageListener( | ||||||
|           CATEGORY_MASK_STREAM, (mask, timestamp) => { |           CATEGORY_MASK_STREAM, (mask, timestamp) => { | ||||||
|  |  | ||||||
|  | @ -297,6 +297,10 @@ export class InteractiveSegmenter extends VisionTaskRunner { | ||||||
| 
 | 
 | ||||||
|     if (this.userCallback) { |     if (this.userCallback) { | ||||||
|       this.userCallback(this.result); |       this.userCallback(this.result); | ||||||
|  | 
 | ||||||
|  |       // Free the image memory, now that we've kept all streams alive long
 | ||||||
|  |       // enough to be returned in our callbacks.
 | ||||||
|  |       this.freeKeepaliveStreams(); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -319,11 +323,13 @@ export class InteractiveSegmenter extends VisionTaskRunner { | ||||||
|     segmenterNode.setOptions(calculatorOptions); |     segmenterNode.setOptions(calculatorOptions); | ||||||
| 
 | 
 | ||||||
|     graphConfig.addNode(segmenterNode); |     graphConfig.addNode(segmenterNode); | ||||||
|  |     this.addKeepaliveNode(graphConfig); | ||||||
| 
 | 
 | ||||||
|     if (this.outputConfidenceMasks) { |     if (this.outputConfidenceMasks) { | ||||||
|       graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); |       graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); | ||||||
|       segmenterNode.addOutputStream( |       segmenterNode.addOutputStream( | ||||||
|           'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); |           'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); | ||||||
|  |       this.keepStreamAlive(CONFIDENCE_MASKS_STREAM); | ||||||
| 
 | 
 | ||||||
|       this.graphRunner.attachImageVectorListener( |       this.graphRunner.attachImageVectorListener( | ||||||
|           CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { |           CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { | ||||||
|  | @ -344,6 +350,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { | ||||||
|     if (this.outputCategoryMask) { |     if (this.outputCategoryMask) { | ||||||
|       graphConfig.addOutputStream(CATEGORY_MASK_STREAM); |       graphConfig.addOutputStream(CATEGORY_MASK_STREAM); | ||||||
|       segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); |       segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); | ||||||
|  |       this.keepStreamAlive(CATEGORY_MASK_STREAM); | ||||||
| 
 | 
 | ||||||
|       this.graphRunner.attachImageListener( |       this.graphRunner.attachImageListener( | ||||||
|           CATEGORY_MASK_STREAM, (mask, timestamp) => { |           CATEGORY_MASK_STREAM, (mask, timestamp) => { | ||||||
|  |  | ||||||
|  | @ -376,6 +376,9 @@ export class PoseLandmarker extends VisionTaskRunner { | ||||||
| 
 | 
 | ||||||
|     if (this.userCallback) { |     if (this.userCallback) { | ||||||
|       this.userCallback(this.result as Required<PoseLandmarkerResult>); |       this.userCallback(this.result as Required<PoseLandmarkerResult>); | ||||||
|  | 
 | ||||||
|  |       // Free the image memory, now that we've finished our callback.
 | ||||||
|  |       this.freeKeepaliveStreams(); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -437,6 +440,9 @@ export class PoseLandmarker extends VisionTaskRunner { | ||||||
|     landmarkerNode.setOptions(calculatorOptions); |     landmarkerNode.setOptions(calculatorOptions); | ||||||
| 
 | 
 | ||||||
|     graphConfig.addNode(landmarkerNode); |     graphConfig.addNode(landmarkerNode); | ||||||
|  |     // We only need to keep alive the image stream, since the protos are being
 | ||||||
|  |     // deep-copied anyways via serialization+deserialization.
 | ||||||
|  |     this.addKeepaliveNode(graphConfig); | ||||||
| 
 | 
 | ||||||
|     this.graphRunner.attachProtoVectorListener( |     this.graphRunner.attachProtoVectorListener( | ||||||
|         NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { |         NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { | ||||||
|  | @ -467,6 +473,8 @@ export class PoseLandmarker extends VisionTaskRunner { | ||||||
|     if (this.outputSegmentationMasks) { |     if (this.outputSegmentationMasks) { | ||||||
|       landmarkerNode.addOutputStream( |       landmarkerNode.addOutputStream( | ||||||
|           'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM); |           'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM); | ||||||
|  |       this.keepStreamAlive(SEGMENTATION_MASK_STREAM); | ||||||
|  | 
 | ||||||
|       this.graphRunner.attachImageVectorListener( |       this.graphRunner.attachImageVectorListener( | ||||||
|           SEGMENTATION_MASK_STREAM, (masks, timestamp) => { |           SEGMENTATION_MASK_STREAM, (masks, timestamp) => { | ||||||
|             this.result.segmentationMasks = masks.map( |             this.result.segmentationMasks = masks.map( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user