When returning multiple output streams together, keep them alive until callback.
PiperOrigin-RevId: 530771884
This commit is contained in:
		
							parent
							
								
									e391c76433
								
							
						
					
					
						commit
						f824424700
					
				| 
						 | 
					@ -28,9 +28,16 @@ import {WasmFileset} from './wasm_fileset';
 | 
				
			||||||
// None of the MP Tasks ship bundle assets.
 | 
					// None of the MP Tasks ship bundle assets.
 | 
				
			||||||
const NO_ASSETS = undefined;
 | 
					const NO_ASSETS = undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Internal stream names for temporarily keeping memory alive, then freeing it.
 | 
				
			||||||
 | 
					const FREE_MEMORY_STREAM = 'free_memory';
 | 
				
			||||||
 | 
					const UNUSED_STREAM_SUFFIX = '_unused_out';
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// tslint:disable-next-line:enforce-name-casing
 | 
					// tslint:disable-next-line:enforce-name-casing
 | 
				
			||||||
const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
 | 
					const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The OSS JS API does not support the builder pattern.
 | 
				
			||||||
 | 
					// tslint:disable:jspb-use-builder-pattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * An implementation of the GraphRunner that exposes the resource graph
 | 
					 * An implementation of the GraphRunner that exposes the resource graph
 | 
				
			||||||
 * service.
 | 
					 * service.
 | 
				
			||||||
| 
						 | 
					@ -64,6 +71,7 @@ export abstract class TaskRunner {
 | 
				
			||||||
  protected abstract baseOptions: BaseOptionsProto;
 | 
					  protected abstract baseOptions: BaseOptionsProto;
 | 
				
			||||||
  private processingErrors: Error[] = [];
 | 
					  private processingErrors: Error[] = [];
 | 
				
			||||||
  private latestOutputTimestamp = 0;
 | 
					  private latestOutputTimestamp = 0;
 | 
				
			||||||
 | 
					  private keepaliveNode?: CalculatorGraphConfig.Node;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Creates a new instance of a Mediapipe Task. Determines if SIMD is
 | 
					   * Creates a new instance of a Mediapipe Task. Determines if SIMD is
 | 
				
			||||||
| 
						 | 
					@ -177,6 +185,7 @@ export abstract class TaskRunner {
 | 
				
			||||||
    this.graphRunner.registerModelResourcesGraphService();
 | 
					    this.graphRunner.registerModelResourcesGraphService();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    this.graphRunner.setGraph(graphData, isBinary);
 | 
					    this.graphRunner.setGraph(graphData, isBinary);
 | 
				
			||||||
 | 
					    this.keepaliveNode = undefined;
 | 
				
			||||||
    this.handleErrors();
 | 
					    this.handleErrors();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -257,8 +266,36 @@ export abstract class TaskRunner {
 | 
				
			||||||
    this.baseOptions.setAcceleration(acceleration);
 | 
					    this.baseOptions.setAcceleration(acceleration);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Adds a node to the graph to temporarily keep certain streams alive.
 | 
				
			||||||
 | 
					   * NOTE: To use this call, PassThroughCalculator must be included in your wasm
 | 
				
			||||||
 | 
					   *     dependencies.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  protected addKeepaliveNode(graphConfig: CalculatorGraphConfig) {
 | 
				
			||||||
 | 
					    this.keepaliveNode = new CalculatorGraphConfig.Node();
 | 
				
			||||||
 | 
					    this.keepaliveNode.setCalculator('PassThroughCalculator');
 | 
				
			||||||
 | 
					    this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM);
 | 
				
			||||||
 | 
					    this.keepaliveNode.addOutputStream(
 | 
				
			||||||
 | 
					        FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX);
 | 
				
			||||||
 | 
					    graphConfig.addInputStream(FREE_MEMORY_STREAM);
 | 
				
			||||||
 | 
					    graphConfig.addNode(this.keepaliveNode);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Adds streams to the keepalive node to be kept alive until callback. */
 | 
				
			||||||
 | 
					  protected keepStreamAlive(streamName: string) {
 | 
				
			||||||
 | 
					    this.keepaliveNode!.addInputStream(streamName);
 | 
				
			||||||
 | 
					    this.keepaliveNode!.addOutputStream(streamName + UNUSED_STREAM_SUFFIX);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Frees any streams being kept alive by the keepStreamAlive callback. */
 | 
				
			||||||
 | 
					  protected freeKeepaliveStreams() {
 | 
				
			||||||
 | 
					    this.graphRunner.addBoolToStream(
 | 
				
			||||||
 | 
					        true, FREE_MEMORY_STREAM, this.latestOutputTimestamp);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Closes and cleans up the resources held by this task. */
 | 
					  /** Closes and cleans up the resources held by this task. */
 | 
				
			||||||
  close(): void {
 | 
					  close(): void {
 | 
				
			||||||
 | 
					    this.keepaliveNode = undefined;
 | 
				
			||||||
    this.graphRunner.closeGraph();
 | 
					    this.graphRunner.closeGraph();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,7 @@ export function createSpyWasmModule(): SpyWasmModule {
 | 
				
			||||||
    '_attachProtoVectorListener', '_free', '_waitUntilIdle',
 | 
					    '_attachProtoVectorListener', '_free', '_waitUntilIdle',
 | 
				
			||||||
    '_addStringToInputStream', '_registerModelResourcesGraphService',
 | 
					    '_addStringToInputStream', '_registerModelResourcesGraphService',
 | 
				
			||||||
    '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
 | 
					    '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
 | 
				
			||||||
    '_closeGraph'
 | 
					    '_closeGraph', '_addBoolToInputStream'
 | 
				
			||||||
  ]);
 | 
					  ]);
 | 
				
			||||||
  spyWasmModule._getGraphConfig.and.callFake(() => {
 | 
					  spyWasmModule._getGraphConfig.and.callFake(() => {
 | 
				
			||||||
    (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
 | 
					    (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
 | 
				
			||||||
| 
						 | 
					@ -81,7 +81,10 @@ export function verifyGraph(
 | 
				
			||||||
    expectedBaseOptions?: FieldPathToValue,
 | 
					    expectedBaseOptions?: FieldPathToValue,
 | 
				
			||||||
    ): void {
 | 
					    ): void {
 | 
				
			||||||
  expect(tasksFake.graph).toBeDefined();
 | 
					  expect(tasksFake.graph).toBeDefined();
 | 
				
			||||||
  expect(tasksFake.graph!.getNodeList().length).toBe(1);
 | 
					  // Our graphs should have at least one node in them for processing, and
 | 
				
			||||||
 | 
					  // sometimes one additional one for keeping alive certain streams in memory.
 | 
				
			||||||
 | 
					  expect(tasksFake.graph!.getNodeList().length).toBeGreaterThanOrEqual(1);
 | 
				
			||||||
 | 
					  expect(tasksFake.graph!.getNodeList().length).toBeLessThanOrEqual(2);
 | 
				
			||||||
  const node = tasksFake.graph!.getNodeList()[0].toObject();
 | 
					  const node = tasksFake.graph!.getNodeList()[0].toObject();
 | 
				
			||||||
  expect(node).toEqual(
 | 
					  expect(node).toEqual(
 | 
				
			||||||
      jasmine.objectContaining({calculator: tasksFake.calculatorName}));
 | 
					      jasmine.objectContaining({calculator: tasksFake.calculatorName}));
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -383,6 +383,10 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (this.userCallback) {
 | 
					    if (this.userCallback) {
 | 
				
			||||||
      this.userCallback(this.result);
 | 
					      this.userCallback(this.result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Free the image memory, now that we've kept all streams alive long
 | 
				
			||||||
 | 
					      // enough to be returned in our callbacks.
 | 
				
			||||||
 | 
					      this.freeKeepaliveStreams();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -403,11 +407,13 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    segmenterNode.setOptions(calculatorOptions);
 | 
					    segmenterNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    graphConfig.addNode(segmenterNode);
 | 
					    graphConfig.addNode(segmenterNode);
 | 
				
			||||||
 | 
					    this.addKeepaliveNode(graphConfig);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (this.outputConfidenceMasks) {
 | 
					    if (this.outputConfidenceMasks) {
 | 
				
			||||||
      graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
 | 
					      graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
      segmenterNode.addOutputStream(
 | 
					      segmenterNode.addOutputStream(
 | 
				
			||||||
          'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
 | 
					          'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
 | 
					      this.keepStreamAlive(CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageVectorListener(
 | 
					      this.graphRunner.attachImageVectorListener(
 | 
				
			||||||
          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
					          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
				
			||||||
| 
						 | 
					@ -428,6 +434,7 @@ export class ImageSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    if (this.outputCategoryMask) {
 | 
					    if (this.outputCategoryMask) {
 | 
				
			||||||
      graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
 | 
					      graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
 | 
				
			||||||
      segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
 | 
					      segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
 | 
				
			||||||
 | 
					      this.keepStreamAlive(CATEGORY_MASK_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageListener(
 | 
					      this.graphRunner.attachImageListener(
 | 
				
			||||||
          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
					          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -297,6 +297,10 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (this.userCallback) {
 | 
					    if (this.userCallback) {
 | 
				
			||||||
      this.userCallback(this.result);
 | 
					      this.userCallback(this.result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Free the image memory, now that we've kept all streams alive long
 | 
				
			||||||
 | 
					      // enough to be returned in our callbacks.
 | 
				
			||||||
 | 
					      this.freeKeepaliveStreams();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -319,11 +323,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    segmenterNode.setOptions(calculatorOptions);
 | 
					    segmenterNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    graphConfig.addNode(segmenterNode);
 | 
					    graphConfig.addNode(segmenterNode);
 | 
				
			||||||
 | 
					    this.addKeepaliveNode(graphConfig);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (this.outputConfidenceMasks) {
 | 
					    if (this.outputConfidenceMasks) {
 | 
				
			||||||
      graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
 | 
					      graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
      segmenterNode.addOutputStream(
 | 
					      segmenterNode.addOutputStream(
 | 
				
			||||||
          'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
 | 
					          'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
 | 
					      this.keepStreamAlive(CONFIDENCE_MASKS_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageVectorListener(
 | 
					      this.graphRunner.attachImageVectorListener(
 | 
				
			||||||
          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
					          CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
 | 
				
			||||||
| 
						 | 
					@ -344,6 +350,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
 | 
				
			||||||
    if (this.outputCategoryMask) {
 | 
					    if (this.outputCategoryMask) {
 | 
				
			||||||
      graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
 | 
					      graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
 | 
				
			||||||
      segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
 | 
					      segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
 | 
				
			||||||
 | 
					      this.keepStreamAlive(CATEGORY_MASK_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageListener(
 | 
					      this.graphRunner.attachImageListener(
 | 
				
			||||||
          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
					          CATEGORY_MASK_STREAM, (mask, timestamp) => {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -376,6 +376,9 @@ export class PoseLandmarker extends VisionTaskRunner {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (this.userCallback) {
 | 
					    if (this.userCallback) {
 | 
				
			||||||
      this.userCallback(this.result as Required<PoseLandmarkerResult>);
 | 
					      this.userCallback(this.result as Required<PoseLandmarkerResult>);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Free the image memory, now that we've finished our callback.
 | 
				
			||||||
 | 
					      this.freeKeepaliveStreams();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -437,6 +440,9 @@ export class PoseLandmarker extends VisionTaskRunner {
 | 
				
			||||||
    landmarkerNode.setOptions(calculatorOptions);
 | 
					    landmarkerNode.setOptions(calculatorOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    graphConfig.addNode(landmarkerNode);
 | 
					    graphConfig.addNode(landmarkerNode);
 | 
				
			||||||
 | 
					    // We only need to keep alive the image stream, since the protos are being
 | 
				
			||||||
 | 
					    // deep-copied anyways via serialization+deserialization.
 | 
				
			||||||
 | 
					    this.addKeepaliveNode(graphConfig);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    this.graphRunner.attachProtoVectorListener(
 | 
					    this.graphRunner.attachProtoVectorListener(
 | 
				
			||||||
        NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
 | 
					        NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
 | 
				
			||||||
| 
						 | 
					@ -467,6 +473,8 @@ export class PoseLandmarker extends VisionTaskRunner {
 | 
				
			||||||
    if (this.outputSegmentationMasks) {
 | 
					    if (this.outputSegmentationMasks) {
 | 
				
			||||||
      landmarkerNode.addOutputStream(
 | 
					      landmarkerNode.addOutputStream(
 | 
				
			||||||
          'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);
 | 
					          'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);
 | 
				
			||||||
 | 
					      this.keepStreamAlive(SEGMENTATION_MASK_STREAM);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      this.graphRunner.attachImageVectorListener(
 | 
					      this.graphRunner.attachImageVectorListener(
 | 
				
			||||||
          SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
 | 
					          SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
 | 
				
			||||||
            this.result.segmentationMasks = masks.map(
 | 
					            this.result.segmentationMasks = masks.map(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user