Invoke callback for InteractiveSegmenter while C++ Packets are active

PiperOrigin-RevId: 528053621
This commit is contained in:
Sebastian Schmidt 2023-04-28 20:31:57 -07:00 committed by Copybara-Service
parent d5c5457d25
commit 253f13ad62
2 changed files with 51 additions and 3 deletions

View File

@ -86,6 +86,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private result: InteractiveSegmenterResult = {}; private result: InteractiveSegmenterResult = {};
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback: InteractiveSegmenterCallback = () => {};
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto;
@ -241,21 +242,32 @@ export class InteractiveSegmenter extends VisionTaskRunner {
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
const userCallback = this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback!;
this.reset(); this.reset();
this.processRenderData(roi, this.getSynctheticTimestamp()); this.processRenderData(roi, this.getSynctheticTimestamp());
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
userCallback(this.result); this.userCallback = () => {};
} }
private reset(): void { private reset(): void {
this.result = {}; this.result = {};
} }
/** Invokes the user callback once all data has been received. */
private maybeInvokeCallback(): void {
if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) {
return;
}
if (this.outputCategoryMask && !('categoryMask' in this.result)) {
return;
}
this.userCallback(this.result);
}
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void { protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
@ -286,10 +298,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.result.confidenceMasks = this.result.confidenceMasks =
masks.map(wasmImage => this.convertToMPImage(wasmImage)); masks.map(wasmImage => this.convertToMPImage(wasmImage));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => { CONFIDENCE_MASKS_STREAM, timestamp => {
this.result.confidenceMasks = undefined;
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }
@ -301,10 +316,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPImage(mask); this.result.categoryMask = this.convertToMPImage(mask);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => { CATEGORY_MASK_STREAM, timestamp => {
this.result.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }

View File

@ -252,4 +252,34 @@ describe('InteractiveSegmenter', () => {
}); });
}); });
}); });
it('invokes listener once masks are avaiblae', async () => {
const categoryMask = new Uint8ClampedArray([1]);
const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false;
await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true});
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(listenerCalled).toBeFalse();
interactiveSegmenter.categoryMaskListener!
({data: categoryMask, width: 1, height: 1}, 1337);
expect(listenerCalled).toBeFalse();
interactiveSegmenter.confidenceMasksListener!(
[
{data: confidenceMask, width: 1, height: 1},
],
1337);
expect(listenerCalled).toBeTrue();
});
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {
listenerCalled = true;
resolve();
});
});
});
}); });