Invoke callback for ImageSegmenter while C++ Packets are active
PiperOrigin-RevId: 528047220
This commit is contained in:
parent
e15add2475
commit
a9721ae2fb
|
@ -60,6 +60,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private result: ImageSegmenterResult = {};
|
private result: ImageSegmenterResult = {};
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
|
private userCallback: ImageSegmenterCallback = () => {};
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||||
|
@ -232,14 +233,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
{};
|
{};
|
||||||
const userCallback =
|
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||||
typeof imageProcessingOptionsOrCallback === 'function' ?
|
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
callback!;
|
callback!;
|
||||||
|
|
||||||
this.reset();
|
this.reset();
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
userCallback(this.result);
|
this.userCallback = () => {};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -286,13 +286,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||||
timestampOrImageProcessingOptions :
|
timestampOrImageProcessingOptions :
|
||||||
timestampOrCallback as number;
|
timestampOrCallback as number;
|
||||||
const userCallback = typeof timestampOrCallback === 'function' ?
|
this.userCallback = typeof timestampOrCallback === 'function' ?
|
||||||
timestampOrCallback :
|
timestampOrCallback :
|
||||||
callback!;
|
callback!;
|
||||||
|
|
||||||
this.reset();
|
this.reset();
|
||||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||||
userCallback(this.result);
|
this.userCallback = () => {};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -314,6 +314,18 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
this.result = {};
|
this.result = {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Invokes the user callback once all data has been received. */
|
||||||
|
private maybeInvokeCallback(): void {
|
||||||
|
if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (this.outputCategoryMask && !('categoryMask' in this.result)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.userCallback(this.result);
|
||||||
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
protected override refreshGraph(): void {
|
protected override refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
|
@ -342,10 +354,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
this.result.confidenceMasks =
|
this.result.confidenceMasks =
|
||||||
masks.map(wasmImage => this.convertToMPImage(wasmImage));
|
masks.map(wasmImage => this.convertToMPImage(wasmImage));
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
this.maybeInvokeCallback();
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
CONFIDENCE_MASKS_STREAM, timestamp => {
|
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||||
|
this.result.confidenceMasks = undefined;
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
this.maybeInvokeCallback();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,10 +372,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
this.result.categoryMask = this.convertToMPImage(mask);
|
this.result.categoryMask = this.convertToMPImage(mask);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
this.maybeInvokeCallback();
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
CATEGORY_MASK_STREAM, timestamp => {
|
CATEGORY_MASK_STREAM, timestamp => {
|
||||||
|
this.result.categoryMask = undefined;
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
this.maybeInvokeCallback();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -262,4 +262,34 @@ describe('ImageSegmenter', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('invokes listener once masks are avaiblae', async () => {
|
||||||
|
const categoryMask = new Uint8ClampedArray([1]);
|
||||||
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
|
let listenerCalled = false;
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
expect(listenerCalled).toBeFalse();
|
||||||
|
imageSegmenter.categoryMaskListener!
|
||||||
|
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||||
|
expect(listenerCalled).toBeFalse();
|
||||||
|
imageSegmenter.confidenceMasksListener!(
|
||||||
|
[
|
||||||
|
{data: confidenceMask, width: 1, height: 1},
|
||||||
|
],
|
||||||
|
1337);
|
||||||
|
expect(listenerCalled).toBeTrue();
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Promise<void>(resolve => {
|
||||||
|
imageSegmenter.segment({} as HTMLImageElement, () => {
|
||||||
|
listenerCalled = true;
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user