From 5a5ff5393a7bfd9e76f7c3c867957eb18c48f80e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 22 Dec 2022 17:29:23 -0800 Subject: [PATCH 01/21] Internal change PiperOrigin-RevId: 497269082 --- mediapipe/framework/api2/builder.h | 2 +- mediapipe/framework/api2/packet.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 19273bf44..2a98c4166 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -398,7 +398,7 @@ template class Node; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Node()->Node; +explicit Node() -> Node; #endif // C++17 template <> diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 7933575d3..b1ebb0410 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -181,7 +181,7 @@ template class Packet; #if __cplusplus >= 201703L // Deduction guide to silence -Wctad-maybe-unsupported. -explicit Packet()->Packet; +explicit Packet() -> Packet; #endif // C++17 template <> From 175aff9be8ca719257e15355ecc1b682e7e4e299 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 27 Dec 2022 11:24:50 -0800 Subject: [PATCH 02/21] Update list of issue assignments PiperOrigin-RevId: 498003950 --- .github/bot_config.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 8ad724168..74a60e4b9 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,4 +15,5 @@ # A list of assignees assignees: - - sureshdagooglecom + - kuaashish + - ayushgdev From 7e36a5e2ae8c66ef9717d399fa4004f448dde13f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 28 Dec 2022 11:22:52 -0800 Subject: [PATCH 03/21] Set filecmp.cmp(shallow=False) in model_maker unit tests. PiperOrigin-RevId: 498218578 --- .../python/text/text_classifier/text_classifier_test.py | 6 ++++-- .../python/vision/image_classifier/image_classifier_test.py | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 7a30d19fd..d2edb78bc 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -72,8 +72,10 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertTrue( - filecmp.cmp(output_metadata_file, - self._AVERAGE_WORD_EMBEDDING_JSON_FILE)) + filecmp.cmp( + output_metadata_file, + self._AVERAGE_WORD_EMBEDDING_JSON_FILE, + shallow=False)) def test_create_and_train_bert(self): train_data, validation_data = self._get_data() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 6ca21d334..14c67d831 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,7 +135,9 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) - self.assertTrue(filecmp.cmp(output_metadata_file, expected_metadata_file)) + self.assertTrue( + filecmp.cmp( + output_metadata_file, expected_metadata_file, shallow=False)) def test_continual_training_by_loading_checkpoint(self): mock_stdout = io.StringIO() From 9580f045710327b7a22d738b911af70121e2a79a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 28 Dec 2022 13:57:20 -0800 Subject: [PATCH 04/21] Apply most graph options synchronously PiperOrigin-RevId: 498244085 --- .../audio_classifier/audio_classifier.ts | 7 +- .../audio_classifier/audio_classifier_test.ts | 3 +- .../audio/audio_embedder/audio_embedder.ts | 7 +- .../audio_embedder/audio_embedder_test.ts | 3 +- .../tasks/web/components/processors/BUILD | 26 --- .../processors/base_options.test.ts | 127 --------------- .../web/components/processors/base_options.ts | 80 ---------- mediapipe/tasks/web/core/BUILD | 5 +- mediapipe/tasks/web/core/task_runner.ts | 75 ++++++++- mediapipe/tasks/web/core/task_runner_test.ts | 148 +++++++++++++++++- .../text/text_classifier/text_classifier.ts | 7 +- .../text_classifier/text_classifier_test.ts | 3 +- .../web/text/text_embedder/text_embedder.ts | 7 +- .../text/text_embedder/text_embedder_test.ts | 3 +- mediapipe/tasks/web/vision/core/BUILD | 1 + .../vision/core/vision_task_runner.test.ts | 32 ++-- .../web/vision/core/vision_task_runner.ts | 4 +- .../gesture_recognizer/gesture_recognizer.ts | 8 +- .../gesture_recognizer_test.ts | 3 +- .../vision/hand_landmarker/hand_landmarker.ts | 8 +- .../hand_landmarker/hand_landmarker_test.ts | 3 +- .../image_classifier/image_classifier.ts | 7 +- .../image_classifier/image_classifier_test.ts | 3 +- .../vision/image_embedder/image_embedder.ts | 7 +- .../image_embedder/image_embedder_test.ts | 3 +- .../vision/object_detector/object_detector.ts | 8 +- .../object_detector/object_detector_test.ts | 3 +- 27 files changed, 280 insertions(+), 311 deletions(-) delete mode 100644 mediapipe/tasks/web/components/processors/base_options.test.ts delete mode 100644 mediapipe/tasks/web/components/processors/base_options.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 7bfca680a..51573f50a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -119,11 +119,10 @@ export class AudioClassifier extends AudioTaskRunner { * * @param options The options for the audio classifier. */ - override async setOptions(options: AudioClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioClassifier extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts index d5c0a9429..2089f184f 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_test.ts @@ -79,7 +79,8 @@ describe('AudioClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioClassifier = new AudioClassifierFake(); - await audioClassifier.setOptions({}); // Initialize graph + await audioClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 246cba883..6a4b8ce39 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -121,11 +121,10 @@ export class AudioEmbedder extends AudioTaskRunner { * * @param options The options for the audio embedder. */ - override async setOptions(options: AudioEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: AudioEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -171,7 +170,7 @@ export class AudioEmbedder extends AudioTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index 2f605ff98..dde61a6e9 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -70,7 +70,8 @@ describe('AudioEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); audioEmbedder = new AudioEmbedderFake(); - await audioEmbedder.setOptions({}); // Initialize graph + await audioEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', () => { diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 148a08238..cab24293d 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -103,29 +103,3 @@ jasmine_node_test( name = "embedder_options_test", deps = [":embedder_options_test_lib"], ) - -mediapipe_ts_library( - name = "base_options", - srcs = [ - "base_options.ts", - ], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", - "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", - "//mediapipe/tasks/web/core", - ], -) - -mediapipe_ts_library( - name = "base_options_test_lib", - testonly = True, - srcs = ["base_options.test.ts"], - deps = [":base_options"], -) - -jasmine_node_test( - name = "base_options_test", - deps = [":base_options_test_lib"], -) diff --git a/mediapipe/tasks/web/components/processors/base_options.test.ts b/mediapipe/tasks/web/components/processors/base_options.test.ts deleted file mode 100644 index 6d58be68f..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.test.ts +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import 'jasmine'; - -// Placeholder for internal dependency on encodeByteArray -// Placeholder for internal dependency on trusted resource URL builder - -import {convertBaseOptionsToProto} from './base_options'; - -describe('convertBaseOptionsToProto()', () => { - const mockBytes = new Uint8Array([0, 1, 2, 3]); - const mockBytesResult = { - modelAsset: { - fileContent: Buffer.from(mockBytes).toString('base64'), - fileName: undefined, - fileDescriptorMeta: undefined, - filePointerMeta: undefined, - }, - useStreamMode: false, - acceleration: { - xnnpack: undefined, - gpu: undefined, - tflite: {}, - }, - }; - - let fetchSpy: jasmine.Spy; - - beforeEach(() => { - fetchSpy = jasmine.createSpy().and.callFake(async url => { - expect(url).toEqual('foo'); - return { - arrayBuffer: () => mockBytes.buffer, - } as unknown as Response; - }); - global.fetch = fetchSpy; - }); - - it('verifies that at least one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({})) - .toBeRejectedWithError( - /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); - }); - - it('verifies that no more than one model asset option is provided', async () => { - await expectAsync(convertBaseOptionsToProto({ - modelAssetPath: `foo`, - modelAssetBuffer: new Uint8Array([]) - })) - .toBeRejectedWithError( - /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); - }); - - it('downloads model', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetPath: `foo`, - }); - - expect(fetchSpy).toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('does not download model when bytes are provided', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - }); - - expect(fetchSpy).not.toHaveBeenCalled(); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable CPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'CPU', - }); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); - - it('can enable GPU delegate', async () => { - const baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - expect(baseOptionsProto.toObject()).toEqual({ - ...mockBytesResult, - acceleration: { - xnnpack: undefined, - gpu: { - useAdvancedGpuApi: false, - api: 0, - allowPrecisionLoss: true, - cachedKernelPath: undefined, - serializedModelDir: undefined, - modelToken: undefined, - usage: 2, - }, - tflite: undefined, - }, - }); - }); - - it('can reset delegate', async () => { - let baseOptionsProto = await convertBaseOptionsToProto({ - modelAssetBuffer: new Uint8Array(mockBytes), - delegate: 'GPU', - }); - // Clear backend - baseOptionsProto = - await convertBaseOptionsToProto({delegate: undefined}, baseOptionsProto); - expect(baseOptionsProto.toObject()).toEqual(mockBytesResult); - }); -}); diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts deleted file mode 100644 index 97b62b784..000000000 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb'; -import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; - -// The OSS JS API does not support the builder pattern. -// tslint:disable:jspb-use-builder-pattern - -/** - * Converts a BaseOptions API object to its Protobuf representation. - * @throws If neither a model assset path or buffer is provided - */ -export async function convertBaseOptionsToProto( - updatedOptions: BaseOptions, - currentOptions?: BaseOptionsProto): Promise { - const result = - currentOptions ? currentOptions.clone() : new BaseOptionsProto(); - - await configureExternalFile(updatedOptions, result); - configureAcceleration(updatedOptions, result); - - return result; -} - -/** - * Configues the `externalFile` option and validates that a single model is - * provided. - */ -async function configureExternalFile( - options: BaseOptions, proto: BaseOptionsProto) { - const externalFile = proto.getModelAsset() || new ExternalFile(); - proto.setModelAsset(externalFile); - - if (options.modelAssetPath || options.modelAssetBuffer) { - if (options.modelAssetPath && options.modelAssetBuffer) { - throw new Error( - 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); - } - - let modelAssetBuffer = options.modelAssetBuffer; - if (!modelAssetBuffer) { - const response = await fetch(options.modelAssetPath!.toString()); - modelAssetBuffer = new Uint8Array(await response.arrayBuffer()); - } - externalFile.setFileContent(modelAssetBuffer); - } - - if (!externalFile.hasFileContent()) { - throw new Error( - 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); - } -} - -/** Configues the `acceleration` option. */ -function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) { - const acceleration = proto.getAcceleration() ?? new Acceleration(); - if (options.delegate === 'GPU') { - acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); - } else { - acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); - } - proto.setAcceleration(acceleration); -} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 1721661f5..c0d10d28b 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,8 +18,10 @@ mediapipe_ts_library( srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", @@ -53,6 +55,7 @@ mediapipe_ts_library( "task_runner_test.ts", ], deps = [ + ":core", ":task_runner", ":task_runner_test_utils", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 2011fadef..ffb538b52 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,11 @@ * limitations under the License. */ +import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; -import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; +import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; +import {BaseOptions, TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -91,14 +93,52 @@ export abstract class TaskRunner { this.graphRunner.registerModelResourcesGraphService(); } - /** Configures the shared options of a MediaPipe Task. */ - async setOptions(options: TaskRunnerOptions): Promise { - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); + /** Configures the task with custom options. */ + abstract setOptions(options: TaskRunnerOptions): Promise; + + /** + * Applies the current set of options, including any base options that have + * not been processed by the task implementation. The options are applied + * synchronously unless a `modelAssetPath` is provided. This ensures that + * for most use cases options are applied directly and immediately affect + * the next inference. + */ + protected applyOptions(options: TaskRunnerOptions): Promise { + const baseOptions: BaseOptions = options.baseOptions || {}; + + // Validate that exactly one model is configured + if (options.baseOptions?.modelAssetBuffer && + options.baseOptions?.modelAssetPath) { + throw new Error( + 'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer'); + } else if (!(this.baseOptions.getModelAsset()?.hasFileContent() || + options.baseOptions?.modelAssetBuffer || + options.baseOptions?.modelAssetPath)) { + throw new Error( + 'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set'); + } + + this.setAcceleration(baseOptions); + if (baseOptions.modelAssetPath) { + // We don't use `await` here since we want to apply most settings + // synchronously. + return fetch(baseOptions.modelAssetPath.toString()) + .then(response => response.arrayBuffer()) + .then(buffer => { + this.setExternalFile(new Uint8Array(buffer)); + this.refreshGraph(); + }); + } else { + // Apply the setting synchronously. + this.setExternalFile(baseOptions.modelAssetBuffer); + this.refreshGraph(); + return Promise.resolve(); } } + /** Appliest the current options to the MediaPipe graph. */ + protected abstract refreshGraph(): void; + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -140,6 +180,27 @@ export abstract class TaskRunner { } this.processingErrors = []; } + + /** Configures the `externalFile` option */ + private setExternalFile(modelAssetBuffer?: Uint8Array): void { + const externalFile = this.baseOptions.getModelAsset() || new ExternalFile(); + if (modelAssetBuffer) { + externalFile.setFileContent(modelAssetBuffer); + } + this.baseOptions.setModelAsset(externalFile); + } + + /** Configures the `acceleration` option. */ + private setAcceleration(options: BaseOptions) { + const acceleration = + this.baseOptions.getAcceleration() ?? new Acceleration(); + if (options.delegate === 'GPU') { + acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu()); + } else { + acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite()); + } + this.baseOptions.setAcceleration(acceleration); + } } diff --git a/mediapipe/tasks/web/core/task_runner_test.ts b/mediapipe/tasks/web/core/task_runner_test.ts index c9aad9d25..a55ac04d7 100644 --- a/mediapipe/tasks/web/core/task_runner_test.ts +++ b/mediapipe/tasks/web/core/task_runner_test.ts @@ -15,18 +15,22 @@ */ import 'jasmine'; +// Placeholder for internal dependency on encodeByteArray import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {TaskRunner} from '../../../tasks/web/core/task_runner'; import {createSpyWasmModule, SpyWasmModule} from '../../../tasks/web/core/task_runner_test_utils'; import {ErrorListener} from '../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource URL builder import {GraphRunnerImageLib} from './task_runner'; +import {TaskRunnerOptions} from './task_runner_options.d'; class TaskRunnerFake extends TaskRunner { - protected baseOptions = new BaseOptionsProto(); private errorListener: ErrorListener|undefined; private errors: string[] = []; + baseOptions = new BaseOptionsProto(); + static createFake(): TaskRunnerFake { const wasmModule = createSpyWasmModule(); return new TaskRunnerFake(wasmModule); @@ -61,10 +65,16 @@ class TaskRunnerFake extends TaskRunner { super.finishProcessing(); } + override refreshGraph(): void {} + override setGraph(graphData: Uint8Array, isBinary: boolean): void { super.setGraph(graphData, isBinary); } + setOptions(options: TaskRunnerOptions): Promise { + return this.applyOptions(options); + } + private throwErrors(): void { expect(this.errorListener).toBeDefined(); for (const error of this.errors) { @@ -75,8 +85,38 @@ class TaskRunnerFake extends TaskRunner { } describe('TaskRunner', () => { + const mockBytes = new Uint8Array([0, 1, 2, 3]); + const mockBytesResult = { + modelAsset: { + fileContent: Buffer.from(mockBytes).toString('base64'), + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined, + }, + useStreamMode: false, + acceleration: { + xnnpack: undefined, + gpu: undefined, + tflite: {}, + }, + }; + + let fetchSpy: jasmine.Spy; + let taskRunner: TaskRunnerFake; + + beforeEach(() => { + fetchSpy = jasmine.createSpy().and.callFake(async url => { + expect(url).toEqual('foo'); + return { + arrayBuffer: () => mockBytes.buffer, + } as unknown as Response; + }); + global.fetch = fetchSpy; + + taskRunner = TaskRunnerFake.createFake(); + }); + it('handles errors during graph update', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error'); expect(() => { @@ -85,7 +125,6 @@ describe('TaskRunner', () => { }); it('handles errors during graph execution', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); taskRunner.enqueueError('Test error'); @@ -96,7 +135,6 @@ describe('TaskRunner', () => { }); it('can handle multiple errors', () => { - const taskRunner = TaskRunnerFake.createFake(); taskRunner.enqueueError('Test error 1'); taskRunner.enqueueError('Test error 2'); @@ -104,4 +142,106 @@ describe('TaskRunner', () => { taskRunner.setGraph(new Uint8Array(0), /* isBinary= */ true); }).toThrowError(/Test error 1, Test error 2/); }); + + it('verifies that at least one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({}); + }) + .toThrowError( + /Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set/); + }); + + it('verifies that no more than one model asset option is provided', () => { + expect(() => { + taskRunner.setOptions({ + baseOptions: { + modelAssetPath: `foo`, + modelAssetBuffer: new Uint8Array([]) + } + }); + }) + .toThrowError( + /Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer/); + }); + + it('doesn\'t require model once it is configured', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + expect(() => { + taskRunner.setOptions({}); + }).not.toThrowError(); + }); + + it('downloads model', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetPath: `foo`}}); + + expect(fetchSpy).toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('does not download model when bytes are provided', async () => { + await taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + expect(fetchSpy).not.toHaveBeenCalled(); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('changes model synchronously when bytes are provided', () => { + const resolvedPromise = taskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array(mockBytes)}}); + + // Check that the change has been applied even though we do not await the + // above Promise + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + return resolvedPromise; + }); + + it('can enable CPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'CPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); + + it('can enable GPU delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + expect(taskRunner.baseOptions.toObject()).toEqual({ + ...mockBytesResult, + acceleration: { + xnnpack: undefined, + gpu: { + useAdvancedGpuApi: false, + api: 0, + allowPrecisionLoss: true, + cachedKernelPath: undefined, + serializedModelDir: undefined, + modelToken: undefined, + usage: 2, + }, + tflite: undefined, + }, + }); + }); + + it('can reset delegate', async () => { + await taskRunner.setOptions({ + baseOptions: { + modelAssetBuffer: new Uint8Array(mockBytes), + delegate: 'GPU', + } + }); + // Clear backend + await taskRunner.setOptions({baseOptions: {delegate: undefined}}); + expect(taskRunner.baseOptions.toObject()).toEqual(mockBytesResult); + }); }); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 62708700a..981438625 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -109,11 +109,10 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - override async setOptions(options: TextClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -141,7 +140,7 @@ export class TextClassifier extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts index 841bf8c48..5578362cb 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts @@ -56,7 +56,8 @@ describe('TextClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textClassifier = new TextClassifierFake(); - await textClassifier.setOptions({}); // Initialize graph + await textClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 611233e02..7aa0aa6b9 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -113,11 +113,10 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - override async setOptions(options: TextEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: TextEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } protected override get baseOptions(): BaseOptionsProto { @@ -157,7 +156,7 @@ export class TextEmbedder extends TaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts index 04a9b371a..2804e4deb 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts @@ -56,7 +56,8 @@ describe('TextEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); textEmbedder = new TextEmbedderFake(); - await textEmbedder.setOptions({}); // Initialize graph + await textEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e4ea3036f..03958a819 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -29,6 +29,7 @@ mediapipe_ts_library( testonly = True, srcs = ["vision_task_runner.test.ts"], deps = [ + ":vision_task_options", ":vision_task_runner", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/web/core:task_runner_test_utils", diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts index 6cc9ea328..d77cc4fed 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.test.ts @@ -20,6 +20,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {VisionTaskOptions} from './vision_task_options'; import {VisionTaskRunner} from './vision_task_runner'; class VisionTaskRunnerFake extends VisionTaskRunner { @@ -31,6 +32,12 @@ class VisionTaskRunnerFake extends VisionTaskRunner { protected override process(): void {} + protected override refreshGraph(): void {} + + override setOptions(options: VisionTaskOptions): Promise { + return this.applyOptions(options); + } + override processImageData(image: ImageSource): void { super.processImageData(image); } @@ -41,32 +48,24 @@ class VisionTaskRunnerFake extends VisionTaskRunner { } describe('VisionTaskRunner', () => { - const streamMode = { - modelAsset: undefined, - useStreamMode: true, - acceleration: undefined, - }; - - const imageMode = { - modelAsset: undefined, - useStreamMode: false, - acceleration: undefined, - }; - let visionTaskRunner: VisionTaskRunnerFake; - beforeEach(() => { + beforeEach(async () => { visionTaskRunner = new VisionTaskRunnerFake(); + await visionTaskRunner.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('can enable image mode', async () => { await visionTaskRunner.setOptions({runningMode: 'image'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('can enable video mode', async () => { await visionTaskRunner.setOptions({runningMode: 'video'}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: true})); }); it('can clear running mode', async () => { @@ -74,7 +73,8 @@ describe('VisionTaskRunner', () => { // Clear running mode await visionTaskRunner.setOptions({runningMode: undefined}); - expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode); + expect(visionTaskRunner.baseOptions.toObject()) + .toEqual(jasmine.objectContaining({useStreamMode: false})); }); it('cannot process images with video mode', async () => { diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 3432b521b..952990326 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -22,13 +22,13 @@ import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ export abstract class VisionTaskRunner extends TaskRunner { /** Configures the shared options of a vision task. */ - override async setOptions(options: VisionTaskOptions): Promise { - await super.setOptions(options); + override applyOptions(options: VisionTaskOptions): Promise { if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; this.baseOptions.setUseStreamMode(useStreamMode); } + return super.applyOptions(options); } /** Sends an image packet to the graph and awaits results. */ diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index b6b795076..cfeb179f5 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -169,9 +169,7 @@ export class GestureRecognizer extends * * @param options The options for the gesture recognizer. */ - override async setOptions(options: GestureRecognizerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: GestureRecognizerOptions): Promise { if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( options.numHands ?? DEFAULT_NUM_HANDS); @@ -221,7 +219,7 @@ export class GestureRecognizer extends ?.clearClassifierOptions(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -342,7 +340,7 @@ export class GestureRecognizer extends } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index c0f0d1554..ff6bba613 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -109,7 +109,8 @@ describe('GestureRecognizer', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); gestureRecognizer = new GestureRecognizerFake(); - await gestureRecognizer.setOptions({}); // Initialize graph + await gestureRecognizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 2a0e8286c..24cf9a402 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -150,9 +150,7 @@ export class HandLandmarker extends VisionTaskRunner { * * @param options The options for the hand landmarker. */ - override async setOptions(options: HandLandmarkerOptions): Promise { - await super.setOptions(options); - + override setOptions(options: HandLandmarkerOptions): Promise { // Configure hand detector options. if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -173,7 +171,7 @@ export class HandLandmarker extends VisionTaskRunner { options.minHandPresenceConfidence ?? DEFAULT_SCORE_THRESHOLD); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -291,7 +289,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index fc26680e0..76e77b4bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -98,7 +98,8 @@ describe('HandLandmarker', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); handLandmarker = new HandLandmarkerFake(); - await handLandmarker.setOptions({}); // Initialize graph + await handLandmarker.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 36e7311fb..9298a860c 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -118,11 +118,10 @@ export class ImageClassifier extends VisionTaskRunner { * * @param options The options for the image classifier. */ - override async setOptions(options: ImageClassifierOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageClassifierOptions): Promise { this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -163,7 +162,7 @@ export class ImageClassifier extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts index 2041a0cef..da4a01d02 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_test.ts @@ -61,7 +61,8 @@ describe('ImageClassifier', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageClassifier = new ImageClassifierFake(); - await imageClassifier.setOptions({}); // Initialize graph + await imageClassifier.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 0c45ba5e7..cf0bd8c5d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -120,11 +120,10 @@ export class ImageEmbedder extends VisionTaskRunner { * * @param options The options for the image embedder. */ - override async setOptions(options: ImageEmbedderOptions): Promise { - await super.setOptions(options); + override setOptions(options: ImageEmbedderOptions): Promise { this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -186,7 +185,7 @@ export class ImageEmbedder extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(EMBEDDINGS_STREAM); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts index cafe0f3d8..b63bb374c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts @@ -57,7 +57,8 @@ describe('ImageEmbedder', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); imageEmbedder = new ImageEmbedderFake(); - await imageEmbedder.setOptions({}); // Initialize graph + await imageEmbedder.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index fbfaced12..e4c51de08 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -117,9 +117,7 @@ export class ObjectDetector extends VisionTaskRunner { * * @param options The options for the object detector. */ - override async setOptions(options: ObjectDetectorOptions): Promise { - await super.setOptions(options); - + override setOptions(options: ObjectDetectorOptions): Promise { // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to // `undefined`. @@ -153,7 +151,7 @@ export class ObjectDetector extends VisionTaskRunner { this.options.clearCategoryDenylistList(); } - this.refreshGraph(); + return this.applyOptions(options); } /** @@ -226,7 +224,7 @@ export class ObjectDetector extends VisionTaskRunner { } /** Updates the MediaPipe graph configuration. */ - private refreshGraph(): void { + protected override refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); graphConfig.addOutputStream(DETECTIONS_STREAM); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts index fff1a1c48..43b7035d5 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_test.ts @@ -61,7 +61,8 @@ describe('ObjectDetector', () => { beforeEach(async () => { addJasmineCustomFloatEqualityTester(); objectDetector = new ObjectDetectorFake(); - await objectDetector.setOptions({}); // Initialize graph + await objectDetector.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); }); it('initializes graph', async () => { From 1924f1cdff94af953c2cd9b01a13d623ea13e7a7 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 28 Dec 2022 14:27:42 -0800 Subject: [PATCH 05/21] Tensor: Fix use_ahwb_ flag and tests on local device involved. PiperOrigin-RevId: 498249332 --- mediapipe/framework/formats/tensor_ahwb.cc | 3 +- .../framework/formats/tensor_ahwb_gpu_test.cc | 16 ++++++-- .../framework/formats/tensor_ahwb_test.cc | 39 ++++--------------- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 466811be7..74b2dca93 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -458,7 +458,8 @@ void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const { ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim); } } - use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_); + // Keep flag value if it was set previously. + use_ahwb_ = use_ahwb_ || ahwb_usage_track_.contains(ahwb_tracking_key_); } #else // MEDIAPIPE_TENSOR_USE_AHWB diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index a6ca00949..e2ad869f9 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -92,9 +92,14 @@ class TensorAhwbGpuTest : public mediapipe::GpuTestBase { }; TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -114,9 +119,14 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { } TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { - Tensor::SetPreferredStorageType(Tensor::StorageType::kAhwb); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat16, Tensor::Shape({num_elements})}; + { + // Request Ahwb first to get Ahwb storage allocated internally. + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); @@ -139,7 +149,6 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { // Request the CPU view to get the memory to be allocated. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { @@ -168,7 +177,6 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { // Request the GPU view to get the ssbo allocated internally. // Request Ahwb view then to transform the storage into Ahwb. - Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault); constexpr size_t num_elements = 20; Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; RunInGlContext([&tensor] { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 7ab5a4925..f0baa6303 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -1,34 +1,28 @@ #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/gpu/gpu_test_base.h" #include "testing/base/public/gmock.h" #include "testing/base/public/gunit.h" -#ifdef MEDIAPIPE_TENSOR_USE_AHWB -#if !MEDIAPIPE_DISABLE_GPU - namespace mediapipe { -class TensorAhwbTest : public mediapipe::GpuTestBase { - public: -}; - -TEST_F(TensorAhwbTest, TestCpuThenAHWB) { +TEST(TensorAhwbTest, TestCpuThenAHWB) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { auto ptr = tensor.GetCpuWriteView().buffer(); EXPECT_NE(ptr, nullptr); } { - auto ahwb = tensor.GetAHardwareBufferReadView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); } } -TEST_F(TensorAhwbTest, TestAHWBThenCpu) { +TEST(TensorAhwbTest, TestAHWBThenCpu) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); { - auto ahwb = tensor.GetAHardwareBufferWriteView().handle(); - EXPECT_NE(ahwb, nullptr); + auto view = tensor.GetAHardwareBufferWriteView(); + EXPECT_NE(view.handle(), nullptr); + view.SetWritingFinishedFD(-1, [](bool) { return true; }); } { auto ptr = tensor.GetCpuReadView().buffer(); @@ -36,21 +30,4 @@ TEST_F(TensorAhwbTest, TestAHWBThenCpu) { } } -TEST_F(TensorAhwbTest, TestCpuThenGl) { - RunInGlContext([] { - Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1}); - { - auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); - } - { - auto ssbo = tensor.GetOpenGlBufferReadView().name(); - EXPECT_GT(ssbo, 0); - } - }); -} - } // namespace mediapipe - -#endif // !MEDIAPIPE_DISABLE_GPU -#endif // MEDIAPIPE_TENSOR_USE_AHWB From 2d9a969d10bdcac98e0e86f617817e08cf656331 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 28 Dec 2022 16:07:09 -0800 Subject: [PATCH 06/21] Tensor1: memorize size_alignment when tracking the ahwb usage. When CPU/GPU buffer allocated and the tracker selects Ahwb storage to be used then the properly recorded alignment must be used. PiperOrigin-RevId: 498264759 --- mediapipe/framework/formats/BUILD | 2 +- mediapipe/framework/formats/tensor.h | 7 +- mediapipe/framework/formats/tensor_ahwb.cc | 7 +- .../framework/formats/tensor_ahwb_test.cc | 67 +++++++++++++++++++ 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f5a043f10..cce7e5bd0 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -455,7 +455,7 @@ cc_library( ], }), deps = [ - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 8a6f02e9d..0f19bb5ee 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -24,7 +24,7 @@ #include #include -#include "absl/container/flat_hash_set.h" +#include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/formats/tensor_internal.h" #include "mediapipe/framework/port.h" @@ -434,8 +434,9 @@ class Tensor { mutable bool use_ahwb_ = false; mutable uint64_t ahwb_tracking_key_ = 0; // TODO: Tracks all unique tensors. Can grow to a large number. LRU - // can be more predicted. - static inline absl::flat_hash_set ahwb_usage_track_; + // (Least Recently Used) can be more predicted. + // The value contains the size alignment parameter. + static inline absl::flat_hash_map ahwb_usage_track_; // Expects the target SSBO to be already bound. bool AllocateAhwbMapToSsbo() const; bool InsertAhwbToSsboFence() const; diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index 74b2dca93..525f05f31 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -266,7 +266,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView( bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { // Mark current tracking key as Ahwb-use. - ahwb_usage_track_.insert(ahwb_tracking_key_); + if (auto it = ahwb_usage_track_.find(ahwb_tracking_key_); + it != ahwb_usage_track_.end()) { + size_alignment = it->second; + } else if (ahwb_tracking_key_ != 0) { + ahwb_usage_track_.insert({ahwb_tracking_key_, size_alignment}); + } use_ahwb_ = true; if (__builtin_available(android 26, *)) { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index f0baa6303..3da6ca8d3 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -30,4 +30,71 @@ TEST(TensorAhwbTest, TestAHWBThenCpu) { } } +TEST(TensorAhwbTest, TestAhwbAlignment) { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); + { + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 5 = 20, the closest aligned to 16 size is 32. + EXPECT_EQ(desc.width, 32); + } + view.SetWritingFinishedFD(-1, [](bool) { return true; }); + } +} + +// Tensor::GetCpuView uses source location mechanism that gives source file name +// and line from where the method is called. The function is intended just to +// have two calls providing the same source file name and line. +auto GetCpuView(const Tensor &tensor) { return tensor.GetCpuWriteView(); } + +// The test checks the tracking mechanism: when a tensor's Cpu view is retrieved +// for the first time then the source location is attached to the tensor. If the +// Ahwb view is requested then from the tensor then the previously recorded Cpu +// view request source location is marked for using Ahwb storage. +// When a Cpu view with the same source location (but for the newly allocated +// tensor) is requested and the location is marked to use Ahwb storage then the +// Ahwb storage is allocated for the CpuView. +TEST(TensorAhwbTest, TestTrackingAhwb) { + // Create first tensor and request Cpu and then Ahwb view to mark the source + // location for Ahwb storage. + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Align size of the Ahwb by multiple of 16. + auto view = tensor.GetAHardwareBufferWriteView(16); + EXPECT_NE(view.handle(), nullptr); + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } + { + Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{9}); + { + // The second tensor uses the same Cpu view source location so Ahwb + // storage is allocated internally. + auto view = GetCpuView(tensor); + EXPECT_NE(view.buffer(), nullptr); + } + { + // Check the Ahwb size to be aligned to multiple of 16. The alignment is + // stored by previous requesting of the Ahwb view. + auto view = tensor.GetAHardwareBufferReadView(); + EXPECT_NE(view.handle(), nullptr); + if (__builtin_available(android 26, *)) { + AHardwareBuffer_Desc desc; + AHardwareBuffer_describe(view.handle(), &desc); + // sizeof(float) * 9 = 36. The closest aligned size is 48. + EXPECT_EQ(desc.width, 48); + } + view.SetReadingFinishedFunc([](bool) { return true; }); + } + } +} + } // namespace mediapipe From aaa16eca1fedf9450689be422ea2dc01c7d74c93 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Dec 2022 08:33:58 -0800 Subject: [PATCH 07/21] Sets the graph service packets before initializing (and validating the graph) in the objc graph wrapper. PiperOrigin-RevId: 498393761 --- mediapipe/objc/MPPGraph.mm | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 1bd177e80..3123eb863 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -230,16 +230,17 @@ if ([wrapper.delegate } - (absl::Status)performStart { - absl::Status status = _graph->Initialize(_config); - if (!status.ok()) { - return status; - } + absl::Status status; for (const auto& service_packet : _servicePackets) { status = _graph->SetServicePacket(*service_packet.first, service_packet.second); if (!status.ok()) { return status; } } + status = _graph->Initialize(_config); + if (!status.ok()) { + return status; + } status = _graph->StartRun(_inputSidePackets, _streamHeaders); if (!status.ok()) { return status; From 60c6b155f626f40e2971cda10aa4c3565897874a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 29 Dec 2022 10:16:10 -0800 Subject: [PATCH 08/21] Save an integer id in graph profiler objects to distinguish between different profiler instances during benchmarking. PiperOrigin-RevId: 498409363 --- .../framework/profiler/graph_profiler.cc | 1 + mediapipe/framework/profiler/graph_profiler.h | 9 +++++++ .../framework/profiler/graph_profiler_test.cc | 26 +++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index f14acfc78..6aead5250 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -194,6 +194,7 @@ void GraphProfiler::Initialize( "Calculator \"$0\" has already been added.", node_name); } profile_builder_ = std::make_unique(this); + graph_id_ = ++next_instance_id_; is_initialized_ = true; } diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 23caed4ec..6358cb057 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -237,6 +237,9 @@ class GraphProfiler : public std::enable_shared_from_this { return validated_graph_; } + // Gets a numerical identifier for this GraphProfiler object. + uint64_t GetGraphId() { return graph_id_; } + private: // This can be used to add packet info for the input streams to the graph. // It treats the stream defined by |stream_name| as a stream produced by a @@ -357,6 +360,12 @@ class GraphProfiler : public std::enable_shared_from_this { class GraphProfileBuilder; std::unique_ptr profile_builder_; + // The globally incrementing identifier for all graphs in a process. + static inline std::atomic_int next_instance_id_ = 0; + + // A unique identifier for this object. Only unique within a process. + uint64_t graph_id_; + // For testing. friend GraphProfilerTestPeer; }; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index 81ba90cda..75d1c7ebd 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -442,6 +442,32 @@ TEST_F(GraphProfilerTestPeer, InitializeMultipleTimes) { "Cannot initialize .* multiple times."); } +// Tests that graph identifiers are not reused, even after destruction. +TEST_F(GraphProfilerTestPeer, InitializeMultipleProfilers) { + auto raw_graph_config = R"( + profiler_config { + enable_profiler: true + } + input_stream: "input_stream" + node { + calculator: "DummyTestCalculator" + input_stream: "input_stream" + })"; + const int n_iterations = 100; + absl::flat_hash_set seen_ids; + for (int i = 0; i < n_iterations; ++i) { + std::shared_ptr profiler = + std::make_shared(); + auto graph_config = CreateGraphConfig(raw_graph_config); + mediapipe::ValidatedGraphConfig validated_graph; + QCHECK_OK(validated_graph.Initialize(graph_config)); + profiler->Initialize(validated_graph); + + int id = profiler->GetGraphId(); + ASSERT_THAT(seen_ids, testing::Not(testing::Contains(id))); + seen_ids.insert(id); + } +} // Tests that Pause(), Resume(), and Reset() works. TEST_F(GraphProfilerTestPeer, PauseResumeReset) { InitializeProfilerWithGraphConfig(R"( From 9252a025e5604cb61b11cbf23943dc7fb9e6f679 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 30 Dec 2022 04:56:57 -0800 Subject: [PATCH 09/21] Use custom gesture options in GestureRecognizer PiperOrigin-RevId: 498567432 --- .../tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 01f444742..91a5ec213 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -151,11 +151,11 @@ ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto custom_gestures_classifier_options_proto = std::make_unique( components::processors::ConvertClassifierOptionsToProto( - &(options->canned_gestures_classifier_options))); + &(options->custom_gestures_classifier_options))); hand_gesture_recognizer_graph_options ->mutable_custom_gesture_classifier_graph_options() ->mutable_classifier_options() - ->Swap(canned_gestures_classifier_options_proto.get()); + ->Swap(custom_gestures_classifier_options_proto.get()); return options_proto; } From 2f4bb5d545fbd6b6389248b7123635dcdfff02b7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 3 Jan 2023 09:34:21 -0800 Subject: [PATCH 10/21] Use utility framebuffer in ViewDoneWritingSimulatorWorkaround This code needs a FBO to bind the texture. Fixes invalid results when running under simulator. PiperOrigin-RevId: 499241867 --- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 75 +++++++++++-------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index 014cc1c69..7cac32b7f 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -74,42 +74,51 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, const GlTextureView& view) { CHECK(pixel_buffer); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferLockBaseAddress failed: " << err; - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = - static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we - // can use BindFramebuffer? - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); + auto ctx = GlContext::GetCurrent().get(); + if (!ctx) ctx = view.gl_context(); + ctx->Run([pixel_buffer, &view, ctx] { + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + glBindFramebuffer(GL_FRAMEBUFFER, kUtilityFramebuffer.Get(*ctx)); + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), view.name(), 0); - size_t contiguous_bytes_per_row = view.width() * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * - view.height()); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE, - temp_ptr); - for (int i = 0; i < view.height(); ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; + size_t contiguous_bytes_per_row = view.width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, pixel_ptr); + } else { + // TODO: use GL_PACK settings for row length. We can expect + // GLES 3.0 on iOS now. + std::vector contiguous_buffer(contiguous_bytes_per_row * + view.height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, + GL_UNSIGNED_BYTE, temp_ptr); + for (int i = 0; i < view.height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } } + // TODO: restore previous framebuffer? + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + view.target(), 0, 0); + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; } - } else { - LOG(ERROR) << "unsupported pixel format: " << pixel_format; - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - CHECK(err == kCVReturnSuccess) - << "CVPixelBufferUnlockBaseAddress failed: " << err; + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; + }); } #endif // TARGET_IPHONE_SIMULATOR From f53c0eaceeae9b7cb622764d78054f8e44222ba3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 09:38:02 -0800 Subject: [PATCH 11/21] Extend tag conversion behavior to also convert `:` (in addition to the current `/`, `-`, and `.`) to `_`. PiperOrigin-RevId: 499243005 --- .../tensorflow_session_from_saved_model_calculator.cc | 7 +++---- .../tensorflow_session_from_saved_model_calculator.proto | 4 ++-- .../tensorflow_session_from_saved_model_generator.cc | 7 +++---- .../tensorflow_session_from_saved_model_generator.proto | 4 ++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 922eb9d50..18bddbbe3 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -55,7 +55,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, . and :'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -67,9 +67,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index 927d3b51f..515b46fa9 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase as well as switch + // /, -, .and :'s to _'s, which enables common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index d5236f1cc..ee69ec56a 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -61,7 +61,7 @@ absl::Status GetLatestDirectory(std::string* path) { } // If options.convert_signature_to_tags() is set, will convert letters to -// uppercase and replace /'s and -'s with _'s. This enables the standard +// uppercase and replace /, -, and .'s with _'s. This enables the standard // SavedModel classification, regression, and prediction signatures to be used // as uppercase INPUTS and OUTPUTS tags for streams and supports other common // patterns. @@ -73,9 +73,8 @@ const std::string MaybeConvertSignatureToTag( output.resize(name.length()); std::transform(name.begin(), name.end(), output.begin(), [](unsigned char c) { return std::toupper(c); }); - output = absl::StrReplaceAll(output, {{"/", "_"}}); - output = absl::StrReplaceAll(output, {{"-", "_"}}); - output = absl::StrReplaceAll(output, {{".", "_"}}); + output = absl::StrReplaceAll( + output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}}); LOG(INFO) << "Renamed TAG from: " << name << " to " << output; return output; } else { diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index d24a1cd73..d45fcb662 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -33,8 +33,8 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // The name of the generic signature to load into the mapping from tags to // tensor names. optional string signature_name = 2 [default = "serving_default"]; - // Whether to convert the signature keys to uppercase as well as switch /'s - // and -'s to _'s, which enables common signatures to be used as Tags. + // Whether to convert the signature keys to uppercase, as well as switch /'s + // -'s, .'s, and :'s to _'s, enabling common signatures to be used as Tags. optional bool convert_signature_to_tags = 3 [default = true]; // If true, saved_model_path can have multiple exported models in // subdirectories saved_model_path/%08d and the alphabetically last (i.e., From 987f4dc1ed89801e54c408abd670f63ce0c77007 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 10:52:41 -0800 Subject: [PATCH 12/21] Make addJsamineCustomFloatEqualityTest configurable PiperOrigin-RevId: 499263931 --- mediapipe/tasks/web/core/task_runner_test_utils.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 2a1161a55..838b3f585 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -44,10 +44,10 @@ export function createSpyWasmModule(): SpyWasmModule { * Sets up our equality testing to use a custom float equality checking function * to avoid incorrect test results due to minor floating point inaccuracies. */ -export function addJasmineCustomFloatEqualityTester() { +export function addJasmineCustomFloatEqualityTester(tolerance = 5e-8) { jasmine.addCustomEqualityTester((a, b) => { // Custom float equality if (a === +a && b === +b && (a !== (a | 0) || b !== (b | 0))) { - return Math.abs(a - b) < 5e-8; + return Math.abs(a - b) < tolerance; } return; }); From 68f247a5c7a2f081e6f0ff8b25b9187de5646e2b Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 12:03:57 -0800 Subject: [PATCH 13/21] Internal change PiperOrigin-RevId: 499282085 --- .../web/vision/hand_landmarker/hand_landmarker_result.d.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 89f867d69..8a6d9bfa6 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Landmark, NormalizedLandmark, Category}; + /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ From 75b87e0e321090bf73653d83ebfa69cf6f73621f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 3 Jan 2023 12:09:59 -0800 Subject: [PATCH 14/21] Internal change PiperOrigin-RevId: 499283559 --- .../gesture_recognizer/gesture_recognizer.ts | 35 ++++++++++++++----- .../gesture_recognizer_result.d.ts | 8 ++++- .../gesture_recognizer_test.ts | 23 +++++++++++- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index cfeb179f5..c77f2c67a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -263,12 +263,22 @@ export class GestureRecognizer extends NORM_RECT_STREAM, timestamp); this.finishProcessing(); - return { - gestures: this.gestures, - landmarks: this.landmarks, - worldLandmarks: this.worldLandmarks, - handednesses: this.handednesses - }; + if (this.gestures.length === 0) { + // If no gestures are detected in the image, just return an empty list + return { + gestures: [], + landmarks: [], + worldLandmarks: [], + handednesses: [], + }; + } else { + return { + gestures: this.gestures, + landmarks: this.landmarks, + worldLandmarks: this.worldLandmarks, + handednesses: this.handednesses + }; + } } /** Sets the default values for the graph. */ @@ -283,15 +293,19 @@ export class GestureRecognizer extends } /** Converts the proto data to a Category[][] structure. */ - private toJsCategories(data: Uint8Array[]): Category[][] { + private toJsCategories(data: Uint8Array[], populateIndex = true): + Category[][] { const result: Category[][] = []; for (const binaryProto of data) { const inputList = ClassificationList.deserializeBinary(binaryProto); const outputList: Category[] = []; for (const classification of inputList.getClassificationList()) { + const index = populateIndex && classification.hasIndex() ? + classification.getIndex()! : + DEFAULT_CATEGORY_INDEX; outputList.push({ score: classification.getScore() ?? 0, - index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, + index, categoryName: classification.getLabel() ?? '', displayName: classification.getDisplayName() ?? '', }); @@ -375,7 +389,10 @@ export class GestureRecognizer extends }); this.graphRunner.attachProtoVectorListener( HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); + // Gesture index is not used, because the final gesture result comes + // from multiple classifiers. + this.gestures.push( + ...this.toJsCategories(binaryProto, /* populateIndex= */ false)); }); this.graphRunner.attachProtoVectorListener( HANDEDNESS_STREAM, binaryProto => { diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index e570270b2..323290008 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -17,6 +17,8 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +export {Category, Landmark, NormalizedLandmark}; + /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ @@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult { /** Handedness of detected hands. */ handednesses: Category[][]; - /** Recognized hand gestures of detected hands */ + /** + * Recognized hand gestures of detected hands. Note that the index of the + * gesture is always -1, because the raw indices from multiple gesture + * classifiers cannot consolidate to a meaningful index. + */ gestures: Category[][]; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts index ff6bba613..ee51fd32a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_test.ts @@ -272,7 +272,7 @@ describe('GestureRecognizer', () => { expect(gestures).toEqual({ 'gestures': [[{ 'score': 0.2, - 'index': 2, + 'index': -1, 'categoryName': 'gesture_label', 'displayName': 'gesture_display_name' }]], @@ -305,4 +305,25 @@ describe('GestureRecognizer', () => { // gestures. expect(gestures2).toEqual(gestures1); }); + + it('returns empty results when no gestures are detected', async () => { + // Pass the test data to our listener + gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(gestureRecognizer); + gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks()); + gestureRecognizer.listeners.get('world_hand_landmarks')! + (createWorldLandmarks()); + gestureRecognizer.listeners.get('handedness')!(createHandednesses()); + gestureRecognizer.listeners.get('hand_gestures')!([]); + }); + + // Invoke the gesture recognizer + const gestures = gestureRecognizer.recognize({} as HTMLImageElement); + expect(gestures).toEqual({ + 'gestures': [], + 'landmarks': [], + 'worldLandmarks': [], + 'handednesses': [] + }); + }); }); From e7dc989f715382c10ac6d714f4f4be5d330f903d Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 14:12:34 -0800 Subject: [PATCH 15/21] Internal Change PiperOrigin-RevId: 499313491 --- mediapipe/examples/desktop/autoflip/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 562f11c49..0e28746dc 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -30,6 +30,10 @@ proto_library( java_lite_proto_library( name = "autoflip_messages_java_proto_lite", + visibility = [ + "//java/com/google/android/apps/photos:__subpackages__", + "//javatests/com/google/android/apps/photos:__subpackages__", + ], deps = [ ":autoflip_messages_proto", ], From add5600d0d4e9f0213ebf58088301dc7e743194a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 3 Jan 2023 17:18:59 -0800 Subject: [PATCH 16/21] Internal change PiperOrigin-RevId: 499351795 --- .../python/text/text_classifier/text_classifier_test.py | 1 + .../python/vision/image_classifier/image_classifier_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index d2edb78bc..eb4443b44 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -71,6 +71,7 @@ class TextClassifierTest(tf.test.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( filecmp.cmp( output_metadata_file, diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index 14c67d831..afda8643b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -135,6 +135,7 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): self.assertTrue(os.path.exists(output_metadata_file)) self.assertGreater(os.path.getsize(output_metadata_file), 0) + filecmp.clear_cache() self.assertTrue( filecmp.cmp( output_metadata_file, expected_metadata_file, shallow=False)) From a4ea606eac3adf3ca5e149e9e6ff6573168971a6 Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 4 Jan 2023 08:21:55 -0800 Subject: [PATCH 17/21] Internal change. PiperOrigin-RevId: 499490514 --- .../framework/formats/tensor_ahwb_gpu_test.cc | 28 +++++++++---------- .../framework/formats/tensor_ahwb_test.cc | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index e2ad869f9..45d341e20 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -83,8 +83,8 @@ void FillGpuBuffer(GLuint name, std::size_t size, TFLITE_GPU_CALL_GL(glBindBufferBase, GL_SHADER_STORAGE_BUFFER, 0, name)); MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glUseProgram, to_buffer_program)); MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDispatchCompute, size / 2, 1, 1)); - MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); - MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glBindBuffer, GL_SHADER_STORAGE_BUFFER, 0)); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteProgram, to_buffer_program)); } class TensorAhwbGpuTest : public mediapipe::GpuTestBase { @@ -97,18 +97,18 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) { { // Request Ahwb first to get Ahwb storage allocated internally. auto view = tensor.GetAHardwareBufferWriteView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetWritingFinishedFD(-1, [](bool) { return true; }); } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -124,18 +124,18 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) { { // Request Ahwb first to get Ahwb storage allocated internally. auto view = tensor.GetAHardwareBufferWriteView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -153,18 +153,18 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) { Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})}; { auto ptr = tensor.GetCpuWriteView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); for (int i = 0; i < num_elements; i++) { ptr[i] = static_cast(i) / 10.0f; } } { auto view = tensor.GetAHardwareBufferReadView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { @@ -182,17 +182,17 @@ TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) { RunInGlContext([&tensor] { auto ssbo_view = tensor.GetOpenGlBufferWriteView(); auto ssbo_name = ssbo_view.name(); - EXPECT_GT(ssbo_name, 0); + ASSERT_GT(ssbo_name, 0); FillGpuBuffer(ssbo_name, tensor.shape().num_elements(), tensor.element_type()); }); { auto view = tensor.GetAHardwareBufferReadView(); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); view.SetReadingFinishedFunc([](bool) { return true; }); } auto ptr = tensor.GetCpuReadView().buffer(); - EXPECT_NE(ptr, nullptr); + ASSERT_NE(ptr, nullptr); std::vector reference; reference.resize(num_elements); for (int i = 0; i < num_elements; i++) { diff --git a/mediapipe/framework/formats/tensor_ahwb_test.cc b/mediapipe/framework/formats/tensor_ahwb_test.cc index 3da6ca8d3..69e49dd58 100644 --- a/mediapipe/framework/formats/tensor_ahwb_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_test.cc @@ -34,7 +34,7 @@ TEST(TensorAhwbTest, TestAhwbAlignment) { Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{5}); { auto view = tensor.GetAHardwareBufferWriteView(16); - EXPECT_NE(view.handle(), nullptr); + ASSERT_NE(view.handle(), nullptr); if (__builtin_available(android 26, *)) { AHardwareBuffer_Desc desc; AHardwareBuffer_describe(view.handle(), &desc); From 9a70af146432dcbbbc961f9c1a5af4a039d0909a Mon Sep 17 00:00:00 2001 From: Nikolay Chirkov Date: Wed, 4 Jan 2023 08:52:03 -0800 Subject: [PATCH 18/21] Internal change. PiperOrigin-RevId: 499496793 --- mediapipe/framework/formats/tensor_ahwb_gpu_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc index 45d341e20..ff78d1f88 100644 --- a/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc +++ b/mediapipe/framework/formats/tensor_ahwb_gpu_test.cc @@ -68,8 +68,9 @@ void FillGpuBuffer(GLuint name, std::size_t size, MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderiv, shader, GL_INFO_LOG_LENGTH, &max_length)); std::vector error_log(max_length); - glGetShaderInfoLog(shader, max_length, &max_length, error_log.data()); - glDeleteShader(shader); + MP_ASSERT_OK(TFLITE_GPU_CALL_GL(glGetShaderInfoLog, shader, max_length, + &max_length, error_log.data())); + MP_EXPECT_OK(TFLITE_GPU_CALL_GL(glDeleteShader, shader)); FAIL() << error_log.data(); return; } From e3131d7d7856771def3c1c141720ca311ed0f3d9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 10:31:04 -0800 Subject: [PATCH 19/21] Internal change PiperOrigin-RevId: 499521620 --- mediapipe/model_maker/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/setup.py b/mediapipe/model_maker/setup.py index ea193db94..7114e2080 100644 --- a/mediapipe/model_maker/setup.py +++ b/mediapipe/model_maker/setup.py @@ -132,9 +132,9 @@ setuptools.setup( 'Operating System :: MacOS :: MacOS X', 'Operating System :: Microsoft :: Windows', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', From 24cc0672c47b0b2fac28bbc8434e93a9fccb47ad Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 10:57:33 -0800 Subject: [PATCH 20/21] Internal change PiperOrigin-RevId: 499529022 --- mediapipe/examples/desktop/autoflip/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/examples/desktop/autoflip/BUILD b/mediapipe/examples/desktop/autoflip/BUILD index 0e28746dc..340205caa 100644 --- a/mediapipe/examples/desktop/autoflip/BUILD +++ b/mediapipe/examples/desktop/autoflip/BUILD @@ -18,6 +18,8 @@ licenses(["notice"]) package(default_visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", + "//photos/editing/mobile/mediapipe/proto:__subpackages__", ]) proto_library( @@ -45,6 +47,8 @@ mediapipe_cc_proto_library( cc_deps = ["//mediapipe/framework:calculator_cc_proto"], visibility = [ "//mediapipe/examples:__subpackages__", + "//photos/editing/mobile/mediapipe/calculators:__pkg__", + "//photos/editing/mobile/mediapipe/calculators:__subpackages__", ], deps = [":autoflip_messages_proto"], ) From 43bf02443c1b8b7f237c9f7ef408da5cb56619b8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 4 Jan 2023 17:31:48 -0800 Subject: [PATCH 21/21] Option to remove overlapping values computed for different timestamps. PiperOrigin-RevId: 499635143 --- .../tensor_to_vector_int_calculator.cc | 20 +++++++ ...sor_to_vector_int_calculator_options.proto | 4 ++ .../tensor_to_vector_int_calculator_test.cc | 53 ++++++++++++++++++- 3 files changed, 76 insertions(+), 1 deletion(-) diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc index 2f4ff28cf..f92ddf08d 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator.cc @@ -37,8 +37,10 @@ class TensorToVectorIntCalculator : public CalculatorBase { private: void TokenizeVector(std::vector* vector) const; + void RemoveOverlapVector(std::vector* vector); TensorToVectorIntCalculatorOptions options_; + int32_t overlapping_values_; }; REGISTER_CALCULATOR(TensorToVectorIntCalculator); @@ -66,6 +68,7 @@ absl::Status TensorToVectorIntCalculator::GetContract(CalculatorContract* cc) { absl::Status TensorToVectorIntCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); + overlapping_values_ = 0; // Inform mediapipe that this calculator produces an output at time t for // each input received at time t (i.e. this calculator does not buffer @@ -106,6 +109,7 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(&instance_output); + RemoveOverlapVector(&instance_output); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } else { @@ -128,12 +132,28 @@ absl::Status TensorToVectorIntCalculator::Process(CalculatorContext* cc) { } } TokenizeVector(output.get()); + RemoveOverlapVector(output.get()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } return absl::OkStatus(); } +void TensorToVectorIntCalculator::RemoveOverlapVector( + std::vector* vector) { + if (options_.overlap() <= 0) { + return; + } + if (overlapping_values_ > 0) { + if (vector->size() < overlapping_values_) { + vector->clear(); + } else { + vector->erase(vector->begin(), vector->begin() + overlapping_values_); + } + } + overlapping_values_ = options_.overlap(); +} + void TensorToVectorIntCalculator::TokenizeVector( std::vector* vector) const { if (!options_.tensor_is_token()) { diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto index 9da3298b9..76b9be952 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_options.proto @@ -36,4 +36,8 @@ message TensorToVectorIntCalculatorOptions { optional bool tensor_is_token = 3 [default = false]; // Threshold for the token generation. optional float token_threshold = 4 [default = 0.5]; + + // Values which overlap between timely following vectors. They are removed + // from the output to reduce redundancy. + optional int32 overlap = 5 [default = 0]; } diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc index 60c0d47ec..406c2c1a7 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_int_calculator_test.cc @@ -28,7 +28,8 @@ namespace tf = ::tensorflow; class TensorToVectorIntCalculatorTest : public ::testing::Test { protected: void SetUpRunner(const bool tensor_is_2d, const bool flatten_nd, - const bool tensor_is_token = false) { + const bool tensor_is_token = false, + const int32_t overlap = 0) { CalculatorGraphConfig::Node config; config.set_calculator("TensorToVectorIntCalculator"); config.add_input_stream("input_tensor"); @@ -38,6 +39,7 @@ class TensorToVectorIntCalculatorTest : public ::testing::Test { options->set_tensor_is_2d(tensor_is_2d); options->set_flatten_nd(flatten_nd); options->set_tensor_is_token(tensor_is_token); + options->set_overlap(overlap); runner_ = absl::make_unique(config); } @@ -188,5 +190,54 @@ TEST_F(TensorToVectorIntCalculatorTest, FlattenShouldTakeAllDimensions) { } } +TEST_F(TensorToVectorIntCalculatorTest, Overlap) { + SetUpRunner(false, false, false, 2); + for (int time = 0; time < 3; ++time) { + const tf::TensorShape tensor_shape(std::vector{5}); + auto tensor = absl::make_unique(tf::DT_INT64, tensor_shape); + auto tensor_vec = tensor->vec(); + for (int i = 0; i < 5; ++i) { + // 2^i can be represented exactly in floating point numbers if 'i' is + // small. + tensor_vec(i) = static_cast(time + (1 << i)); + } + + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(tensor.release()).At(Timestamp(time))); + } + + ASSERT_TRUE(runner_->Run().ok()); + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(3, output_packets.size()); + + { + // First vector in full. + int time = 0; + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(5, output_vector.size()); + for (int i = 0; i < 5; ++i) { + const int64 expected = static_cast(time + (1 << i)); + EXPECT_EQ(expected, output_vector[i]); + } + } + + // All following vectors the overlap removed + for (int time = 1; time < 3; ++time) { + EXPECT_EQ(time, output_packets[time].Timestamp().Value()); + const std::vector& output_vector = + output_packets[time].Get>(); + + EXPECT_EQ(3, output_vector.size()); + for (int i = 0; i < 3; ++i) { + const int64 expected = static_cast(time + (1 << (i + 2))); + EXPECT_EQ(expected, output_vector[i]); + } + } +} + } // namespace } // namespace mediapipe