diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index b54e7352b..c621016dc 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -23,7 +23,6 @@ mediapipe_proto_library( srcs = ["tensors_to_segmentation_calculator.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index f267bf09b..dbaf34db0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -18,7 +18,7 @@ syntax = "proto2"; // TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; -import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index ec65548d4..a417d4d72 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -19,6 +19,7 @@ mediapipe_ts_library( deps = [ ":core", "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index a01bb1c92..68208c970 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -15,6 +15,7 @@ */ import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; @@ -120,11 +121,13 @@ export abstract class TaskRunner { .then(buffer => { this.setExternalFile(new Uint8Array(buffer)); this.refreshGraph(); + this.onGraphRefreshed(); }); } else { // Apply the setting synchronously. this.setExternalFile(baseOptions.modelAssetBuffer); this.refreshGraph(); + this.onGraphRefreshed(); return Promise.resolve(); } } @@ -132,6 +135,24 @@ export abstract class TaskRunner { /** Appliest the current options to the MediaPipe graph. */ protected abstract refreshGraph(): void; + /** + * Callback that gets invoked once a new graph configuration has been + * applied. + */ + protected onGraphRefreshed(): void {} + + /** Returns the current CalculatorGraphConfig. */ + protected getCalculatorGraphConfig(): CalculatorGraphConfig { + let config: CalculatorGraphConfig|undefined; + this.graphRunner.getCalculatorGraphConfig(binaryData => { + config = CalculatorGraphConfig.deserializeBinary(binaryData); + }); + if (!config) { + throw new Error('Failed to retrieve CalculatorGraphConfig'); + } + return config; + } + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 62dd0463a..b0aa34095 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -16,7 +16,7 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; -import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {CALCULATOR_GRAPH_CONFIG_LISTENER_NAME, SimpleListener, WasmModule} from '../../../web/graph_runner/graph_runner'; import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; @@ -36,8 +36,13 @@ export function createSpyWasmModule(): SpyWasmModule { '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', '_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addStringToInputStream', '_registerModelResourcesGraphService', - '_configureAudio', '_malloc', '_addProtoToInputStream' + '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig' ]); + spyWasmModule._getGraphConfig.and.callFake(() => { + (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as + SimpleListener)( + new CalculatorGraphConfig().serializeBinary(), 0); + }); spyWasmModule.HEAPU8 = jasmine.createSpyObj(['set']); return spyWasmModule; } diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD index 3ca2a64eb..a4b9008dd 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -15,12 +15,14 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_jspb_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:types", "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/util:label_map_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index f8ff0dcca..cb192b0ce 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -17,12 +17,14 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {TensorsToSegmentationCalculatorOptions} from '../../../../tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_pb'; import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {LabelMapItem} from '../../../../util/label_map_pb'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -37,6 +39,8 @@ const NORM_RECT_STREAM = 'norm_rect'; const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; const IMAGE_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; +const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = + 'mediapipe.tasks.TensorsToSegmentationCalculator'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern @@ -44,6 +48,7 @@ const IMAGE_SEGMENTER_GRAPH = /** Performs image segmentation on images. */ export class ImageSegmenter extends VisionTaskRunner { private userCallback: SegmentationMaskCallback = () => {}; + private labels: string[] = []; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -146,6 +151,39 @@ export class ImageSegmenter extends VisionTaskRunner { return super.applyOptions(options); } + protected override onGraphRefreshed(): void { + this.populateLabels(); + } + + /** + * Populate the labelMap in TensorsToSegmentationCalculator to labels field. + * @throws Exception if there is an error during finding + * TensorsToSegmentationCalculator. + */ + private populateLabels(): void { + const graphConfig = this.getCalculatorGraphConfig(); + const tensorsToSegmentationCalculators = graphConfig.getNodeList().filter( + (n: CalculatorGraphConfig.Node) => + n.getName().includes(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)); + + this.labels = []; + if (tensorsToSegmentationCalculators.length > 1) { + throw new Error(`The graph has more than one ${ + TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}.`); + } else if (tensorsToSegmentationCalculators.length === 1) { + const labelItems = + tensorsToSegmentationCalculators[0] + .getOptions() + ?.getExtension(TensorsToSegmentationCalculatorOptions.ext) + ?.getLabelItemsMap() ?? + new Map(); + labelItems.forEach((value, index) => { + // tslint:disable-next-line:no-unnecessary-type-assertion + this.labels[Number(index)] = value.getName()!; + }); + } + } + /** * Performs image segmentation on the provided single image and invokes the * callback with the response. The method returns synchronously once the @@ -191,6 +229,21 @@ export class ImageSegmenter extends VisionTaskRunner { this.userCallback = () => {}; } + /** + * Get the category label list of the ImageSegmenter can recognize. For + * `CATEGORY_MASK` type, the index in the category mask corresponds to the + * category in the label list. For `CONFIDENCE_MASK` type, the output mask + * list at index corresponds to the category in the label list. + * + * If there is no labelmap provided in the model file, empty label array is + * returned. + * + * @return The labels used by the current model. + */ + getLabels(): string[] { + return this.labels; + } + /** * Performs image segmentation on the provided video frame and invokes the * callback with the response. The method returns synchronously once the diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 4417f6a03..e2b1684a0 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -36,6 +36,17 @@ export type EmptyPacketListener = (timestamp: number) => void; export type VectorListener = (data: T, done: boolean, timestamp: number) => void; +/** + * A listener that receives the CalculatorGraphConfig in binary encoding. + */ +export type CalculatorGraphConfigListener = (graphConfig: Uint8Array) => void; + +/** + * The name of the internal listener that we use to obtain the calculator graph + * config. Intended for internal usage. Exported for testing only. + */ +export const CALCULATOR_GRAPH_CONFIG_LISTENER_NAME = '__graph_config__'; + /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -124,6 +135,10 @@ export declare interface WasmModule { _configureAudio: (channels: number, samples: number, sampleRate: number, streamNamePtr: number, headerNamePtr: number) => void; + // Get the graph configuration and invoke the listener configured under + // streamNamePtr + _getGraphConfig: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more // streamlined; new version is _processFrame above). @@ -437,6 +452,29 @@ export class GraphRunner { this.wasmModule._free(heapSpace); } + /** + * Invokes the callback with the current calculator configuration (in binary + * format). + * + * Consumers must deserialize the binary representation themselves as this + * avoids addding a direct dependency on the Protobuf JSPB target in the graph + * library. + */ + getCalculatorGraphConfig( + callback: CalculatorGraphConfigListener, makeDeepCopy?: boolean): void { + const listener = CALCULATOR_GRAPH_CONFIG_LISTENER_NAME; + + // Create a short-lived listener to receive the binary encoded proto + this.setListener(listener, (data: Uint8Array) => { + callback(data); + }); + this.wrapStringPtr(listener, (outputStreamNamePtr: number) => { + this.wasmModule._getGraphConfig(outputStreamNamePtr, makeDeepCopy); + }); + + delete this.wasmModule.simpleListeners![listener]; + } + /** * Ensures existence of the simple listeners table and registers the callback. * Intended for internal usage.