From 746d775933a351bf925a206e1fe21ddfc6c4c5fd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 14 Dec 2023 14:08:56 -0800 Subject: [PATCH] Extend verifyGraph to be compatible with proto3. PiperOrigin-RevId: 591047275 --- .../tasks/web/core/task_runner_test_utils.ts | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 777cb8704..69d00b944 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -71,14 +71,23 @@ export interface MediapipeTasksFake { /** An map of field paths to values */ export type FieldPathToValue = [string[] | string, unknown]; +type JsonObject = Record; + +type Deserializer = (binaryProto: string | Uint8Array) => JsonObject; + /** * Verifies that the graph has been initialized and that it contains the * provided options. + * + * @param deserializer - the function to convert a binary proto to a JsonObject. + * For example, the deserializer of HolisticLandmarkerOptions's binary proto is + * HolisticLandmarkerOptions.deserializeBinary(binaryProto).toObject(). */ export function verifyGraph( tasksFake: MediapipeTasksFake, expectedCalculatorOptions?: FieldPathToValue, expectedBaseOptions?: FieldPathToValue, + deserializer?: Deserializer, ): void { expect(tasksFake.graph).toBeDefined(); // Our graphs should have at least one node in them for processing, and @@ -89,22 +98,30 @@ export function verifyGraph( expect(node).toEqual( jasmine.objectContaining({calculator: tasksFake.calculatorName})); + let proto; + if (deserializer) { + const binaryProto = + tasksFake.graph!.getNodeList()[0].getNodeOptionsList()[0].getValue(); + proto = deserializer(binaryProto); + } else { + proto = (node.options as {ext: unknown}).ext; + } + if (expectedBaseOptions) { const [fieldPath, value] = expectedBaseOptions; - let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; + let baseOptions = (proto as {baseOptions: unknown}).baseOptions; for (const fieldName of ( Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { - proto = ((proto ?? {}) as Record)[fieldName]; + baseOptions = ((baseOptions ?? {}) as JsonObject)[fieldName]; } - expect(proto).toEqual(value); + expect(baseOptions).toEqual(value); } if (expectedCalculatorOptions) { const [fieldPath, value] = expectedCalculatorOptions; - let proto = (node.options as {ext: unknown}).ext; for (const fieldName of ( Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { - proto = ((proto ?? {}) as Record)[fieldName]; + proto = ((proto ?? {}) as JsonObject)[fieldName]; } expect(proto).toEqual(value); }