Internal change

PiperOrigin-RevId: 518747623
This commit is contained in:
Sebastian Schmidt 2023-03-22 20:41:08 -07:00 committed by Copybara-Service
parent 37111e8fa5
commit 1a7be8a4c1
8 changed files with 123 additions and 4 deletions

View File

@ -23,7 +23,6 @@ mediapipe_proto_library(
srcs = ["tensors_to_segmentation_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/framework/formats:image_format_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto",
"//mediapipe/util:label_map_proto",

View File

@ -18,7 +18,7 @@ syntax = "proto2";
// TODO: consolidate TensorToSegmentationCalculator.
package mediapipe.tasks;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
import "mediapipe/util/label_map.proto";

View File

@ -19,6 +19,7 @@ mediapipe_ts_library(
deps = [
":core",
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",

View File

@ -15,6 +15,7 @@
*/
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
import {CalculatorGraphConfig} from '../../../framework/calculator_pb';
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
@ -120,11 +121,13 @@ export abstract class TaskRunner {
.then(buffer => {
this.setExternalFile(new Uint8Array(buffer));
this.refreshGraph();
this.onGraphRefreshed();
});
} else {
// Apply the setting synchronously.
this.setExternalFile(baseOptions.modelAssetBuffer);
this.refreshGraph();
this.onGraphRefreshed();
return Promise.resolve();
}
}
@ -132,6 +135,24 @@ export abstract class TaskRunner {
/** Appliest the current options to the MediaPipe graph. */
protected abstract refreshGraph(): void;
/**
* Callback that gets invoked once a new graph configuration has been
* applied.
*/
protected onGraphRefreshed(): void {}
/** Returns the current CalculatorGraphConfig. */
protected getCalculatorGraphConfig(): CalculatorGraphConfig {
let config: CalculatorGraphConfig|undefined;
this.graphRunner.getCalculatorGraphConfig(binaryData => {
config = CalculatorGraphConfig.deserializeBinary(binaryData);
});
if (!config) {
throw new Error('Failed to retrieve CalculatorGraphConfig');
}
return config;
}
/**
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
* over the video stream. Will replace the previously running MediaPipe graph,

View File

@ -16,7 +16,7 @@
import 'jasmine';
import {CalculatorGraphConfig} from '../../../framework/calculator_pb';
import {WasmModule} from '../../../web/graph_runner/graph_runner';
import {CALCULATOR_GRAPH_CONFIG_LISTENER_NAME, SimpleListener, WasmModule} from '../../../web/graph_runner/graph_runner';
import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service';
type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources;
@ -36,8 +36,13 @@ export function createSpyWasmModule(): SpyWasmModule {
'_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener',
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
'_addStringToInputStream', '_registerModelResourcesGraphService',
'_configureAudio', '_malloc', '_addProtoToInputStream'
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig'
]);
spyWasmModule._getGraphConfig.and.callFake(() => {
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
SimpleListener<Uint8Array>)(
new CalculatorGraphConfig().serializeBinary(), 0);
});
spyWasmModule.HEAPU8 = jasmine.createSpyObj<Uint8Array>(['set']);
return spyWasmModule;
}

View File

@ -15,12 +15,14 @@ mediapipe_ts_library(
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:types",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/util:label_map_jspb_proto",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)

View File

@ -17,12 +17,14 @@
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {TensorsToSegmentationCalculatorOptions} from '../../../../tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_pb';
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {LabelMapItem} from '../../../../util/label_map_pb';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
@ -37,6 +39,8 @@ const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
'mediapipe.tasks.TensorsToSegmentationCalculator';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
@ -44,6 +48,7 @@ const IMAGE_SEGMENTER_GRAPH =
/** Performs image segmentation on images. */
export class ImageSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {};
private labels: string[] = [];
private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto;
@ -146,6 +151,39 @@ export class ImageSegmenter extends VisionTaskRunner {
return super.applyOptions(options);
}
protected override onGraphRefreshed(): void {
this.populateLabels();
}
/**
* Populate the labelMap in TensorsToSegmentationCalculator to labels field.
* @throws Exception if there is an error during finding
* TensorsToSegmentationCalculator.
*/
private populateLabels(): void {
const graphConfig = this.getCalculatorGraphConfig();
const tensorsToSegmentationCalculators = graphConfig.getNodeList().filter(
(n: CalculatorGraphConfig.Node) =>
n.getName().includes(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME));
this.labels = [];
if (tensorsToSegmentationCalculators.length > 1) {
throw new Error(`The graph has more than one ${
TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}.`);
} else if (tensorsToSegmentationCalculators.length === 1) {
const labelItems =
tensorsToSegmentationCalculators[0]
.getOptions()
?.getExtension(TensorsToSegmentationCalculatorOptions.ext)
?.getLabelItemsMap() ??
new Map<string, LabelMapItem>();
labelItems.forEach((value, index) => {
// tslint:disable-next-line:no-unnecessary-type-assertion
this.labels[Number(index)] = value.getName()!;
});
}
}
/**
* Performs image segmentation on the provided single image and invokes the
* callback with the response. The method returns synchronously once the
@ -191,6 +229,21 @@ export class ImageSegmenter extends VisionTaskRunner {
this.userCallback = () => {};
}
/**
* Get the category label list of the ImageSegmenter can recognize. For
* `CATEGORY_MASK` type, the index in the category mask corresponds to the
* category in the label list. For `CONFIDENCE_MASK` type, the output mask
* list at index corresponds to the category in the label list.
*
* If there is no labelmap provided in the model file, empty label array is
* returned.
*
* @return The labels used by the current model.
*/
getLabels(): string[] {
return this.labels;
}
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the

View File

@ -36,6 +36,17 @@ export type EmptyPacketListener = (timestamp: number) => void;
export type VectorListener<T> = (data: T, done: boolean, timestamp: number) =>
void;
/**
* A listener that receives the CalculatorGraphConfig in binary encoding.
*/
export type CalculatorGraphConfigListener = (graphConfig: Uint8Array) => void;
/**
* The name of the internal listener that we use to obtain the calculator graph
* config. Intended for internal usage. Exported for testing only.
*/
export const CALCULATOR_GRAPH_CONFIG_LISTENER_NAME = '__graph_config__';
/**
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
* doesn't break our JS/C++ bridge.
@ -124,6 +135,10 @@ export declare interface WasmModule {
_configureAudio: (channels: number, samples: number, sampleRate: number,
streamNamePtr: number, headerNamePtr: number) => void;
// Get the graph configuration and invoke the listener configured under
// streamNamePtr
_getGraphConfig: (streamNamePtr: number, makeDeepCopy?: boolean) => void;
// TODO: Refactor to just use a few numbers (perhaps refactor away
// from gl_graph_runner_internal.cc entirely to use something a little more
// streamlined; new version is _processFrame above).
@ -437,6 +452,29 @@ export class GraphRunner {
this.wasmModule._free(heapSpace);
}
/**
* Invokes the callback with the current calculator configuration (in binary
* format).
*
* Consumers must deserialize the binary representation themselves as this
* avoids addding a direct dependency on the Protobuf JSPB target in the graph
* library.
*/
getCalculatorGraphConfig(
callback: CalculatorGraphConfigListener, makeDeepCopy?: boolean): void {
const listener = CALCULATOR_GRAPH_CONFIG_LISTENER_NAME;
// Create a short-lived listener to receive the binary encoded proto
this.setListener(listener, (data: Uint8Array) => {
callback(data);
});
this.wrapStringPtr(listener, (outputStreamNamePtr: number) => {
this.wasmModule._getGraphConfig(outputStreamNamePtr, makeDeepCopy);
});
delete this.wasmModule.simpleListeners![listener];
}
/**
* Ensures existence of the simple listeners table and registers the callback.
* Intended for internal usage.