Internal change
PiperOrigin-RevId: 518747623
This commit is contained in:
parent
37111e8fa5
commit
1a7be8a4c1
|
@ -23,7 +23,6 @@ mediapipe_proto_library(
|
|||
srcs = ["tensors_to_segmentation_calculator.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/framework/formats:image_format_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto",
|
||||
"//mediapipe/util:label_map_proto",
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
// TODO: consolidate TensorToSegmentationCalculator.
|
||||
package mediapipe.tasks;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
||||
import "mediapipe/util/label_map.proto";
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ mediapipe_ts_library(
|
|||
deps = [
|
||||
":core",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb';
|
||||
import {CalculatorGraphConfig} from '../../../framework/calculator_pb';
|
||||
import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb';
|
||||
|
@ -120,11 +121,13 @@ export abstract class TaskRunner {
|
|||
.then(buffer => {
|
||||
this.setExternalFile(new Uint8Array(buffer));
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
});
|
||||
} else {
|
||||
// Apply the setting synchronously.
|
||||
this.setExternalFile(baseOptions.modelAssetBuffer);
|
||||
this.refreshGraph();
|
||||
this.onGraphRefreshed();
|
||||
return Promise.resolve();
|
||||
}
|
||||
}
|
||||
|
@ -132,6 +135,24 @@ export abstract class TaskRunner {
|
|||
/** Appliest the current options to the MediaPipe graph. */
|
||||
protected abstract refreshGraph(): void;
|
||||
|
||||
/**
|
||||
* Callback that gets invoked once a new graph configuration has been
|
||||
* applied.
|
||||
*/
|
||||
protected onGraphRefreshed(): void {}
|
||||
|
||||
/** Returns the current CalculatorGraphConfig. */
|
||||
protected getCalculatorGraphConfig(): CalculatorGraphConfig {
|
||||
let config: CalculatorGraphConfig|undefined;
|
||||
this.graphRunner.getCalculatorGraphConfig(binaryData => {
|
||||
config = CalculatorGraphConfig.deserializeBinary(binaryData);
|
||||
});
|
||||
if (!config) {
|
||||
throw new Error('Failed to retrieve CalculatorGraphConfig');
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes the raw data from a MediaPipe graph, and passes it to C++ to be run
|
||||
* over the video stream. Will replace the previously running MediaPipe graph,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import 'jasmine';
|
||||
|
||||
import {CalculatorGraphConfig} from '../../../framework/calculator_pb';
|
||||
import {WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||
import {CALCULATOR_GRAPH_CONFIG_LISTENER_NAME, SimpleListener, WasmModule} from '../../../web/graph_runner/graph_runner';
|
||||
import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service';
|
||||
|
||||
type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources;
|
||||
|
@ -36,8 +36,13 @@ export function createSpyWasmModule(): SpyWasmModule {
|
|||
'_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener',
|
||||
'_attachProtoVectorListener', '_free', '_waitUntilIdle',
|
||||
'_addStringToInputStream', '_registerModelResourcesGraphService',
|
||||
'_configureAudio', '_malloc', '_addProtoToInputStream'
|
||||
'_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig'
|
||||
]);
|
||||
spyWasmModule._getGraphConfig.and.callFake(() => {
|
||||
(spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as
|
||||
SimpleListener<Uint8Array>)(
|
||||
new CalculatorGraphConfig().serializeBinary(), 0);
|
||||
});
|
||||
spyWasmModule.HEAPU8 = jasmine.createSpyObj<Uint8Array>(['set']);
|
||||
return spyWasmModule;
|
||||
}
|
||||
|
|
|
@ -15,12 +15,14 @@ mediapipe_ts_library(
|
|||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/web/vision/core:types",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||
"//mediapipe/util:label_map_jspb_proto",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {TensorsToSegmentationCalculatorOptions} from '../../../../tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_pb';
|
||||
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
|
||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {LabelMapItem} from '../../../../util/label_map_pb';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
|
@ -37,6 +39,8 @@ const NORM_RECT_STREAM = 'norm_rect';
|
|||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
||||
const IMAGE_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
'mediapipe.tasks.TensorsToSegmentationCalculator';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
@ -44,6 +48,7 @@ const IMAGE_SEGMENTER_GRAPH =
|
|||
/** Performs image segmentation on images. */
|
||||
export class ImageSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private labels: string[] = [];
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
|
@ -146,6 +151,39 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
return super.applyOptions(options);
|
||||
}
|
||||
|
||||
protected override onGraphRefreshed(): void {
|
||||
this.populateLabels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Populate the labelMap in TensorsToSegmentationCalculator to labels field.
|
||||
* @throws Exception if there is an error during finding
|
||||
* TensorsToSegmentationCalculator.
|
||||
*/
|
||||
private populateLabels(): void {
|
||||
const graphConfig = this.getCalculatorGraphConfig();
|
||||
const tensorsToSegmentationCalculators = graphConfig.getNodeList().filter(
|
||||
(n: CalculatorGraphConfig.Node) =>
|
||||
n.getName().includes(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME));
|
||||
|
||||
this.labels = [];
|
||||
if (tensorsToSegmentationCalculators.length > 1) {
|
||||
throw new Error(`The graph has more than one ${
|
||||
TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}.`);
|
||||
} else if (tensorsToSegmentationCalculators.length === 1) {
|
||||
const labelItems =
|
||||
tensorsToSegmentationCalculators[0]
|
||||
.getOptions()
|
||||
?.getExtension(TensorsToSegmentationCalculatorOptions.ext)
|
||||
?.getLabelItemsMap() ??
|
||||
new Map<string, LabelMapItem>();
|
||||
labelItems.forEach((value, index) => {
|
||||
// tslint:disable-next-line:no-unnecessary-type-assertion
|
||||
this.labels[Number(index)] = value.getName()!;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
|
@ -191,6 +229,21 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the category label list of the ImageSegmenter can recognize. For
|
||||
* `CATEGORY_MASK` type, the index in the category mask corresponds to the
|
||||
* category in the label list. For `CONFIDENCE_MASK` type, the output mask
|
||||
* list at index corresponds to the category in the label list.
|
||||
*
|
||||
* If there is no labelmap provided in the model file, empty label array is
|
||||
* returned.
|
||||
*
|
||||
* @return The labels used by the current model.
|
||||
*/
|
||||
getLabels(): string[] {
|
||||
return this.labels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
|
|
|
@ -36,6 +36,17 @@ export type EmptyPacketListener = (timestamp: number) => void;
|
|||
export type VectorListener<T> = (data: T, done: boolean, timestamp: number) =>
|
||||
void;
|
||||
|
||||
/**
|
||||
* A listener that receives the CalculatorGraphConfig in binary encoding.
|
||||
*/
|
||||
export type CalculatorGraphConfigListener = (graphConfig: Uint8Array) => void;
|
||||
|
||||
/**
|
||||
* The name of the internal listener that we use to obtain the calculator graph
|
||||
* config. Intended for internal usage. Exported for testing only.
|
||||
*/
|
||||
export const CALCULATOR_GRAPH_CONFIG_LISTENER_NAME = '__graph_config__';
|
||||
|
||||
/**
|
||||
* Declarations for Emscripten's WebAssembly Module behavior, so TS compiler
|
||||
* doesn't break our JS/C++ bridge.
|
||||
|
@ -124,6 +135,10 @@ export declare interface WasmModule {
|
|||
_configureAudio: (channels: number, samples: number, sampleRate: number,
|
||||
streamNamePtr: number, headerNamePtr: number) => void;
|
||||
|
||||
// Get the graph configuration and invoke the listener configured under
|
||||
// streamNamePtr
|
||||
_getGraphConfig: (streamNamePtr: number, makeDeepCopy?: boolean) => void;
|
||||
|
||||
// TODO: Refactor to just use a few numbers (perhaps refactor away
|
||||
// from gl_graph_runner_internal.cc entirely to use something a little more
|
||||
// streamlined; new version is _processFrame above).
|
||||
|
@ -437,6 +452,29 @@ export class GraphRunner {
|
|||
this.wasmModule._free(heapSpace);
|
||||
}
|
||||
|
||||
/**
|
||||
* Invokes the callback with the current calculator configuration (in binary
|
||||
* format).
|
||||
*
|
||||
* Consumers must deserialize the binary representation themselves as this
|
||||
* avoids addding a direct dependency on the Protobuf JSPB target in the graph
|
||||
* library.
|
||||
*/
|
||||
getCalculatorGraphConfig(
|
||||
callback: CalculatorGraphConfigListener, makeDeepCopy?: boolean): void {
|
||||
const listener = CALCULATOR_GRAPH_CONFIG_LISTENER_NAME;
|
||||
|
||||
// Create a short-lived listener to receive the binary encoded proto
|
||||
this.setListener(listener, (data: Uint8Array) => {
|
||||
callback(data);
|
||||
});
|
||||
this.wrapStringPtr(listener, (outputStreamNamePtr: number) => {
|
||||
this.wasmModule._getGraphConfig(outputStreamNamePtr, makeDeepCopy);
|
||||
});
|
||||
|
||||
delete this.wasmModule.simpleListeners![listener];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures existence of the simple listeners table and registers the callback.
|
||||
* Intended for internal usage.
|
||||
|
|
Loading…
Reference in New Issue
Block a user