Extend verifyGraph to be compatible with proto3.

PiperOrigin-RevId: 591047275
This commit is contained in:
MediaPipe Team 2023-12-14 14:08:56 -08:00 committed by Copybara-Service
parent df7feadaf7
commit 746d775933

View File

@ -71,14 +71,23 @@ export interface MediapipeTasksFake {
/** An map of field paths to values */ /** An map of field paths to values */
export type FieldPathToValue = [string[] | string, unknown]; export type FieldPathToValue = [string[] | string, unknown];
type JsonObject = Record<string, unknown>;
type Deserializer = (binaryProto: string | Uint8Array) => JsonObject;
/** /**
* Verifies that the graph has been initialized and that it contains the * Verifies that the graph has been initialized and that it contains the
* provided options. * provided options.
*
* @param deserializer - the function to convert a binary proto to a JsonObject.
* For example, the deserializer of HolisticLandmarkerOptions's binary proto is
* HolisticLandmarkerOptions.deserializeBinary(binaryProto).toObject().
*/ */
export function verifyGraph( export function verifyGraph(
tasksFake: MediapipeTasksFake, tasksFake: MediapipeTasksFake,
expectedCalculatorOptions?: FieldPathToValue, expectedCalculatorOptions?: FieldPathToValue,
expectedBaseOptions?: FieldPathToValue, expectedBaseOptions?: FieldPathToValue,
deserializer?: Deserializer,
): void { ): void {
expect(tasksFake.graph).toBeDefined(); expect(tasksFake.graph).toBeDefined();
// Our graphs should have at least one node in them for processing, and // Our graphs should have at least one node in them for processing, and
@ -89,22 +98,30 @@ export function verifyGraph(
expect(node).toEqual( expect(node).toEqual(
jasmine.objectContaining({calculator: tasksFake.calculatorName})); jasmine.objectContaining({calculator: tasksFake.calculatorName}));
let proto;
if (deserializer) {
const binaryProto =
tasksFake.graph!.getNodeList()[0].getNodeOptionsList()[0].getValue();
proto = deserializer(binaryProto);
} else {
proto = (node.options as {ext: unknown}).ext;
}
if (expectedBaseOptions) { if (expectedBaseOptions) {
const [fieldPath, value] = expectedBaseOptions; const [fieldPath, value] = expectedBaseOptions;
let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions; let baseOptions = (proto as {baseOptions: unknown}).baseOptions;
for (const fieldName of ( for (const fieldName of (
Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { Array.isArray(fieldPath) ? fieldPath : [fieldPath])) {
proto = ((proto ?? {}) as Record<string, unknown>)[fieldName]; baseOptions = ((baseOptions ?? {}) as JsonObject)[fieldName];
} }
expect(proto).toEqual(value); expect(baseOptions).toEqual(value);
} }
if (expectedCalculatorOptions) { if (expectedCalculatorOptions) {
const [fieldPath, value] = expectedCalculatorOptions; const [fieldPath, value] = expectedCalculatorOptions;
let proto = (node.options as {ext: unknown}).ext;
for (const fieldName of ( for (const fieldName of (
Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { Array.isArray(fieldPath) ? fieldPath : [fieldPath])) {
proto = ((proto ?? {}) as Record<string, unknown>)[fieldName]; proto = ((proto ?? {}) as JsonObject)[fieldName];
} }
expect(proto).toEqual(value); expect(proto).toEqual(value);
} }