Add quality scores to Segmenter tasks
PiperOrigin-RevId: 534497957
This commit is contained in:
parent
87f525c76b
commit
1fe78180c8
|
@ -33,7 +33,7 @@ struct ImageSegmenterResult {
|
|||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||
// the class which the pixel in the original image was predicted to belong to.
|
||||
std::optional<Image> category_mask;
|
||||
// The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
|
||||
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||
// the score of the category in the model outputs.
|
||||
std::vector<float> quality_scores;
|
||||
|
|
|
@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
|||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||
* category mask, where each pixel represents the class which the pixel in the original image
|
||||
* was predicted to belong to.
|
||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
|
||||
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
|
||||
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||
* the category in the model outputs.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
|
|
|
@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
|
|||
const NORM_RECT_STREAM = 'norm_rect';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||
const IMAGE_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
|||
export class ImageSegmenter extends VisionTaskRunner {
|
||||
private categoryMask?: MPMask;
|
||||
private confidenceMasks?: MPMask[];
|
||||
private qualityScores?: number[];
|
||||
private labels: string[] = [];
|
||||
private userCallback?: ImageSegmenterCallback;
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
|
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
private reset(): void {
|
||||
this.categoryMask = undefined;
|
||||
this.confidenceMasks = undefined;
|
||||
this.qualityScores = undefined;
|
||||
}
|
||||
|
||||
private processResults(): ImageSegmenterResult|void {
|
||||
try {
|
||||
const result =
|
||||
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
|
||||
const result = new ImageSegmenterResult(
|
||||
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||
if (this.userCallback) {
|
||||
this.userCallback(result);
|
||||
} else {
|
||||
|
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
});
|
||||
}
|
||||
|
||||
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||
|
||||
this.graphRunner.attachFloatVectorListener(
|
||||
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||
this.qualityScores = scores;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
QUALITY_SCORES_STREAM, timestamp => {
|
||||
this.categoryMask = undefined;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,13 @@ export class ImageSegmenterResult {
|
|||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||
* which the pixel in the original image was predicted to belong to.
|
||||
*/
|
||||
readonly categoryMask?: MPMask) {}
|
||||
readonly categoryMask?: MPMask,
|
||||
/**
|
||||
* The quality scores of the result masks, in the range of [0, 1].
|
||||
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||
* element corresponds to the score of the category in the model outputs.
|
||||
*/
|
||||
readonly qualityScores?: number[]) {}
|
||||
|
||||
/** Frees the resources held by the category and confidence masks. */
|
||||
close(): void {
|
||||
|
|
|
@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
qualityScoresListener:
|
||||
((data: number[], timestamp: number) => void)|undefined;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
|
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[2] =
|
||||
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('quality_scores');
|
||||
this.qualityScoresListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
|
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
|
|||
it('invokes listener after masks are available', async () => {
|
||||
const categoryMask = new Uint8Array([1]);
|
||||
const confidenceMask = new Float32Array([0.0]);
|
||||
const qualityScores = [1.0];
|
||||
let listenerCalled = false;
|
||||
|
||||
await imageSegmenter.setOptions(
|
||||
|
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
|
|||
],
|
||||
1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, () => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||
listenerCalled = true;
|
||||
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||
expect(result.qualityScores).toEqual(qualityScores);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
|||
const ROI_IN_STREAM = 'roi_in';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||
|
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
|
|||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private categoryMask?: MPMask;
|
||||
private confidenceMasks?: MPMask[];
|
||||
private qualityScores?: number[];
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
private userCallback?: InteractiveSegmenterCallback;
|
||||
|
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
private reset(): void {
|
||||
this.confidenceMasks = undefined;
|
||||
this.categoryMask = undefined;
|
||||
this.qualityScores = undefined;
|
||||
}
|
||||
|
||||
private processResults(): InteractiveSegmenterResult|void {
|
||||
try {
|
||||
const result = new InteractiveSegmenterResult(
|
||||
this.confidenceMasks, this.categoryMask);
|
||||
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||
if (this.userCallback) {
|
||||
this.userCallback(result);
|
||||
} else {
|
||||
|
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
});
|
||||
}
|
||||
|
||||
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||
|
||||
this.graphRunner.attachFloatVectorListener(
|
||||
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||
this.qualityScores = scores;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
QUALITY_SCORES_STREAM, timestamp => {
|
||||
this.categoryMask = undefined;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
|
|||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||
* which the pixel in the original image was predicted to belong to.
|
||||
*/
|
||||
readonly categoryMask?: MPMask) {}
|
||||
readonly categoryMask?: MPMask,
|
||||
/**
|
||||
* The quality scores of the result masks, in the range of [0, 1].
|
||||
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||
* element corresponds to the score of the category in the model outputs.
|
||||
*/
|
||||
readonly qualityScores?: number[]) {}
|
||||
|
||||
/** Frees the resources held by the category and confidence masks. */
|
||||
close(): void {
|
||||
|
|
|
@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
qualityScoresListener:
|
||||
((data: number[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
constructor() {
|
||||
|
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[2] =
|
||||
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('quality_scores');
|
||||
this.qualityScoresListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
|
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('invokes listener after masks are avaiblae', async () => {
|
||||
it('invokes listener after masks are available', async () => {
|
||||
const categoryMask = new Uint8Array([1]);
|
||||
const confidenceMask = new Float32Array([0.0]);
|
||||
const qualityScores = [1.0];
|
||||
let listenerCalled = false;
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
|
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
|
|||
],
|
||||
1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
|
||||
listenerCalled = true;
|
||||
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||
expect(result.qualityScores).toEqual(qualityScores);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user