Add .close() method to ImageSegmenterResult/InteractiveSegmenterResult/PoseLandmarkerResult

PiperOrigin-RevId: 530973944
This commit is contained in:
Sebastian Schmidt 2023-05-10 12:17:50 -07:00 committed by Copybara-Service
parent a7ede9235c
commit 1666f3ed80
9 changed files with 133 additions and 138 deletions

View File

@ -22,6 +22,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {LabelMapItem} from '../../../../util/label_map_pb'; import {LabelMapItem} from '../../../../util/label_map_pb';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
@ -58,7 +59,8 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
/** Performs image segmentation on images. */ /** Performs image segmentation on images. */
export class ImageSegmenter extends VisionTaskRunner { export class ImageSegmenter extends VisionTaskRunner {
private result: ImageSegmenterResult = {}; private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private labels: string[] = []; private labels: string[] = [];
private userCallback?: ImageSegmenterCallback; private userCallback?: ImageSegmenterCallback;
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
@ -265,10 +267,7 @@ export class ImageSegmenter extends VisionTaskRunner {
this.reset(); this.reset();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
return this.processResults();
if (!this.userCallback) {
return this.result;
}
} }
/** /**
@ -347,10 +346,7 @@ export class ImageSegmenter extends VisionTaskRunner {
this.reset(); this.reset();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp); this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
return this.processResults();
if (!this.userCallback) {
return this.result;
}
} }
/** /**
@ -369,21 +365,20 @@ export class ImageSegmenter extends VisionTaskRunner {
} }
private reset(): void { private reset(): void {
this.result = {}; this.categoryMask = undefined;
} this.confidenceMasks = undefined;
/** Invokes the user callback once all data has been received. */
private maybeInvokeCallback(): void {
if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) {
return;
}
if (this.outputCategoryMask && !('categoryMask' in this.result)) {
return;
} }
private processResults(): ImageSegmenterResult|void {
try {
const result =
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(this.result); this.userCallback(result);
} else {
return result;
}
} finally {
// Free the image memory, now that we've kept all streams alive long // Free the image memory, now that we've kept all streams alive long
// enough to be returned in our callbacks. // enough to be returned in our callbacks.
this.freeKeepaliveStreams(); this.freeKeepaliveStreams();
@ -417,17 +412,15 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map( this.confidenceMasks = masks.map(
wasmImage => this.convertToMPMask( wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback)); wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => { CONFIDENCE_MASKS_STREAM, timestamp => {
this.result.confidenceMasks = undefined; this.confidenceMasks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }
@ -438,16 +431,14 @@ export class ImageSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPMask( this.categoryMask = this.convertToMPMask(
mask, /* shouldCopyData= */ !this.userCallback); mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => { CATEGORY_MASK_STREAM, timestamp => {
this.result.categoryMask = undefined; this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }

View File

@ -17,18 +17,26 @@
import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
/** The output result of ImageSegmenter. */ /** The output result of ImageSegmenter. */
export declare interface ImageSegmenterResult { export class ImageSegmenterResult {
constructor(
/** /**
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
* `MPImage`s where, for each mask, each pixel represents the prediction * `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range. * confidence, usually in the [0, 1] range.
*/ */
confidenceMasks?: MPMask[]; readonly confidenceMasks?: MPMask[],
/** /**
* A category mask represented as a `Uint8ClampedArray` or * A category mask represented as a `Uint8ClampedArray` or
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which * `WebGLTexture`-backed `MPImage` where each pixel represents the class
* the pixel in the original image was predicted to belong to. * which the pixel in the original image was predicted to belong to.
*/ */
categoryMask?: MPMask; readonly categoryMask?: MPMask) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {
this.confidenceMasks?.forEach(m => {
m.close();
});
this.categoryMask?.close();
}
} }

View File

@ -263,7 +263,7 @@ describe('ImageSegmenter', () => {
}); });
}); });
it('invokes listener once masks are available', async () => { it('invokes listener after masks are available', async () => {
const categoryMask = new Uint8Array([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false; let listenerCalled = false;
@ -282,7 +282,7 @@ describe('ImageSegmenter', () => {
{data: confidenceMask, width: 1, height: 1}, {data: confidenceMask, width: 1, height: 1},
], ],
1337); 1337);
expect(listenerCalled).toBeTrue(); expect(listenerCalled).toBeFalse();
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
@ -307,6 +307,6 @@ describe('ImageSegmenter', () => {
const result = imageSegmenter.segment({} as HTMLImageElement); const result = imageSegmenter.segment({} as HTMLImageElement);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
result.confidenceMasks![0].close(); result.close();
}); });
}); });

View File

@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {RegionOfInterest} from '../../../../tasks/web/vision/core/types'; import {RegionOfInterest} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {Color as ColorProto} from '../../../../util/color_pb'; import {Color as ColorProto} from '../../../../util/color_pb';
@ -83,7 +84,8 @@ export type InteractiveSegmenterCallback =
* - batch is always 1 * - batch is always 1
*/ */
export class InteractiveSegmenter extends VisionTaskRunner { export class InteractiveSegmenter extends VisionTaskRunner {
private result: InteractiveSegmenterResult = {}; private categoryMask?: MPMask;
private confidenceMasks?: MPMask[];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private userCallback?: InteractiveSegmenterCallback; private userCallback?: InteractiveSegmenterCallback;
@ -276,28 +278,24 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.reset(); this.reset();
this.processRenderData(roi, this.getSynctheticTimestamp()); this.processRenderData(roi, this.getSynctheticTimestamp());
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
return this.processResults();
if (!this.userCallback) {
return this.result;
}
} }
private reset(): void { private reset(): void {
this.result = {}; this.confidenceMasks = undefined;
} this.categoryMask = undefined;
/** Invokes the user callback once all data has been received. */
private maybeInvokeCallback(): void {
if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) {
return;
}
if (this.outputCategoryMask && !('categoryMask' in this.result)) {
return;
} }
private processResults(): InteractiveSegmenterResult|void {
try {
const result = new InteractiveSegmenterResult(
this.confidenceMasks, this.categoryMask);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(this.result); this.userCallback(result);
} else {
return result;
}
} finally {
// Free the image memory, now that we've kept all streams alive long // Free the image memory, now that we've kept all streams alive long
// enough to be returned in our callbacks. // enough to be returned in our callbacks.
this.freeKeepaliveStreams(); this.freeKeepaliveStreams();
@ -333,17 +331,15 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
this.result.confidenceMasks = masks.map( this.confidenceMasks = masks.map(
wasmImage => this.convertToMPMask( wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback)); wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => { CONFIDENCE_MASKS_STREAM, timestamp => {
this.result.confidenceMasks = undefined; this.confidenceMasks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }
@ -354,16 +350,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => { CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = this.convertToMPMask( this.categoryMask = this.convertToMPMask(
mask, /* shouldCopyData= */ !this.userCallback); mask, /* shouldCopyData= */ !this.userCallback);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => { CATEGORY_MASK_STREAM, timestamp => {
this.result.categoryMask = undefined; this.categoryMask = undefined;
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }

View File

@ -17,18 +17,26 @@
import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
/** The output result of InteractiveSegmenter. */ /** The output result of InteractiveSegmenter. */
export declare interface InteractiveSegmenterResult { export class InteractiveSegmenterResult {
constructor(
/** /**
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed * Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
* `MPImage`s where, for each mask, each pixel represents the prediction * `MPImage`s where, for each mask, each pixel represents the prediction
* confidence, usually in the [0, 1] range. * confidence, usually in the [0, 1] range.
*/ */
confidenceMasks?: MPMask[]; readonly confidenceMasks?: MPMask[],
/** /**
* A category mask represented as a `Uint8ClampedArray` or * A category mask represented as a `Uint8ClampedArray` or
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which * `WebGLTexture`-backed `MPImage` where each pixel represents the class
* the pixel in the original image was predicted to belong to. * which the pixel in the original image was predicted to belong to.
*/ */
categoryMask?: MPMask; readonly categoryMask?: MPMask) {}
/** Frees the resources held by the category and confidence masks. */
close(): void {
this.confidenceMasks?.forEach(m => {
m.close();
});
this.categoryMask?.close();
}
} }

View File

@ -277,7 +277,7 @@ describe('InteractiveSegmenter', () => {
}); });
}); });
it('invokes listener once masks are avaiblae', async () => { it('invokes listener after masks are avaiblae', async () => {
const categoryMask = new Uint8Array([1]); const categoryMask = new Uint8Array([1]);
const confidenceMask = new Float32Array([0.0]); const confidenceMask = new Float32Array([0.0]);
let listenerCalled = false; let listenerCalled = false;
@ -296,7 +296,7 @@ describe('InteractiveSegmenter', () => {
{data: confidenceMask, width: 1, height: 1}, {data: confidenceMask, width: 1, height: 1},
], ],
1337); 1337);
expect(listenerCalled).toBeTrue(); expect(listenerCalled).toBeFalse();
}); });
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
@ -322,6 +322,6 @@ describe('InteractiveSegmenter', () => {
const result = const result =
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT); interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT);
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask); expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
result.confidenceMasks![0].close(); result.close();
}); });
}); });

View File

@ -21,9 +21,11 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
import {PoseDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb'; import {PoseDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb';
import {PoseLandmarkerGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options_pb'; import {PoseLandmarkerGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options_pb';
import {PoseLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb'; import {PoseLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result'; import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {Connection} from '../../../../tasks/web/vision/core/types'; import {Connection} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
@ -61,7 +63,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void;
/** Performs pose landmarks detection on images. */ /** Performs pose landmarks detection on images. */
export class PoseLandmarker extends VisionTaskRunner { export class PoseLandmarker extends VisionTaskRunner {
private result: Partial<PoseLandmarkerResult> = {}; private landmarks: NormalizedLandmark[][] = [];
private worldLandmarks: Landmark[][] = [];
private segmentationMasks?: MPMask[];
private outputSegmentationMasks = false; private outputSegmentationMasks = false;
private userCallback?: PoseLandmarkerCallback; private userCallback?: PoseLandmarkerCallback;
private readonly options: PoseLandmarkerGraphOptions; private readonly options: PoseLandmarkerGraphOptions;
@ -268,10 +272,7 @@ export class PoseLandmarker extends VisionTaskRunner {
this.resetResults(); this.resetResults();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
return this.processResults();
if (!this.userCallback) {
return this.result as PoseLandmarkerResult;
}
} }
/** /**
@ -352,31 +353,25 @@ export class PoseLandmarker extends VisionTaskRunner {
this.resetResults(); this.resetResults();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp); this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
return this.processResults();
if (!this.userCallback) {
return this.result as PoseLandmarkerResult;
}
} }
private resetResults(): void { private resetResults(): void {
this.result = {}; this.landmarks = [];
} this.worldLandmarks = [];
this.segmentationMasks = undefined;
/** Invokes the user callback once all data has been received. */
private maybeInvokeCallback(): void {
if (!('landmarks' in this.result)) {
return;
}
if (!('worldLandmarks' in this.result)) {
return;
}
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
return;
} }
private processResults(): PoseLandmarkerResult|void {
try {
const result = new PoseLandmarkerResult(
this.landmarks, this.worldLandmarks, this.segmentationMasks);
if (this.userCallback) { if (this.userCallback) {
this.userCallback(this.result as Required<PoseLandmarkerResult>); this.userCallback(result);
} else {
return result;
}
} finally {
// Free the image memory, now that we've finished our callback. // Free the image memory, now that we've finished our callback.
this.freeKeepaliveStreams(); this.freeKeepaliveStreams();
} }
@ -396,11 +391,11 @@ export class PoseLandmarker extends VisionTaskRunner {
* Converts raw data into a landmark, and adds it to our landmarks list. * Converts raw data into a landmark, and adds it to our landmarks list.
*/ */
private addJsLandmarks(data: Uint8Array[]): void { private addJsLandmarks(data: Uint8Array[]): void {
this.result.landmarks = []; this.landmarks = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const poseLandmarksProto = const poseLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto); NormalizedLandmarkList.deserializeBinary(binaryProto);
this.result.landmarks.push(convertToLandmarks(poseLandmarksProto)); this.landmarks.push(convertToLandmarks(poseLandmarksProto));
} }
} }
@ -409,11 +404,11 @@ export class PoseLandmarker extends VisionTaskRunner {
* worldLandmarks list. * worldLandmarks list.
*/ */
private adddJsWorldLandmarks(data: Uint8Array[]): void { private adddJsWorldLandmarks(data: Uint8Array[]): void {
this.result.worldLandmarks = []; this.worldLandmarks = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const poseWorldLandmarksProto = const poseWorldLandmarksProto =
LandmarkList.deserializeBinary(binaryProto); LandmarkList.deserializeBinary(binaryProto);
this.result.worldLandmarks.push( this.worldLandmarks.push(
convertToWorldLandmarks(poseWorldLandmarksProto)); convertToWorldLandmarks(poseWorldLandmarksProto));
} }
} }
@ -448,26 +443,22 @@ export class PoseLandmarker extends VisionTaskRunner {
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => { NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto); this.addJsLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
NORM_LANDMARKS_STREAM, timestamp => { NORM_LANDMARKS_STREAM, timestamp => {
this.result.landmarks = []; this.landmarks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => { WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.adddJsWorldLandmarks(binaryProto); this.adddJsWorldLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
WORLD_LANDMARKS_STREAM, timestamp => { WORLD_LANDMARKS_STREAM, timestamp => {
this.result.worldLandmarks = []; this.worldLandmarks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
if (this.outputSegmentationMasks) { if (this.outputSegmentationMasks) {
@ -477,17 +468,15 @@ export class PoseLandmarker extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
SEGMENTATION_MASK_STREAM, (masks, timestamp) => { SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
this.result.segmentationMasks = masks.map( this.segmentationMasks = masks.map(
wasmImage => this.convertToMPMask( wasmImage => this.convertToMPMask(
wasmImage, /* shouldCopyData= */ !this.userCallback)); wasmImage, /* shouldCopyData= */ !this.userCallback));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(
SEGMENTATION_MASK_STREAM, timestamp => { SEGMENTATION_MASK_STREAM, timestamp => {
this.result.segmentationMasks = []; this.segmentationMasks = [];
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
}); });
} }

View File

@ -24,13 +24,18 @@ export {Category, Landmark, NormalizedLandmark};
* Represents the pose landmarks deection results generated by `PoseLandmarker`. * Represents the pose landmarks deection results generated by `PoseLandmarker`.
* Each vector element represents a single pose detected in the image. * Each vector element represents a single pose detected in the image.
*/ */
export declare interface PoseLandmarkerResult { export class PoseLandmarkerResult {
/** Pose landmarks of detected poses. */ constructor(/** Pose landmarks of detected poses. */
landmarks: NormalizedLandmark[][]; readonly landmarks: NormalizedLandmark[][],
/** Pose landmarks in world coordinates of detected poses. */ /** Pose landmarks in world coordinates of detected poses. */
worldLandmarks: Landmark[][]; readonly worldLandmarks: Landmark[][],
/** Segmentation mask for the detected pose. */ /** Segmentation mask for the detected pose. */
segmentationMasks?: MPMask[]; readonly segmentationMasks?: MPMask[]) {}
/** Frees the resources held by the segmentation masks. */
close(): void {
this.segmentationMasks?.forEach(m => {
m.close();
});
}
} }

View File

@ -287,7 +287,7 @@ describe('PoseLandmarker', () => {
}); });
}); });
it('invokes listener once masks are available', (done) => { it('invokes listener after masks are available', (done) => {
const landmarksProto = [createLandmarks().serializeBinary()]; const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const masks = [ const masks = [
@ -309,13 +309,12 @@ describe('PoseLandmarker', () => {
expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse();
expect(listenerCalled).toBeFalse(); expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337); poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
expect(listenerCalled).toBeTrue();
done();
}); });
// Invoke the pose landmarker // Invoke the pose landmarker
poseLandmarker.detect({} as HTMLImageElement, () => { poseLandmarker.detect({} as HTMLImageElement, () => {
listenerCalled = true; listenerCalled = true;
done();
}); });
}); });
@ -336,5 +335,6 @@ describe('PoseLandmarker', () => {
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]); expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
result.close();
}); });
}); });