When returning multiple output streams together, keep them alive until callback.
PiperOrigin-RevId: 530771884
This commit is contained in:
parent
e391c76433
commit
f824424700
|
@ -28,9 +28,16 @@ import {WasmFileset} from './wasm_fileset';
|
||||||
// None of the MP Tasks ship bundle assets.
|
// None of the MP Tasks ship bundle assets.
|
||||||
const NO_ASSETS = undefined;
|
const NO_ASSETS = undefined;
|
||||||
|
|
||||||
|
// Internal stream names for temporarily keeping memory alive, then freeing it.
|
||||||
|
const FREE_MEMORY_STREAM = 'free_memory';
|
||||||
|
const UNUSED_STREAM_SUFFIX = '_unused_out';
|
||||||
|
|
||||||
// tslint:disable-next-line:enforce-name-casing
|
// tslint:disable-next-line:enforce-name-casing
|
||||||
const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
|
const CachedGraphRunnerType = SupportModelResourcesGraphService(GraphRunner);
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An implementation of the GraphRunner that exposes the resource graph
|
* An implementation of the GraphRunner that exposes the resource graph
|
||||||
* service.
|
* service.
|
||||||
|
@ -64,6 +71,7 @@ export abstract class TaskRunner {
|
||||||
protected abstract baseOptions: BaseOptionsProto;
|
protected abstract baseOptions: BaseOptionsProto;
|
||||||
private processingErrors: Error[] = [];
|
private processingErrors: Error[] = [];
|
||||||
private latestOutputTimestamp = 0;
|
private latestOutputTimestamp = 0;
|
||||||
|
private keepaliveNode?: CalculatorGraphConfig.Node;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new instance of a Mediapipe Task. Determines if SIMD is
|
* Creates a new instance of a Mediapipe Task. Determines if SIMD is
|
||||||
|
@ -177,6 +185,7 @@ export abstract class TaskRunner {
|
||||||
this.graphRunner.registerModelResourcesGraphService();
|
this.graphRunner.registerModelResourcesGraphService();
|
||||||
|
|
||||||
this.graphRunner.setGraph(graphData, isBinary);
|
this.graphRunner.setGraph(graphData, isBinary);
|
||||||
|
this.keepaliveNode = undefined;
|
||||||
this.handleErrors();
|
this.handleErrors();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,8 +266,36 @@ export abstract class TaskRunner {
|
||||||
this.baseOptions.setAcceleration(acceleration);
|
this.baseOptions.setAcceleration(acceleration);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a node to the graph to temporarily keep certain streams alive.
|
||||||
|
* NOTE: To use this call, PassThroughCalculator must be included in your wasm
|
||||||
|
* dependencies.
|
||||||
|
*/
|
||||||
|
protected addKeepaliveNode(graphConfig: CalculatorGraphConfig) {
|
||||||
|
this.keepaliveNode = new CalculatorGraphConfig.Node();
|
||||||
|
this.keepaliveNode.setCalculator('PassThroughCalculator');
|
||||||
|
this.keepaliveNode.addInputStream(FREE_MEMORY_STREAM);
|
||||||
|
this.keepaliveNode.addOutputStream(
|
||||||
|
FREE_MEMORY_STREAM + UNUSED_STREAM_SUFFIX);
|
||||||
|
graphConfig.addInputStream(FREE_MEMORY_STREAM);
|
||||||
|
graphConfig.addNode(this.keepaliveNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Adds streams to the keepalive node to be kept alive until callback. */
|
||||||
|
protected keepStreamAlive(streamName: string) {
|
||||||
|
this.keepaliveNode!.addInputStream(streamName);
|
||||||
|
this.keepaliveNode!.addOutputStream(streamName + UNUSED_STREAM_SUFFIX);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Frees any streams being kept alive by the keepStreamAlive callback. */
|
||||||
|
protected freeKeepaliveStreams() {
|
||||||
|
this.graphRunner.addBoolToStream(
|
||||||
|
true, FREE_MEMORY_STREAM, this.latestOutputTimestamp);
|
||||||
|
}
|
||||||
|
|
||||||
/** Closes and cleans up the resources held by this task. */
|
/** Closes and cleans up the resources held by this task. */
|
||||||
close(): void {
|
close(): void {
|
||||||
|
this.keepaliveNode = undefined;
|
||||||
this.graphRunner.closeGraph();
|
this.graphRunner.closeGraph();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ export function createSpyWasmModule(): SpyWasmModule {
|
||||||
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
|
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
|
||||||
'_addStringToInputStream', '_registerModelResourcesGraphService',
|
'_addStringToInputStream', '_registerModelResourcesGraphService',
|
||||||
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
|
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig',
|
||||||
'_closeGraph'
|
'_closeGraph', '_addBoolToInputStream'
|
||||||
]);
|
]);
|
||||||
spyWasmModule._getGraphConfig.and.callFake(() => {
|
spyWasmModule._getGraphConfig.and.callFake(() => {
|
||||||
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
|
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
|
||||||
|
@ -81,7 +81,10 @@ export function verifyGraph(
|
||||||
expectedBaseOptions?: FieldPathToValue,
|
expectedBaseOptions?: FieldPathToValue,
|
||||||
): void {
|
): void {
|
||||||
expect(tasksFake.graph).toBeDefined();
|
expect(tasksFake.graph).toBeDefined();
|
||||||
expect(tasksFake.graph!.getNodeList().length).toBe(1);
|
// Our graphs should have at least one node in them for processing, and
|
||||||
|
// sometimes one additional one for keeping alive certain streams in memory.
|
||||||
|
expect(tasksFake.graph!.getNodeList().length).toBeGreaterThanOrEqual(1);
|
||||||
|
expect(tasksFake.graph!.getNodeList().length).toBeLessThanOrEqual(2);
|
||||||
const node = tasksFake.graph!.getNodeList()[0].toObject();
|
const node = tasksFake.graph!.getNodeList()[0].toObject();
|
||||||
expect(node).toEqual(
|
expect(node).toEqual(
|
||||||
jasmine.objectContaining({calculator: tasksFake.calculatorName}));
|
jasmine.objectContaining({calculator: tasksFake.calculatorName}));
|
||||||
|
|
|
@ -383,6 +383,10 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(this.result);
|
this.userCallback(this.result);
|
||||||
|
|
||||||
|
// Free the image memory, now that we've kept all streams alive long
|
||||||
|
// enough to be returned in our callbacks.
|
||||||
|
this.freeKeepaliveStreams();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -403,11 +407,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
segmenterNode.setOptions(calculatorOptions);
|
segmenterNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(segmenterNode);
|
graphConfig.addNode(segmenterNode);
|
||||||
|
this.addKeepaliveNode(graphConfig);
|
||||||
|
|
||||||
if (this.outputConfidenceMasks) {
|
if (this.outputConfidenceMasks) {
|
||||||
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||||
segmenterNode.addOutputStream(
|
segmenterNode.addOutputStream(
|
||||||
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||||
|
this.keepStreamAlive(CONFIDENCE_MASKS_STREAM);
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
|
@ -428,6 +434,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
if (this.outputCategoryMask) {
|
if (this.outputCategoryMask) {
|
||||||
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||||
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||||
|
this.keepStreamAlive(CATEGORY_MASK_STREAM);
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
|
|
|
@ -297,6 +297,10 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(this.result);
|
this.userCallback(this.result);
|
||||||
|
|
||||||
|
// Free the image memory, now that we've kept all streams alive long
|
||||||
|
// enough to be returned in our callbacks.
|
||||||
|
this.freeKeepaliveStreams();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,11 +323,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
segmenterNode.setOptions(calculatorOptions);
|
segmenterNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(segmenterNode);
|
graphConfig.addNode(segmenterNode);
|
||||||
|
this.addKeepaliveNode(graphConfig);
|
||||||
|
|
||||||
if (this.outputConfidenceMasks) {
|
if (this.outputConfidenceMasks) {
|
||||||
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||||
segmenterNode.addOutputStream(
|
segmenterNode.addOutputStream(
|
||||||
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||||
|
this.keepStreamAlive(CONFIDENCE_MASKS_STREAM);
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
|
@ -344,6 +350,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
if (this.outputCategoryMask) {
|
if (this.outputCategoryMask) {
|
||||||
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||||
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||||
|
this.keepStreamAlive(CATEGORY_MASK_STREAM);
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
|
|
|
@ -376,6 +376,9 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
|
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(this.result as Required<PoseLandmarkerResult>);
|
this.userCallback(this.result as Required<PoseLandmarkerResult>);
|
||||||
|
|
||||||
|
// Free the image memory, now that we've finished our callback.
|
||||||
|
this.freeKeepaliveStreams();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,6 +440,9 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
landmarkerNode.setOptions(calculatorOptions);
|
landmarkerNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(landmarkerNode);
|
graphConfig.addNode(landmarkerNode);
|
||||||
|
// We only need to keep alive the image stream, since the protos are being
|
||||||
|
// deep-copied anyways via serialization+deserialization.
|
||||||
|
this.addKeepaliveNode(graphConfig);
|
||||||
|
|
||||||
this.graphRunner.attachProtoVectorListener(
|
this.graphRunner.attachProtoVectorListener(
|
||||||
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||||
|
@ -467,6 +473,8 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
if (this.outputSegmentationMasks) {
|
if (this.outputSegmentationMasks) {
|
||||||
landmarkerNode.addOutputStream(
|
landmarkerNode.addOutputStream(
|
||||||
'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);
|
'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);
|
||||||
|
this.keepStreamAlive(SEGMENTATION_MASK_STREAM);
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
|
SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
|
||||||
this.result.segmentationMasks = masks.map(
|
this.result.segmentationMasks = masks.map(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user