Add .close() method to ImageSegmenterResult/InteractiveSegmenterResult/PoseLandmarkerResult
PiperOrigin-RevId: 530973944
This commit is contained in:
parent
a7ede9235c
commit
1666f3ed80
|
@ -22,6 +22,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
|
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {LabelMapItem} from '../../../../util/label_map_pb';
|
import {LabelMapItem} from '../../../../util/label_map_pb';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
@ -58,7 +59,8 @@ export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void;
|
||||||
|
|
||||||
/** Performs image segmentation on images. */
|
/** Performs image segmentation on images. */
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private result: ImageSegmenterResult = {};
|
private categoryMask?: MPMask;
|
||||||
|
private confidenceMasks?: MPMask[];
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
private userCallback?: ImageSegmenterCallback;
|
private userCallback?: ImageSegmenterCallback;
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
|
@ -265,10 +267,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.reset();
|
this.reset();
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
|
return this.processResults();
|
||||||
if (!this.userCallback) {
|
|
||||||
return this.result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -347,10 +346,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.reset();
|
this.reset();
|
||||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||||
|
return this.processResults();
|
||||||
if (!this.userCallback) {
|
|
||||||
return this.result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -369,21 +365,20 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
}
|
}
|
||||||
|
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.result = {};
|
this.categoryMask = undefined;
|
||||||
}
|
this.confidenceMasks = undefined;
|
||||||
|
|
||||||
/** Invokes the user callback once all data has been received. */
|
|
||||||
private maybeInvokeCallback(): void {
|
|
||||||
if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (this.outputCategoryMask && !('categoryMask' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private processResults(): ImageSegmenterResult|void {
|
||||||
|
try {
|
||||||
|
const result =
|
||||||
|
new ImageSegmenterResult(this.confidenceMasks, this.categoryMask);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(this.result);
|
this.userCallback(result);
|
||||||
|
} else {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
// Free the image memory, now that we've kept all streams alive long
|
// Free the image memory, now that we've kept all streams alive long
|
||||||
// enough to be returned in our callbacks.
|
// enough to be returned in our callbacks.
|
||||||
this.freeKeepaliveStreams();
|
this.freeKeepaliveStreams();
|
||||||
|
@ -417,17 +412,15 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
this.result.confidenceMasks = masks.map(
|
this.confidenceMasks = masks.map(
|
||||||
wasmImage => this.convertToMPMask(
|
wasmImage => this.convertToMPMask(
|
||||||
wasmImage, /* shouldCopyData= */ !this.userCallback));
|
wasmImage, /* shouldCopyData= */ !this.userCallback));
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
CONFIDENCE_MASKS_STREAM, timestamp => {
|
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||||
this.result.confidenceMasks = undefined;
|
this.confidenceMasks = [];
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -438,16 +431,14 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
this.result.categoryMask = this.convertToMPMask(
|
this.categoryMask = this.convertToMPMask(
|
||||||
mask, /* shouldCopyData= */ !this.userCallback);
|
mask, /* shouldCopyData= */ !this.userCallback);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
CATEGORY_MASK_STREAM, timestamp => {
|
CATEGORY_MASK_STREAM, timestamp => {
|
||||||
this.result.categoryMask = undefined;
|
this.categoryMask = undefined;
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,18 +17,26 @@
|
||||||
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||||
|
|
||||||
/** The output result of ImageSegmenter. */
|
/** The output result of ImageSegmenter. */
|
||||||
export declare interface ImageSegmenterResult {
|
export class ImageSegmenterResult {
|
||||||
|
constructor(
|
||||||
/**
|
/**
|
||||||
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
|
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
|
||||||
* `MPImage`s where, for each mask, each pixel represents the prediction
|
* `MPImage`s where, for each mask, each pixel represents the prediction
|
||||||
* confidence, usually in the [0, 1] range.
|
* confidence, usually in the [0, 1] range.
|
||||||
*/
|
*/
|
||||||
confidenceMasks?: MPMask[];
|
readonly confidenceMasks?: MPMask[],
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A category mask represented as a `Uint8ClampedArray` or
|
* A category mask represented as a `Uint8ClampedArray` or
|
||||||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||||
* the pixel in the original image was predicted to belong to.
|
* which the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
categoryMask?: MPMask;
|
readonly categoryMask?: MPMask) {}
|
||||||
|
|
||||||
|
/** Frees the resources held by the category and confidence masks. */
|
||||||
|
close(): void {
|
||||||
|
this.confidenceMasks?.forEach(m => {
|
||||||
|
m.close();
|
||||||
|
});
|
||||||
|
this.categoryMask?.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -263,7 +263,7 @@ describe('ImageSegmenter', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('invokes listener once masks are available', async () => {
|
it('invokes listener after masks are available', async () => {
|
||||||
const categoryMask = new Uint8Array([1]);
|
const categoryMask = new Uint8Array([1]);
|
||||||
const confidenceMask = new Float32Array([0.0]);
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
let listenerCalled = false;
|
let listenerCalled = false;
|
||||||
|
@ -282,7 +282,7 @@ describe('ImageSegmenter', () => {
|
||||||
{data: confidenceMask, width: 1, height: 1},
|
{data: confidenceMask, width: 1, height: 1},
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
expect(listenerCalled).toBeTrue();
|
expect(listenerCalled).toBeFalse();
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
|
@ -307,6 +307,6 @@ describe('ImageSegmenter', () => {
|
||||||
|
|
||||||
const result = imageSegmenter.segment({} as HTMLImageElement);
|
const result = imageSegmenter.segment({} as HTMLImageElement);
|
||||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||||
result.confidenceMasks![0].close();
|
result.close();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
|
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||||
import {RegionOfInterest} from '../../../../tasks/web/vision/core/types';
|
import {RegionOfInterest} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {Color as ColorProto} from '../../../../util/color_pb';
|
import {Color as ColorProto} from '../../../../util/color_pb';
|
||||||
|
@ -83,7 +84,8 @@ export type InteractiveSegmenterCallback =
|
||||||
* - batch is always 1
|
* - batch is always 1
|
||||||
*/
|
*/
|
||||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
private result: InteractiveSegmenterResult = {};
|
private categoryMask?: MPMask;
|
||||||
|
private confidenceMasks?: MPMask[];
|
||||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private userCallback?: InteractiveSegmenterCallback;
|
private userCallback?: InteractiveSegmenterCallback;
|
||||||
|
@ -276,28 +278,24 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
this.reset();
|
this.reset();
|
||||||
this.processRenderData(roi, this.getSynctheticTimestamp());
|
this.processRenderData(roi, this.getSynctheticTimestamp());
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
|
return this.processResults();
|
||||||
if (!this.userCallback) {
|
|
||||||
return this.result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private reset(): void {
|
private reset(): void {
|
||||||
this.result = {};
|
this.confidenceMasks = undefined;
|
||||||
}
|
this.categoryMask = undefined;
|
||||||
|
|
||||||
/** Invokes the user callback once all data has been received. */
|
|
||||||
private maybeInvokeCallback(): void {
|
|
||||||
if (this.outputConfidenceMasks && !('confidenceMasks' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (this.outputCategoryMask && !('categoryMask' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private processResults(): InteractiveSegmenterResult|void {
|
||||||
|
try {
|
||||||
|
const result = new InteractiveSegmenterResult(
|
||||||
|
this.confidenceMasks, this.categoryMask);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(this.result);
|
this.userCallback(result);
|
||||||
|
} else {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
// Free the image memory, now that we've kept all streams alive long
|
// Free the image memory, now that we've kept all streams alive long
|
||||||
// enough to be returned in our callbacks.
|
// enough to be returned in our callbacks.
|
||||||
this.freeKeepaliveStreams();
|
this.freeKeepaliveStreams();
|
||||||
|
@ -333,17 +331,15 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
this.result.confidenceMasks = masks.map(
|
this.confidenceMasks = masks.map(
|
||||||
wasmImage => this.convertToMPMask(
|
wasmImage => this.convertToMPMask(
|
||||||
wasmImage, /* shouldCopyData= */ !this.userCallback));
|
wasmImage, /* shouldCopyData= */ !this.userCallback));
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
CONFIDENCE_MASKS_STREAM, timestamp => {
|
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||||
this.result.confidenceMasks = undefined;
|
this.confidenceMasks = [];
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -354,16 +350,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
this.result.categoryMask = this.convertToMPMask(
|
this.categoryMask = this.convertToMPMask(
|
||||||
mask, /* shouldCopyData= */ !this.userCallback);
|
mask, /* shouldCopyData= */ !this.userCallback);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
CATEGORY_MASK_STREAM, timestamp => {
|
CATEGORY_MASK_STREAM, timestamp => {
|
||||||
this.result.categoryMask = undefined;
|
this.categoryMask = undefined;
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,18 +17,26 @@
|
||||||
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||||
|
|
||||||
/** The output result of InteractiveSegmenter. */
|
/** The output result of InteractiveSegmenter. */
|
||||||
export declare interface InteractiveSegmenterResult {
|
export class InteractiveSegmenterResult {
|
||||||
|
constructor(
|
||||||
/**
|
/**
|
||||||
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
|
* Multiple masks represented as `Float32Array` or `WebGLTexture`-backed
|
||||||
* `MPImage`s where, for each mask, each pixel represents the prediction
|
* `MPImage`s where, for each mask, each pixel represents the prediction
|
||||||
* confidence, usually in the [0, 1] range.
|
* confidence, usually in the [0, 1] range.
|
||||||
*/
|
*/
|
||||||
confidenceMasks?: MPMask[];
|
readonly confidenceMasks?: MPMask[],
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A category mask represented as a `Uint8ClampedArray` or
|
* A category mask represented as a `Uint8ClampedArray` or
|
||||||
* `WebGLTexture`-backed `MPImage` where each pixel represents the class which
|
* `WebGLTexture`-backed `MPImage` where each pixel represents the class
|
||||||
* the pixel in the original image was predicted to belong to.
|
* which the pixel in the original image was predicted to belong to.
|
||||||
*/
|
*/
|
||||||
categoryMask?: MPMask;
|
readonly categoryMask?: MPMask) {}
|
||||||
|
|
||||||
|
/** Frees the resources held by the category and confidence masks. */
|
||||||
|
close(): void {
|
||||||
|
this.confidenceMasks?.forEach(m => {
|
||||||
|
m.close();
|
||||||
|
});
|
||||||
|
this.categoryMask?.close();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -277,7 +277,7 @@ describe('InteractiveSegmenter', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('invokes listener once masks are avaiblae', async () => {
|
it('invokes listener after masks are avaiblae', async () => {
|
||||||
const categoryMask = new Uint8Array([1]);
|
const categoryMask = new Uint8Array([1]);
|
||||||
const confidenceMask = new Float32Array([0.0]);
|
const confidenceMask = new Float32Array([0.0]);
|
||||||
let listenerCalled = false;
|
let listenerCalled = false;
|
||||||
|
@ -296,7 +296,7 @@ describe('InteractiveSegmenter', () => {
|
||||||
{data: confidenceMask, width: 1, height: 1},
|
{data: confidenceMask, width: 1, height: 1},
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
expect(listenerCalled).toBeTrue();
|
expect(listenerCalled).toBeFalse();
|
||||||
});
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
|
@ -322,6 +322,6 @@ describe('InteractiveSegmenter', () => {
|
||||||
const result =
|
const result =
|
||||||
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT);
|
interactiveSegmenter.segment({} as HTMLImageElement, KEYPOINT);
|
||||||
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
expect(result.confidenceMasks![0]).toBeInstanceOf(MPMask);
|
||||||
result.confidenceMasks![0].close();
|
result.close();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -21,9 +21,11 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b
|
||||||
import {PoseDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb';
|
import {PoseDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb';
|
||||||
import {PoseLandmarkerGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options_pb';
|
import {PoseLandmarkerGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarker_graph_options_pb';
|
||||||
import {PoseLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb';
|
import {PoseLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb';
|
||||||
|
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
|
||||||
import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
|
import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
|
import {MPMask} from '../../../../tasks/web/vision/core/mask';
|
||||||
import {Connection} from '../../../../tasks/web/vision/core/types';
|
import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
@ -61,7 +63,9 @@ export type PoseLandmarkerCallback = (result: PoseLandmarkerResult) => void;
|
||||||
|
|
||||||
/** Performs pose landmarks detection on images. */
|
/** Performs pose landmarks detection on images. */
|
||||||
export class PoseLandmarker extends VisionTaskRunner {
|
export class PoseLandmarker extends VisionTaskRunner {
|
||||||
private result: Partial<PoseLandmarkerResult> = {};
|
private landmarks: NormalizedLandmark[][] = [];
|
||||||
|
private worldLandmarks: Landmark[][] = [];
|
||||||
|
private segmentationMasks?: MPMask[];
|
||||||
private outputSegmentationMasks = false;
|
private outputSegmentationMasks = false;
|
||||||
private userCallback?: PoseLandmarkerCallback;
|
private userCallback?: PoseLandmarkerCallback;
|
||||||
private readonly options: PoseLandmarkerGraphOptions;
|
private readonly options: PoseLandmarkerGraphOptions;
|
||||||
|
@ -268,10 +272,7 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
|
|
||||||
this.resetResults();
|
this.resetResults();
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
|
return this.processResults();
|
||||||
if (!this.userCallback) {
|
|
||||||
return this.result as PoseLandmarkerResult;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -352,31 +353,25 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
|
|
||||||
this.resetResults();
|
this.resetResults();
|
||||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||||
|
return this.processResults();
|
||||||
if (!this.userCallback) {
|
|
||||||
return this.result as PoseLandmarkerResult;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private resetResults(): void {
|
private resetResults(): void {
|
||||||
this.result = {};
|
this.landmarks = [];
|
||||||
}
|
this.worldLandmarks = [];
|
||||||
|
this.segmentationMasks = undefined;
|
||||||
/** Invokes the user callback once all data has been received. */
|
|
||||||
private maybeInvokeCallback(): void {
|
|
||||||
if (!('landmarks' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!('worldLandmarks' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private processResults(): PoseLandmarkerResult|void {
|
||||||
|
try {
|
||||||
|
const result = new PoseLandmarkerResult(
|
||||||
|
this.landmarks, this.worldLandmarks, this.segmentationMasks);
|
||||||
if (this.userCallback) {
|
if (this.userCallback) {
|
||||||
this.userCallback(this.result as Required<PoseLandmarkerResult>);
|
this.userCallback(result);
|
||||||
|
} else {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
// Free the image memory, now that we've finished our callback.
|
// Free the image memory, now that we've finished our callback.
|
||||||
this.freeKeepaliveStreams();
|
this.freeKeepaliveStreams();
|
||||||
}
|
}
|
||||||
|
@ -396,11 +391,11 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
* Converts raw data into a landmark, and adds it to our landmarks list.
|
* Converts raw data into a landmark, and adds it to our landmarks list.
|
||||||
*/
|
*/
|
||||||
private addJsLandmarks(data: Uint8Array[]): void {
|
private addJsLandmarks(data: Uint8Array[]): void {
|
||||||
this.result.landmarks = [];
|
this.landmarks = [];
|
||||||
for (const binaryProto of data) {
|
for (const binaryProto of data) {
|
||||||
const poseLandmarksProto =
|
const poseLandmarksProto =
|
||||||
NormalizedLandmarkList.deserializeBinary(binaryProto);
|
NormalizedLandmarkList.deserializeBinary(binaryProto);
|
||||||
this.result.landmarks.push(convertToLandmarks(poseLandmarksProto));
|
this.landmarks.push(convertToLandmarks(poseLandmarksProto));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -409,11 +404,11 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
* worldLandmarks list.
|
* worldLandmarks list.
|
||||||
*/
|
*/
|
||||||
private adddJsWorldLandmarks(data: Uint8Array[]): void {
|
private adddJsWorldLandmarks(data: Uint8Array[]): void {
|
||||||
this.result.worldLandmarks = [];
|
this.worldLandmarks = [];
|
||||||
for (const binaryProto of data) {
|
for (const binaryProto of data) {
|
||||||
const poseWorldLandmarksProto =
|
const poseWorldLandmarksProto =
|
||||||
LandmarkList.deserializeBinary(binaryProto);
|
LandmarkList.deserializeBinary(binaryProto);
|
||||||
this.result.worldLandmarks.push(
|
this.worldLandmarks.push(
|
||||||
convertToWorldLandmarks(poseWorldLandmarksProto));
|
convertToWorldLandmarks(poseWorldLandmarksProto));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -448,26 +443,22 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
NORM_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||||
this.addJsLandmarks(binaryProto);
|
this.addJsLandmarks(binaryProto);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
NORM_LANDMARKS_STREAM, timestamp => {
|
NORM_LANDMARKS_STREAM, timestamp => {
|
||||||
this.result.landmarks = [];
|
this.landmarks = [];
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
this.graphRunner.attachProtoVectorListener(
|
this.graphRunner.attachProtoVectorListener(
|
||||||
WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||||
this.adddJsWorldLandmarks(binaryProto);
|
this.adddJsWorldLandmarks(binaryProto);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
WORLD_LANDMARKS_STREAM, timestamp => {
|
WORLD_LANDMARKS_STREAM, timestamp => {
|
||||||
this.result.worldLandmarks = [];
|
this.worldLandmarks = [];
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (this.outputSegmentationMasks) {
|
if (this.outputSegmentationMasks) {
|
||||||
|
@ -477,17 +468,15 @@ export class PoseLandmarker extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
this.graphRunner.attachImageVectorListener(
|
||||||
SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
|
SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
|
||||||
this.result.segmentationMasks = masks.map(
|
this.segmentationMasks = masks.map(
|
||||||
wasmImage => this.convertToMPMask(
|
wasmImage => this.convertToMPMask(
|
||||||
wasmImage, /* shouldCopyData= */ !this.userCallback));
|
wasmImage, /* shouldCopyData= */ !this.userCallback));
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
SEGMENTATION_MASK_STREAM, timestamp => {
|
SEGMENTATION_MASK_STREAM, timestamp => {
|
||||||
this.result.segmentationMasks = [];
|
this.segmentationMasks = [];
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
this.maybeInvokeCallback();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,18 @@ export {Category, Landmark, NormalizedLandmark};
|
||||||
* Represents the pose landmarks deection results generated by `PoseLandmarker`.
|
* Represents the pose landmarks deection results generated by `PoseLandmarker`.
|
||||||
* Each vector element represents a single pose detected in the image.
|
* Each vector element represents a single pose detected in the image.
|
||||||
*/
|
*/
|
||||||
export declare interface PoseLandmarkerResult {
|
export class PoseLandmarkerResult {
|
||||||
/** Pose landmarks of detected poses. */
|
constructor(/** Pose landmarks of detected poses. */
|
||||||
landmarks: NormalizedLandmark[][];
|
readonly landmarks: NormalizedLandmark[][],
|
||||||
|
|
||||||
/** Pose landmarks in world coordinates of detected poses. */
|
/** Pose landmarks in world coordinates of detected poses. */
|
||||||
worldLandmarks: Landmark[][];
|
readonly worldLandmarks: Landmark[][],
|
||||||
|
|
||||||
/** Segmentation mask for the detected pose. */
|
/** Segmentation mask for the detected pose. */
|
||||||
segmentationMasks?: MPMask[];
|
readonly segmentationMasks?: MPMask[]) {}
|
||||||
|
|
||||||
|
/** Frees the resources held by the segmentation masks. */
|
||||||
|
close(): void {
|
||||||
|
this.segmentationMasks?.forEach(m => {
|
||||||
|
m.close();
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -287,7 +287,7 @@ describe('PoseLandmarker', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('invokes listener once masks are available', (done) => {
|
it('invokes listener after masks are available', (done) => {
|
||||||
const landmarksProto = [createLandmarks().serializeBinary()];
|
const landmarksProto = [createLandmarks().serializeBinary()];
|
||||||
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
|
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
|
||||||
const masks = [
|
const masks = [
|
||||||
|
@ -309,13 +309,12 @@ describe('PoseLandmarker', () => {
|
||||||
expect(listenerCalled).toBeFalse();
|
expect(listenerCalled).toBeFalse();
|
||||||
expect(listenerCalled).toBeFalse();
|
expect(listenerCalled).toBeFalse();
|
||||||
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
|
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
|
||||||
expect(listenerCalled).toBeTrue();
|
|
||||||
done();
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the pose landmarker
|
// Invoke the pose landmarker
|
||||||
poseLandmarker.detect({} as HTMLImageElement, () => {
|
poseLandmarker.detect({} as HTMLImageElement, () => {
|
||||||
listenerCalled = true;
|
listenerCalled = true;
|
||||||
|
done();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -336,5 +335,6 @@ describe('PoseLandmarker', () => {
|
||||||
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||||
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||||
|
result.close();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue
Block a user