Update PoseLandmarker to return MPImage

PiperOrigin-RevId: 528022223
This commit is contained in:
Sebastian Schmidt 2023-04-28 17:11:03 -07:00 committed by Copybara-Service
parent dcef6df1cb
commit 874cc9dea3
5 changed files with 10 additions and 22 deletions

View File

@ -35,20 +35,6 @@ const COLOR_MAP: Array<[number, number, number, number]> = [
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
]; ];
/** Helper function to draw a confidence mask */
export function drawConfidenceMask(
ctx: CanvasRenderingContext2D, image: Float32Array, width: number,
height: number): void {
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < image.length; i++) {
uint8ClampedArray[4 * i] = 128;
uint8ClampedArray[4 * i + 1] = 0;
uint8ClampedArray[4 * i + 2] = 0;
uint8ClampedArray[4 * i + 3] = image[i] * 255;
}
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
}
/** The color converter we use in our demos. */ /** The color converter we use in our demos. */
export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = { export const RENDER_UTIL_CONVERTER: MPImageChannelConverter = {
floatToRGBAConverter: v => [128, 0, 0, v * 255], floatToRGBAConverter: v => [128, 0, 0, v * 255],

View File

@ -45,6 +45,7 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_options",
], ],
) )
@ -62,6 +63,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/components/processors:landmark_result", "//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:image",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
], ],
) )

View File

@ -417,7 +417,7 @@ export class PoseLandmarker extends VisionTaskRunner {
this.graphRunner.attachImageVectorListener( this.graphRunner.attachImageVectorListener(
SEGMENTATION_MASK_STREAM, (masks, timestamp) => { SEGMENTATION_MASK_STREAM, (masks, timestamp) => {
this.result.segmentationMasks = this.result.segmentationMasks =
masks.map(m => m.data) as Float32Array[] | WebGLBuffer[]; masks.map(wasmImage => this.convertToMPImage(wasmImage));
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(

View File

@ -16,6 +16,7 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {MPImage} from '../../../../tasks/web/vision/core/image';
export {Category, Landmark, NormalizedLandmark}; export {Category, Landmark, NormalizedLandmark};
@ -34,5 +35,5 @@ export declare interface PoseLandmarkerResult {
auxilaryLandmarks: NormalizedLandmark[]; auxilaryLandmarks: NormalizedLandmark[];
/** Segmentation mask for the detected pose. */ /** Segmentation mask for the detected pose. */
segmentationMasks?: Float32Array[]|WebGLTexture[]; segmentationMasks?: MPImage[];
} }

View File

@ -18,6 +18,7 @@ import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib'; import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {MPImage} from '../../../../tasks/web/vision/core/image';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {PoseLandmarker} from './pose_landmarker'; import {PoseLandmarker} from './pose_landmarker';
@ -221,12 +222,10 @@ describe('PoseLandmarker', () => {
.toHaveBeenCalledTimes(1); .toHaveBeenCalledTimes(1);
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result).toEqual({ expect(result.landmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]);
'landmarks': [{'x': 0, 'y': 0, 'z': 0}], expect(result.worldLandmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]);
'worldLandmarks': [{'x': 0, 'y': 0, 'z': 0}], expect(result.auxilaryLandmarks).toEqual([{'x': 0, 'y': 0, 'z': 0}]);
'auxilaryLandmarks': [{'x': 0, 'y': 0, 'z': 0}], expect(result.segmentationMasks![0]).toBeInstanceOf(MPImage);
'segmentationMasks': [new Float32Array([0, 1, 2, 3])],
});
done(); done();
}); });
}); });