Add quality scores to Segmenter tasks
PiperOrigin-RevId: 534497957
This commit is contained in:
parent
87f525c76b
commit
1fe78180c8
|
@ -33,7 +33,7 @@ struct ImageSegmenterResult {
|
||||||
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
// A category mask of uint8 image in GRAY8 format where each pixel represents
|
||||||
// the class which the pixel in the original image was predicted to belong to.
|
// the class which the pixel in the original image was predicted to belong to.
|
||||||
std::optional<Image> category_mask;
|
std::optional<Image> category_mask;
|
||||||
// The quality scores of the result masks, in the range of [0, 1]. Default to
|
// The quality scores of the result masks, in the range of [0, 1]. Defaults to
|
||||||
// `1` if the model doesn't output quality scores. Each element corresponds to
|
// `1` if the model doesn't output quality scores. Each element corresponds to
|
||||||
// the score of the category in the model outputs.
|
// the score of the category in the model outputs.
|
||||||
std::vector<float> quality_scores;
|
std::vector<float> quality_scores;
|
||||||
|
|
|
@ -34,8 +34,8 @@ public abstract class ImageSegmenterResult implements TaskResult {
|
||||||
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
* @param categoryMask an {@link Optional} MPImage in IMAGE_FORMAT_ALPHA format representing a
|
||||||
* category mask, where each pixel represents the class which the pixel in the original image
|
* category mask, where each pixel represents the class which the pixel in the original image
|
||||||
* was predicted to belong to.
|
* was predicted to belong to.
|
||||||
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Default to
|
* @param qualityScores The quality scores of the result masks, in the range of [0, 1]. Defaults
|
||||||
* `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
* to `1` if the model doesn't output quality scores. Each element corresponds to the score of
|
||||||
* the category in the model outputs.
|
* the category in the model outputs.
|
||||||
* @param timestampMs a timestamp for this result.
|
* @param timestampMs a timestamp for this result.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -39,6 +39,7 @@ const IMAGE_STREAM = 'image_in';
|
||||||
const NORM_RECT_STREAM = 'norm_rect';
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
|
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||||
const IMAGE_SEGMENTER_GRAPH =
|
const IMAGE_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
@ -61,6 +62,7 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private categoryMask?: MPMask;
|
private categoryMask?: MPMask;
|
||||||
private confidenceMasks?: MPMask[];
|
private confidenceMasks?: MPMask[];
|
||||||
|
private qualityScores?: number[];
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
private userCallback?: ImageSegmenterCallback;
|
private userCallback?: ImageSegmenterCallback;
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
|
@ -367,12 +369,13 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.categoryMask = undefined;
|
this.categoryMask = undefined;
|
||||||
this.confidenceMasks = undefined;
|
this.confidenceMasks = undefined;
|
||||||
|
this.qualityScores = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
private processResults(): ImageSegmenterResult|void {
|
private processResults(): ImageSegmenterResult|void {
|
||||||
try {
|
try {
|
||||||
const result =
|
const result = new ImageSegmenterResult(
|
||||||
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
|
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(result);
|
this.userCallback(result);
|
||||||
} else {
|
} else {
|
||||||
|
@ -442,6 +445,20 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||||
|
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachFloatVectorListener(
|
||||||
|
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||||
|
this.qualityScores = scores;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
QUALITY_SCORES_STREAM, timestamp => {
|
||||||
|
this.categoryMask = undefined;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,13 @@ export class ImageSegmenterResult {
|
||||||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||||
* which the pixel in the original image was predicted to belong to.
|
* which the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
readonly categoryMask?: MPMask) {}
|
readonly categoryMask?: MPMask,
|
||||||
|
/**
|
||||||
|
* The quality scores of the result masks, in the range of [0, 1].
|
||||||
|
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||||
|
* element corresponds to the score of the category in the model outputs.
|
||||||
|
*/
|
||||||
|
readonly qualityScores?: number[]) {}
|
||||||
|
|
||||||
/** Frees the resources held by the category and confidence masks. */
|
/** Frees the resources held by the category and confidence masks. */
|
||||||
close(): void {
|
close(): void {
|
||||||
|
|
|
@ -35,6 +35,8 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
confidenceMasksListener:
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
qualityScoresListener:
|
||||||
|
((data: number[], timestamp: number) => void)|undefined;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
@ -52,6 +54,12 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
expect(stream).toEqual('confidence_masks');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.confidenceMasksListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
|
this.attachListenerSpies[2] =
|
||||||
|
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('quality_scores');
|
||||||
|
this.qualityScoresListener = listener;
|
||||||
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
});
|
});
|
||||||
|
@ -266,6 +274,7 @@ describe('ImageSegmenter', () => {
|
||||||
it('invokes listener after masks are available', async () => {
|
it('invokes listener after masks are available', async () => {
|
||||||
const categoryMask = new Uint8Array([1]);
|
const categoryMask = new Uint8Array([1]);
|
||||||
const confidenceMask = new Float32Array([0.0]);
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
|
const qualityScores = [1.0];
|
||||||
let listenerCalled = false;
|
let listenerCalled = false;
|
||||||
|
|
||||||
await imageSegmenter.setOptions(
|
await imageSegmenter.setOptions(
|
||||||
|
@ -283,11 +292,16 @@ describe('ImageSegmenter', () => {
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
expect(listenerCalled).toBeFalse();
|
expect(listenerCalled).toBeFalse();
|
||||||
|
imageSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||||
|
expect(listenerCalled).toBeFalse();
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
imageSegmenter.segment({} as HTMLImageElement, () => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
listenerCalled = true;
|
listenerCalled = true;
|
||||||
|
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.qualityScores).toEqual(qualityScores);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -42,6 +42,7 @@ const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||||
const ROI_IN_STREAM = 'roi_in';
|
const ROI_IN_STREAM = 'roi_in';
|
||||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
|
const QUALITY_SCORES_STREAM = 'quality_scores';
|
||||||
const IMAGEA_SEGMENTER_GRAPH =
|
const IMAGEA_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||||
|
@ -86,6 +87,7 @@ export type InteractiveSegmenterCallback =
|
||||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
private categoryMask?: MPMask;
|
private categoryMask?: MPMask;
|
||||||
private confidenceMasks?: MPMask[];
|
private confidenceMasks?: MPMask[];
|
||||||
|
private qualityScores?: number[];
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private userCallback?: InteractiveSegmenterCallback;
|
private userCallback?: InteractiveSegmenterCallback;
|
||||||
|
@ -284,12 +286,13 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.confidenceMasks = undefined;
|
this.confidenceMasks = undefined;
|
||||||
this.categoryMask = undefined;
|
this.categoryMask = undefined;
|
||||||
|
this.qualityScores = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
private processResults(): InteractiveSegmenterResult|void {
|
private processResults(): InteractiveSegmenterResult|void {
|
||||||
try {
|
try {
|
||||||
const result = new InteractiveSegmenterResult(
|
const result = new InteractiveSegmenterResult(
|
||||||
this.confidenceMasks, this.categoryMask);
|
this.confidenceMasks, this.categoryMask, this.qualityScores);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(result);
|
this.userCallback(result);
|
||||||
} else {
|
} else {
|
||||||
|
@ -361,6 +364,20 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
graphConfig.addOutputStream(QUALITY_SCORES_STREAM);
|
||||||
|
segmenterNode.addOutputStream('QUALITY_SCORES:' + QUALITY_SCORES_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachFloatVectorListener(
|
||||||
|
QUALITY_SCORES_STREAM, (scores, timestamp) => {
|
||||||
|
this.qualityScores = scores;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
QUALITY_SCORES_STREAM, timestamp => {
|
||||||
|
this.categoryMask = undefined;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,13 @@ export class InteractiveSegmenterResult {
|
||||||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||||
* which the pixel in the original image was predicted to belong to.
|
* which the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
readonly categoryMask?: MPMask) {}
|
readonly categoryMask?: MPMask,
|
||||||
|
/**
|
||||||
|
* The quality scores of the result masks, in the range of [0, 1].
|
||||||
|
* Defaults to `1` if the model doesn't output quality scores. Each
|
||||||
|
* element corresponds to the score of the category in the model outputs.
|
||||||
|
*/
|
||||||
|
readonly qualityScores?: number[]) {}
|
||||||
|
|
||||||
/** Frees the resources held by the category and confidence masks. */
|
/** Frees the resources held by the category and confidence masks. */
|
||||||
close(): void {
|
close(): void {
|
||||||
|
|
|
@ -46,6 +46,8 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
confidenceMasksListener:
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
qualityScoresListener:
|
||||||
|
((data: number[], timestamp: number) => void)|undefined;
|
||||||
lastRoi?: RenderDataProto;
|
lastRoi?: RenderDataProto;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
|
@ -64,6 +66,12 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||||
expect(stream).toEqual('confidence_masks');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.confidenceMasksListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
|
this.attachListenerSpies[2] =
|
||||||
|
spyOn(this.graphRunner, 'attachFloatVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('quality_scores');
|
||||||
|
this.qualityScoresListener = listener;
|
||||||
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
});
|
});
|
||||||
|
@ -277,9 +285,10 @@ describe('InteractiveSegmenter', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('invokes listener after masks are avaiblae', async () => {
|
it('invokes listener after masks are available', async () => {
|
||||||
const categoryMask = new Uint8Array([1]);
|
const categoryMask = new Uint8Array([1]);
|
||||||
const confidenceMask = new Float32Array([0.0]);
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
|
const qualityScores = [1.0];
|
||||||
let listenerCalled = false;
|
let listenerCalled = false;
|
||||||
|
|
||||||
await interactiveSegmenter.setOptions(
|
await interactiveSegmenter.setOptions(
|
||||||
|
@ -297,11 +306,16 @@ describe('InteractiveSegmenter', () => {
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
expect(listenerCalled).toBeFalse();
|
expect(listenerCalled).toBeFalse();
|
||||||
|
interactiveSegmenter.qualityScoresListener!(qualityScores, 1337);
|
||||||
|
expect(listenerCalled).toBeFalse();
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, () => {
|
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT, result => {
|
||||||
listenerCalled = true;
|
listenerCalled = true;
|
||||||
|
expect(result.categoryMask).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||||
|
expect(result.qualityScores).toEqual(qualityScores);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user