Add quality scores to Segmenter tasks

PiperOrigin-RevId: 534497957
This commit is contained in:
Sebastian Schmidt 2023-05-23 11:33:56 -07:00 committed by Copybara-Service
parent 87f525c76b
commit 1fe78180c8
8 changed files with 85 additions and 11 deletions

View File

@ -33,7 +33,7 @@ struct ImageSegmenterResult {
// A category mask of uint8 image in GRAY8 format where each pixel represents
// the class which the pixel in the original image was predicted to belong to.
std::optional<Image> category_mask;
// The quality scores of the result masks, in the range of [0, 1]. Default to
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
// `1` if the model doesn't output quality scores. Each element corresponds to
// the score of the category in the model outputs.
std::vector<float> quality_scores;

View File

@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult {
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
* category mask, where each pixel represents the class which the pixel in the original image
* was predicted to belong to.
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
* the category in the model outputs.
* @param timestampMs a timestamp for this result.
*/

View File

@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
export class ImageSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private labels: string[] = [];
private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
private reset(): void {
this.categoryMask = undefined;
this.confidenceMasks = undefined;
this.qualityScores = undefined;
}
private processResults(): ImageSegmenterResult|void {
try {
const result =
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
const result = new ImageSegmenterResult(
this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) {
this.userCallback(result);
} else {
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
});
}
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}

View File

@ -30,7 +30,13 @@ export class ImageSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to.
*/
readonly categoryMask?: MPMask) {}
readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {

View File

@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false;
await imageSegmenter.setOptions(
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
],
1337);
expect(listenerCalled).toBeFalse();
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
});
return new Promise<void>(resolve => {
imageSegmenter.segment({} as HTMLImageElement, () => {
imageSegmenter.segment({} as HTMLImageElement, result => {
listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve();
});
});

View File

@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in';
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const QUALITY_SCORES_STREAM = 'quality_scores';
const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
export class InteractiveSegmenter extends VisionTaskRunner {
private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private qualityScores?: number[];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback?: InteractiveSegmenterCallback;
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
private reset(): void {
this.confidenceMasks = undefined;
this.categoryMask = undefined;
this.qualityScores = undefined;
}
private processResults(): InteractiveSegmenterResult|void {
try {
const result = new InteractiveSegmenterResult(
this.confidenceMasks, this.categoryMask);
this.confidenceMasks, this.categoryMask, this.qualityScores);
if (this.userCallback) {
this.userCallback(result);
} else {
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
});
}
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
this.graphRunner.attachFloatVectorListener(
QUALITY_SCORES_STREAM, (scores, timestamp) => {
this.qualityScores = scores;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
QUALITY_SCORES_STREAM, timestamp => {
this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}

View File

@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
* which the pixel in the original image was predicted to belong to.
*/
readonly categoryMask?: MPMask) {}
readonly categoryMask?: MPMask,
/**
* The quality scores of the result masks, in the range of [0, 1].
* Defaults to `1` if the model doesn't output quality scores. Each
* element corresponds to the score of the category in the model outputs.
*/
readonly qualityScores?: number[]) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {

View File

@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
qualityScoresListener:
((data: number[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto;
constructor() {
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
expect(stream).toEqual('confidence_masks');
this.confidenceMasksListener = listener;
});
this.attachListenerSpies[2] =
spyOn(this.graphRunner, 'attachFloatVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('quality_scores');
this.qualityScoresListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
});
});
it('invokes listener after masks are avaiblae', async () => {
it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]);
const qualityScores = [1.0];
let listenerCalled = false;
await interactiveSegmenter.setOptions(
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
],
1337);
expect(listenerCalled).toBeFalse();
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
expect(listenerCalled).toBeFalse();
});
return new Promise<void>(resolve => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
listenerCalled = true;
expect(result.categoryMask).toBeInstanceOf(MPMask);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
expect(result.qualityScores).toEqual(qualityScores);
resolve();
});
});