Open Source the remaining MediaPipe Tasks tests for Web

PiperOrigin-RevId: 493769657
This commit is contained in:
Sebastian Schmidt 2022-12-07 19:17:14 -08:00 committed by Copybara-Service
parent 24c8fa97e9
commit 9ae2e43b70
35 changed files with 2141 additions and 0 deletions

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.text.text_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.text.text_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.hand_detector.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.hand_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.hand_landmarker.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_embedder.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_segmenter.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";

View File

@ -18,6 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.object_detector.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto";

View File

@ -2,6 +2,7 @@
#
# This task takes audio data and outputs the classification result.
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -44,3 +45,23 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/core:classifier_options",
],
)
mediapipe_ts_library(
name = "audio_classifier_test_lib",
testonly = True,
srcs = [
"audio_classifier_test.ts",
],
deps = [
":audio_classifier",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "audio_classifier_test",
deps = [":audio_classifier_test_lib"],
)

View File

@ -0,0 +1,208 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {AudioClassifier} from './audio_classifier';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class AudioClassifierFake extends AudioClassifier implements
MediapipeTasksFake {
lastSampleRate: number|undefined;
calculatorName =
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined;
private resultProtoVector: ClassificationResult[] = [];
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('timestamped_classifications');
this.protoVectorListener = listener;
});
spyOn(this.graphRunner, 'addDoubleToStream')
.and.callFake((sampleRate, streamName, timestamp) => {
if (streamName === 'sample_rate') {
this.lastSampleRate = sampleRate;
}
});
spyOn(this.graphRunner, 'addAudioToStreamWithShape')
.and.callFake(
(audioData, numChannels, numSamples, streamName, timestamp) => {
expect(numChannels).toBe(1);
});
spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => {
if (!this.protoVectorListener) return;
this.protoVectorListener(this.resultProtoVector.map(
classificationResult => classificationResult.serializeBinary()));
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
}
/** Sets the Protobuf that will be send to the API. */
setResults(results: ClassificationResult[]): void {
this.resultProtoVector = results;
}
}
describe('AudioClassifier', () => {
let audioClassifier: AudioClassifierFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
audioClassifier = new AudioClassifierFake();
await audioClassifier.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(audioClassifier);
verifyListenersRegistered(audioClassifier);
});
it('reloads graph when settings are changed', async () => {
await audioClassifier.setOptions({maxResults: 1});
verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]);
verifyListenersRegistered(audioClassifier);
await audioClassifier.setOptions({maxResults: 5});
verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]);
verifyListenersRegistered(audioClassifier);
});
it('merges options', async () => {
await audioClassifier.setOptions({maxResults: 1});
await audioClassifier.setOptions({displayNamesLocale: 'en'});
verifyGraph(audioClassifier, [
'classifierOptions', {
maxResults: 1,
displayNamesLocale: 'en',
scoreThreshold: undefined,
categoryAllowlistList: [],
categoryDenylistList: []
}
]);
});
it('uses a sample rate of 48000 by default', async () => {
audioClassifier.classify(new Float32Array([]));
expect(audioClassifier.lastSampleRate).toEqual(48000);
});
it('uses default sample rate if none provided', async () => {
audioClassifier.setDefaultSampleRate(16000);
audioClassifier.classify(new Float32Array([]));
expect(audioClassifier.lastSampleRate).toEqual(16000);
});
it('uses custom sample rate if provided', async () => {
audioClassifier.setDefaultSampleRate(16000);
audioClassifier.classify(new Float32Array([]), 44100);
expect(audioClassifier.lastSampleRate).toEqual(44100);
});
it('transforms results', async () => {
const resultProtoVector: ClassificationResult[] = [];
let classificationResult = new ClassificationResult();
classificationResult.setTimestampMs(0);
let classifcations = new Classifications();
classifcations.setHeadIndex(1);
classifcations.setHeadName('headName');
let classificationList = new ClassificationList();
let clasification = new Classification();
clasification.setIndex(1);
clasification.setScore(0.2);
clasification.setDisplayName('displayName');
clasification.setLabel('categoryName');
classificationList.addClassification(clasification);
classifcations.setClassificationList(classificationList);
classificationResult.addClassifications(classifcations);
resultProtoVector.push(classificationResult);
classificationResult = new ClassificationResult();
classificationResult.setTimestampMs(1);
classifcations = new Classifications();
classificationList = new ClassificationList();
clasification = new Classification();
clasification.setIndex(2);
clasification.setScore(0.3);
classificationList.addClassification(clasification);
classifcations.setClassificationList(classificationList);
classificationResult.addClassifications(classifcations);
resultProtoVector.push(classificationResult);
// Invoke the audio classifier
audioClassifier.setResults(resultProtoVector);
const results = audioClassifier.classify(new Float32Array([]));
expect(results.length).toEqual(2);
expect(results[0]).toEqual({
classifications: [{
categories: [{
index: 1,
score: 0.2,
displayName: 'displayName',
categoryName: 'categoryName'
}],
headIndex: 1,
headName: 'headName'
}],
timestampMs: 0
});
expect(results[1]).toEqual({
classifications: [{
categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}],
headIndex: 0,
headName: ''
}],
timestampMs: 1
});
});
it('clears results between invocations', async () => {
const classificationResult = new ClassificationResult();
const classifcations = new Classifications();
const classificationList = new ClassificationList();
const clasification = new Classification();
classificationList.addClassification(clasification);
classifcations.setClassificationList(classificationList);
classificationResult.addClassifications(classifcations);
audioClassifier.setResults([classificationResult]);
// Invoke the gesture recognizer twice
const classifications1 = audioClassifier.classify(new Float32Array([]));
const classifications2 = audioClassifier.classify(new Float32Array([]));
// Verify that gestures2 is not a concatenation of all previously returned
// gestures.
expect(classifications1).toEqual(classifications2);
});
});

View File

@ -3,6 +3,7 @@
# This task takes audio input and performs embedding.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -43,3 +44,23 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/core:embedder_options",
],
)
mediapipe_ts_library(
name = "audio_embedder_test_lib",
testonly = True,
srcs = [
"audio_embedder_test.ts",
],
deps = [
":audio_embedder",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "audio_embedder_test",
deps = [":audio_embedder_test_lib"],
)

View File

@ -0,0 +1,185 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake {
lastSampleRate: number|undefined;
calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph';
graph: CalculatorGraphConfig|undefined;
attachListenerSpies: jasmine.Spy[] = [];
fakeWasmModule: SpyWasmModule;
protoListener: ((binaryProto: Uint8Array) => void)|undefined;
protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('embeddings_out');
this.protoListener = listener;
});
this.attachListenerSpies[1] =
spyOn(this.graphRunner, 'attachProtoVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('timestamped_embeddings_out');
this.protoVectorListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => {
this.lastSampleRate = sampleRate;
});
spyOn(this.graphRunner, 'addAudioToStreamWithShape');
}
}
describe('AudioEmbedder', () => {
let audioEmbedder: AudioEmbedderFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
audioEmbedder = new AudioEmbedderFake();
await audioEmbedder.setOptions({}); // Initialize graph
});
it('initializes graph', () => {
verifyGraph(audioEmbedder);
verifyListenersRegistered(audioEmbedder);
});
it('reloads graph when settings are changed', async () => {
await audioEmbedder.setOptions({quantize: true});
verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]);
verifyListenersRegistered(audioEmbedder);
await audioEmbedder.setOptions({quantize: undefined});
verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]);
verifyListenersRegistered(audioEmbedder);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await audioEmbedder.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
audioEmbedder,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('combines options', async () => {
await audioEmbedder.setOptions({quantize: true});
await audioEmbedder.setOptions({l2Normalize: true});
verifyGraph(
audioEmbedder,
['embedderOptions', {'quantize': true, 'l2Normalize': true}]);
});
it('uses a sample rate of 48000 by default', async () => {
audioEmbedder.embed(new Float32Array([]));
expect(audioEmbedder.lastSampleRate).toEqual(48000);
});
it('uses default sample rate if none provided', async () => {
audioEmbedder.setDefaultSampleRate(16000);
audioEmbedder.embed(new Float32Array([]));
expect(audioEmbedder.lastSampleRate).toEqual(16000);
});
it('uses custom sample rate if provided', async () => {
audioEmbedder.setDefaultSampleRate(16000);
audioEmbedder.embed(new Float32Array([]), 44100);
expect(audioEmbedder.lastSampleRate).toEqual(44100);
});
describe('transforms results', () => {
const embedding = new Embedding();
embedding.setHeadIndex(1);
embedding.setHeadName('headName');
const floatEmbedding = new FloatEmbedding();
floatEmbedding.setValuesList([0.1, 0.9]);
embedding.setFloatEmbedding(floatEmbedding);
const resultProto = new EmbeddingResultProto();
resultProto.addEmbeddings(embedding);
function validateEmbeddingResult(
expectedEmbeddignResult: AudioEmbedderResult[]) {
expect(expectedEmbeddignResult.length).toEqual(1);
const [embeddingResult] = expectedEmbeddignResult;
expect(embeddingResult.embeddings.length).toEqual(1);
expect(embeddingResult.embeddings[0])
.toEqual(
{floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'});
}
it('from embeddings strem', async () => {
audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(audioEmbedder);
// Pass the test data to our listener
audioEmbedder.protoListener!(resultProto.serializeBinary());
});
// Invoke the audio embedder
const embeddingResults = audioEmbedder.embed(new Float32Array([]));
validateEmbeddingResult(embeddingResults);
});
it('from timestamped embeddgins stream', async () => {
audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(audioEmbedder);
// Pass the test data to our listener
audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]);
});
// Invoke the audio embedder
const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42);
validateEmbeddingResult(embeddingResults);
});
});
});

View File

@ -4,6 +4,7 @@
# BERT-based text classification).
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -45,3 +46,24 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/core:classifier_options",
],
)
mediapipe_ts_library(
name = "text_classifier_test_lib",
testonly = True,
srcs = [
"text_classifier_test.ts",
],
deps = [
":text_classifier",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "text_classifier_test",
deps = [":text_classifier_test_lib"],
)

View File

@ -0,0 +1,152 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {TextClassifier} from './text_classifier';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class TextClassifierFake extends TextClassifier implements MediapipeTasksFake {
calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
protoListener: ((binaryProto: Uint8Array) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('classifications_out');
this.protoListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
}
}
describe('TextClassifier', () => {
let textClassifier: TextClassifierFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
textClassifier = new TextClassifierFake();
await textClassifier.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(textClassifier);
verifyListenersRegistered(textClassifier);
});
it('reloads graph when settings are changed', async () => {
await textClassifier.setOptions({maxResults: 1});
verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]);
verifyListenersRegistered(textClassifier);
await textClassifier.setOptions({maxResults: 5});
verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]);
verifyListenersRegistered(textClassifier);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await textClassifier.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
textClassifier,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */
[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('merges options', async () => {
await textClassifier.setOptions({maxResults: 1});
await textClassifier.setOptions({displayNamesLocale: 'en'});
verifyGraph(textClassifier, [
'classifierOptions', {
maxResults: 1,
displayNamesLocale: 'en',
scoreThreshold: undefined,
categoryAllowlistList: [],
categoryDenylistList: []
}
]);
});
it('transforms results', async () => {
const classificationResult = new ClassificationResult();
const classifcations = new Classifications();
classifcations.setHeadIndex(1);
classifcations.setHeadName('headName');
const classificationList = new ClassificationList();
const clasification = new Classification();
clasification.setIndex(1);
clasification.setScore(0.2);
clasification.setDisplayName('displayName');
clasification.setLabel('categoryName');
classificationList.addClassification(clasification);
classifcations.setClassificationList(classificationList);
classificationResult.addClassifications(classifcations);
// Pass the test data to our listener
textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(textClassifier);
textClassifier.protoListener!(classificationResult.serializeBinary());
});
// Invoke the text classifier
const result = textClassifier.classify('foo');
expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result).toEqual({
classifications: [{
categories: [{
index: 1,
score: 0.2,
displayName: 'displayName',
categoryName: 'categoryName'
}],
headIndex: 1,
headName: 'headName'
}]
});
});
});

View File

@ -4,6 +4,7 @@
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -44,3 +45,23 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/core:embedder_options",
],
)
mediapipe_ts_library(
name = "text_embedder_test_lib",
testonly = True,
srcs = [
"text_embedder_test.ts",
],
deps = [
":text_embedder",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "text_embedder_test",
deps = [":text_embedder_test_lib"],
)

View File

@ -0,0 +1,165 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {TextEmbedder} from './text_embedder';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake {
calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph';
graph: CalculatorGraphConfig|undefined;
attachListenerSpies: jasmine.Spy[] = [];
fakeWasmModule: SpyWasmModule;
protoListener: ((binaryProtos: Uint8Array) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('embeddings_out');
this.protoListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
}
}
describe('TextEmbedder', () => {
let textEmbedder: TextEmbedderFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
textEmbedder = new TextEmbedderFake();
await textEmbedder.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(textEmbedder);
verifyListenersRegistered(textEmbedder);
});
it('reloads graph when settings are changed', async () => {
await textEmbedder.setOptions({quantize: true});
verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]);
verifyListenersRegistered(textEmbedder);
await textEmbedder.setOptions({quantize: undefined});
verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]);
verifyListenersRegistered(textEmbedder);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await textEmbedder.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
textEmbedder,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('combines options', async () => {
await textEmbedder.setOptions({quantize: true});
await textEmbedder.setOptions({l2Normalize: true});
verifyGraph(
textEmbedder,
['embedderOptions', {'quantize': true, 'l2Normalize': true}]);
});
it('transforms results', async () => {
const embedding = new Embedding();
embedding.setHeadIndex(1);
embedding.setHeadName('headName');
const floatEmbedding = new FloatEmbedding();
floatEmbedding.setValuesList([0.1, 0.9]);
embedding.setFloatEmbedding(floatEmbedding);
const resultProto = new EmbeddingResult();
resultProto.addEmbeddings(embedding);
// Pass the test data to our listener
textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(textEmbedder);
textEmbedder.protoListener!(resultProto.serializeBinary());
});
// Invoke the text embedder
const embeddingResult = textEmbedder.embed('foo');
expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(embeddingResult.embeddings.length).toEqual(1);
expect(embeddingResult.embeddings[0])
.toEqual(
{floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'});
});
it('transforms custom quantized values', async () => {
const embedding = new Embedding();
embedding.setHeadIndex(1);
embedding.setHeadName('headName');
const quantizedEmbedding = new QuantizedEmbedding();
const quantizedValues = new Uint8Array([1, 2, 3]);
quantizedEmbedding.setValues(quantizedValues);
embedding.setQuantizedEmbedding(quantizedEmbedding);
const resultProto = new EmbeddingResult();
resultProto.addEmbeddings(embedding);
// Pass the test data to our listener
textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(textEmbedder);
textEmbedder.protoListener!(resultProto.serializeBinary());
});
// Invoke the text embedder
const embeddingsResult = textEmbedder.embed('foo');
expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(embeddingsResult.embeddings.length).toEqual(1);
expect(embeddingsResult.embeddings[0]).toEqual({
quantizedEmbedding: new Uint8Array([1, 2, 3]),
headIndex: 1,
headName: 'headName'
});
});
});

View File

@ -1,5 +1,6 @@
# This package contains options shared by all MediaPipe Vision Tasks for Web.
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -22,3 +23,20 @@ mediapipe_ts_library(
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)
mediapipe_ts_library(
name = "vision_task_runner_test_lib",
testonly = True,
srcs = ["vision_task_runner.test.ts"],
deps = [
":vision_task_runner",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)
jasmine_node_test(
name = "vision_task_runner_test",
deps = [":vision_task_runner_test_lib"],
)

View File

@ -0,0 +1,99 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
import {VisionTaskRunner} from './vision_task_runner';
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
baseOptions = new BaseOptionsProto();
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
}
protected override process(): void {}
override processImageData(image: ImageSource): void {
super.processImageData(image);
}
override processVideoData(imageFrame: ImageSource, timestamp: number): void {
super.processVideoData(imageFrame, timestamp);
}
}
describe('VisionTaskRunner', () => {
const streamMode = {
modelAsset: undefined,
useStreamMode: true,
acceleration: undefined,
};
const imageMode = {
modelAsset: undefined,
useStreamMode: false,
acceleration: undefined,
};
let visionTaskRunner: VisionTaskRunnerFake;
beforeEach(() => {
visionTaskRunner = new VisionTaskRunnerFake();
});
it('can enable image mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'image'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
});
it('can enable video mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'video'});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
});
it('can clear running mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'video'});
// Clear running mode
await visionTaskRunner.setOptions({runningMode: undefined});
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
});
it('cannot process images with video mode', async () => {
await visionTaskRunner.setOptions({runningMode: 'video'});
expect(() => {
visionTaskRunner.processImageData({} as HTMLImageElement);
}).toThrowError(/Task is not initialized with image mode./);
});
it('cannot process video with image mode', async () => {
// Use default for `useStreamMode`
expect(() => {
visionTaskRunner.processVideoData({} as HTMLImageElement, 42);
}).toThrowError(/Task is not initialized with video mode./);
// Explicitly set to image mode
await visionTaskRunner.setOptions({runningMode: 'image'});
expect(() => {
visionTaskRunner.processVideoData({} as HTMLImageElement, 42);
}).toThrowError(/Task is not initialized with video mode./);
});
});

View File

@ -4,6 +4,7 @@
# the detection results for one or more gesture categories, using Gesture Recognizer.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -52,3 +53,27 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "gesture_recognizer_test_lib",
testonly = True,
srcs = [
"gesture_recognizer_test.ts",
],
deps = [
":gesture_recognizer",
":gesture_recognizer_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "gesture_recognizer_test",
tags = ["nomsan"],
deps = [":gesture_recognizer_test_lib"],
)

View File

@ -0,0 +1,307 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
type ProtoListener = ((binaryProtos: Uint8Array[]) => void);
function createHandednesses(): Uint8Array[] {
const handsProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
classification.setIndex(1);
classification.setLabel('handedness_label');
classification.setDisplayName('handedness_display_name');
handsProto.addClassification(classification);
return [handsProto.serializeBinary()];
}
function createGestures(): Uint8Array[] {
const gesturesProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.2);
classification.setIndex(2);
classification.setLabel('gesture_label');
classification.setDisplayName('gesture_display_name');
gesturesProto.addClassification(classification);
return [gesturesProto.serializeBinary()];
}
function createLandmarks(): Uint8Array[] {
const handLandmarksProto = new NormalizedLandmarkList();
const landmark = new NormalizedLandmark();
landmark.setX(0.3);
landmark.setY(0.4);
landmark.setZ(0.5);
handLandmarksProto.addLandmark(landmark);
return [handLandmarksProto.serializeBinary()];
}
function createWorldLandmarks(): Uint8Array[] {
const handLandmarksProto = new LandmarkList();
const landmark = new Landmark();
landmark.setX(21);
landmark.setY(22);
landmark.setZ(23);
handLandmarksProto.addLandmark(landmark);
return [handLandmarksProto.serializeBinary()];
}
class GestureRecognizerFake extends GestureRecognizer implements
MediapipeTasksFake {
calculatorName =
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
listeners = new Map<string, ProtoListener>();
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toMatch(
/(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/);
this.listeners.set(stream, listener);
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
spyOn(this.graphRunner, 'addProtoToStream');
}
getGraphRunner(): GraphRunnerImageLib {
return this.graphRunner;
}
}
describe('GestureRecognizer', () => {
let gestureRecognizer: GestureRecognizerFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
gestureRecognizer = new GestureRecognizerFake();
await gestureRecognizer.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(gestureRecognizer);
verifyListenersRegistered(gestureRecognizer);
});
it('reloads graph when settings are changed', async () => {
await gestureRecognizer.setOptions({numHands: 1});
verifyGraph(gestureRecognizer, [
['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1
]);
verifyListenersRegistered(gestureRecognizer);
await gestureRecognizer.setOptions({numHands: 5});
verifyGraph(gestureRecognizer, [
['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5
]);
verifyListenersRegistered(gestureRecognizer);
});
it('merges options', async () => {
await gestureRecognizer.setOptions({numHands: 1});
await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5});
verifyGraph(gestureRecognizer, [
['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1
]);
verifyGraph(gestureRecognizer, [
[
'handLandmarkerGraphOptions', 'handDetectorGraphOptions',
'minDetectionConfidence'
],
0.5
]);
});
describe('setOptions() ', () => {
interface TestCase {
optionPath: [keyof GestureRecognizerOptions, ...string[]];
fieldPath: string[];
customValue: unknown;
defaultValue: unknown;
}
const testCases: TestCase[] = [
{
optionPath: ['numHands'],
fieldPath: [
'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'
],
customValue: 5,
defaultValue: 1
},
{
optionPath: ['minHandDetectionConfidence'],
fieldPath: [
'handLandmarkerGraphOptions', 'handDetectorGraphOptions',
'minDetectionConfidence'
],
customValue: 0.1,
defaultValue: 0.5
},
{
optionPath: ['minHandPresenceConfidence'],
fieldPath: [
'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions',
'minDetectionConfidence'
],
customValue: 0.2,
defaultValue: 0.5
},
{
optionPath: ['minTrackingConfidence'],
fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'],
customValue: 0.3,
defaultValue: 0.5
},
{
optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'],
fieldPath: [
'handGestureRecognizerGraphOptions',
'cannedGestureClassifierGraphOptions', 'classifierOptions',
'scoreThreshold'
],
customValue: 0.4,
defaultValue: undefined
},
{
optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'],
fieldPath: [
'handGestureRecognizerGraphOptions',
'customGestureClassifierGraphOptions', 'classifierOptions',
'scoreThreshold'
],
customValue: 0.5,
defaultValue: undefined,
},
];
/** Creates an options object that can be passed to setOptions() */
function createOptions(
path: string[], value: unknown): GestureRecognizerOptions {
const options: Record<string, unknown> = {};
let currentLevel = options;
for (const element of path.slice(0, -1)) {
currentLevel[element] = {};
currentLevel = currentLevel[element] as Record<string, unknown>;
}
currentLevel[path[path.length - 1]] = value;
return options;
}
for (const testCase of testCases) {
it(`uses default value for ${testCase.optionPath[0]}`, async () => {
verifyGraph(
gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]);
});
it(`can set ${testCase.optionPath[0]}`, async () => {
await gestureRecognizer.setOptions(
createOptions(testCase.optionPath, testCase.customValue));
verifyGraph(
gestureRecognizer, [testCase.fieldPath, testCase.customValue]);
});
it(`can clear ${testCase.optionPath[0]}`, async () => {
await gestureRecognizer.setOptions(
createOptions(testCase.optionPath, testCase.customValue));
verifyGraph(
gestureRecognizer, [testCase.fieldPath, testCase.customValue]);
await gestureRecognizer.setOptions(
createOptions(testCase.optionPath, undefined));
verifyGraph(
gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]);
});
}
});
it('transforms results', async () => {
// Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(gestureRecognizer);
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
gestureRecognizer.listeners.get('hand_gestures')!(createGestures());
});
// Invoke the gesture recognizer
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
expect(gestureRecognizer.getGraphRunner().addProtoToStream)
.toHaveBeenCalledTimes(1);
expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream)
.toHaveBeenCalledTimes(1);
expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(gestures).toEqual({
'gestures': [[{
'score': 0.2,
'index': 2,
'categoryName': 'gesture_label',
'displayName': 'gesture_display_name'
}]],
'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]],
'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]],
'handednesses': [[{
'score': 0.1,
'index': 1,
'categoryName': 'handedness_label',
'displayName': 'handedness_display_name'
}]]
});
});
it('clears results between invoations', async () => {
// Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
gestureRecognizer.listeners.get('hand_gestures')!(createGestures());
});
// Invoke the gesture recognizer twice
const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement);
const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement);
// Verify that gestures2 is not a concatenation of all previously returned
// gestures.
expect(gestures2).toEqual(gestures1);
});
});

View File

@ -4,6 +4,7 @@
# the detection results for one or more hand categories, using Hand Landmarker.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -47,3 +48,27 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "hand_landmarker_test_lib",
testonly = True,
srcs = [
"hand_landmarker_test.ts",
],
deps = [
":hand_landmarker",
":hand_landmarker_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "hand_landmarker_test",
tags = ["nomsan"],
deps = [":hand_landmarker_test_lib"],
)

View File

@ -0,0 +1,251 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {HandLandmarker} from './hand_landmarker';
import {HandLandmarkerOptions} from './hand_landmarker_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
type ProtoListener = ((binaryProtos: Uint8Array[]) => void);
function createHandednesses(): Uint8Array[] {
const handsProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
classification.setIndex(1);
classification.setLabel('handedness_label');
classification.setDisplayName('handedness_display_name');
handsProto.addClassification(classification);
return [handsProto.serializeBinary()];
}
function createLandmarks(): Uint8Array[] {
const handLandmarksProto = new NormalizedLandmarkList();
const landmark = new NormalizedLandmark();
landmark.setX(0.3);
landmark.setY(0.4);
landmark.setZ(0.5);
handLandmarksProto.addLandmark(landmark);
return [handLandmarksProto.serializeBinary()];
}
function createWorldLandmarks(): Uint8Array[] {
const handLandmarksProto = new LandmarkList();
const landmark = new Landmark();
landmark.setX(21);
landmark.setY(22);
landmark.setZ(23);
handLandmarksProto.addLandmark(landmark);
return [handLandmarksProto.serializeBinary()];
}
class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake {
calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
listeners = new Map<string, ProtoListener>();
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toMatch(
/(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/);
this.listeners.set(stream, listener);
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
spyOn(this.graphRunner, 'addProtoToStream');
}
getGraphRunner(): GraphRunnerImageLib {
return this.graphRunner;
}
}
describe('HandLandmarker', () => {
let handLandmarker: HandLandmarkerFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
handLandmarker = new HandLandmarkerFake();
await handLandmarker.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(handLandmarker);
verifyListenersRegistered(handLandmarker);
});
it('reloads graph when settings are changed', async () => {
verifyListenersRegistered(handLandmarker);
await handLandmarker.setOptions({numHands: 1});
verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]);
verifyListenersRegistered(handLandmarker);
await handLandmarker.setOptions({numHands: 5});
verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]);
verifyListenersRegistered(handLandmarker);
});
it('merges options', async () => {
await handLandmarker.setOptions({numHands: 1});
await handLandmarker.setOptions({minHandDetectionConfidence: 0.5});
verifyGraph(handLandmarker, [
'handDetectorGraphOptions',
{numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5}
]);
});
describe('setOptions() ', () => {
interface TestCase {
optionPath: [keyof HandLandmarkerOptions, ...string[]];
fieldPath: string[];
customValue: unknown;
defaultValue: unknown;
}
const testCases: TestCase[] = [
{
optionPath: ['numHands'],
fieldPath: ['handDetectorGraphOptions', 'numHands'],
customValue: 5,
defaultValue: 1
},
{
optionPath: ['minHandDetectionConfidence'],
fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.1,
defaultValue: 0.5
},
{
optionPath: ['minHandPresenceConfidence'],
fieldPath:
['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.2,
defaultValue: 0.5
},
{
optionPath: ['minTrackingConfidence'],
fieldPath: ['minTrackingConfidence'],
customValue: 0.3,
defaultValue: 0.5
},
];
/** Creates an options object that can be passed to setOptions() */
function createOptions(
path: string[], value: unknown): HandLandmarkerOptions {
const options: Record<string, unknown> = {};
let currentLevel = options;
for (const element of path.slice(0, -1)) {
currentLevel[element] = {};
currentLevel = currentLevel[element] as Record<string, unknown>;
}
currentLevel[path[path.length - 1]] = value;
return options;
}
for (const testCase of testCases) {
it(`uses default value for ${testCase.optionPath[0]}`, async () => {
verifyGraph(
handLandmarker, [testCase.fieldPath, testCase.defaultValue]);
});
it(`can set ${testCase.optionPath[0]}`, async () => {
await handLandmarker.setOptions(
createOptions(testCase.optionPath, testCase.customValue));
verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]);
});
it(`can clear ${testCase.optionPath[0]}`, async () => {
await handLandmarker.setOptions(
createOptions(testCase.optionPath, testCase.customValue));
verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]);
await handLandmarker.setOptions(
createOptions(testCase.optionPath, undefined));
verifyGraph(
handLandmarker, [testCase.fieldPath, testCase.defaultValue]);
});
}
});
it('transforms results', async () => {
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(handLandmarker);
handLandmarker.listeners.get('hand_landmarks')!(createLandmarks());
handLandmarker.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
handLandmarker.listeners.get('handedness')!(createHandednesses());
});
// Invoke the hand landmarker
const landmarks = handLandmarker.detect({} as HTMLImageElement);
expect(handLandmarker.getGraphRunner().addProtoToStream)
.toHaveBeenCalledTimes(1);
expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream)
.toHaveBeenCalledTimes(1);
expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(landmarks).toEqual({
'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]],
'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]],
'handednesses': [[{
'score': 0.1,
'index': 1,
'categoryName': 'handedness_label',
'displayName': 'handedness_display_name'
}]]
});
});
it('clears results between invoations', async () => {
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
handLandmarker.listeners.get('hand_landmarks')!(createLandmarks());
handLandmarker.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
handLandmarker.listeners.get('handedness')!(createHandednesses());
});
// Invoke the hand landmarker twice
const landmarks1 = handLandmarker.detect({} as HTMLImageElement);
const landmarks2 = handLandmarker.detect({} as HTMLImageElement);
// Verify that hands2 is not a concatenation of all previously returned
// hands.
expect(landmarks1).toEqual(landmarks2);
});
});

View File

@ -3,6 +3,7 @@
# This task takes video or image frames and outputs the classification result.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -44,3 +45,26 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "image_classifier_test_lib",
testonly = True,
srcs = [
"image_classifier_test.ts",
],
deps = [
":image_classifier",
":image_classifier_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "image_classifier_test",
tags = ["nomsan"],
deps = [":image_classifier_test_lib"],
)

View File

@ -0,0 +1,150 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {ImageClassifier} from './image_classifier';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class ImageClassifierFake extends ImageClassifier implements
MediapipeTasksFake {
calculatorName =
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
protoListener: ((binaryProto: Uint8Array) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('classifications');
this.protoListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
}
}
describe('ImageClassifier', () => {
let imageClassifier: ImageClassifierFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
imageClassifier = new ImageClassifierFake();
await imageClassifier.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(imageClassifier);
verifyListenersRegistered(imageClassifier);
});
it('reloads graph when settings are changed', async () => {
await imageClassifier.setOptions({maxResults: 1});
verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]);
verifyListenersRegistered(imageClassifier);
await imageClassifier.setOptions({maxResults: 5});
verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]);
verifyListenersRegistered(imageClassifier);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await imageClassifier.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
imageClassifier,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('merges options', async () => {
await imageClassifier.setOptions({maxResults: 1});
await imageClassifier.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]);
verifyGraph(
imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']);
});
it('transforms results', async () => {
const classificationResult = new ClassificationResult();
const classifcations = new Classifications();
classifcations.setHeadIndex(1);
classifcations.setHeadName('headName');
const classificationList = new ClassificationList();
const clasification = new Classification();
clasification.setIndex(1);
clasification.setScore(0.2);
clasification.setDisplayName('displayName');
clasification.setLabel('categoryName');
classificationList.addClassification(clasification);
classifcations.setClassificationList(classificationList);
classificationResult.addClassifications(classifcations);
// Pass the test data to our listener
imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageClassifier);
imageClassifier.protoListener!(classificationResult.serializeBinary());
});
// Invoke the image classifier
const result = imageClassifier.classify({} as HTMLImageElement);
expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result).toEqual({
classifications: [{
categories: [{
index: 1,
score: 0.2,
displayName: 'displayName',
categoryName: 'categoryName'
}],
headIndex: 1,
headName: 'headName'
}]
});
});
});

View File

@ -3,6 +3,7 @@
# This task performs embedding extraction on images.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -45,3 +46,23 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "image_embedder_test_lib",
testonly = True,
srcs = [
"image_embedder_test.ts",
],
deps = [
":image_embedder",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "image_embedder_test",
deps = [":image_embedder_test_lib"],
)

View File

@ -0,0 +1,158 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {ImageEmbedder} from './image_embedder';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake {
calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph';
graph: CalculatorGraphConfig|undefined;
attachListenerSpies: jasmine.Spy[] = [];
fakeWasmModule: SpyWasmModule;
protoListener: ((binaryProtos: Uint8Array) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('embeddings_out');
this.protoListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
}
}
describe('ImageEmbedder', () => {
let imageEmbedder: ImageEmbedderFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
imageEmbedder = new ImageEmbedderFake();
await imageEmbedder.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(imageEmbedder);
verifyListenersRegistered(imageEmbedder);
});
it('reloads graph when settings are changed', async () => {
verifyListenersRegistered(imageEmbedder);
await imageEmbedder.setOptions({quantize: true});
verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]);
verifyListenersRegistered(imageEmbedder);
await imageEmbedder.setOptions({quantize: undefined});
verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]);
verifyListenersRegistered(imageEmbedder);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await imageEmbedder.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
imageEmbedder,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('overrides options', async () => {
await imageEmbedder.setOptions({quantize: true});
await imageEmbedder.setOptions({l2Normalize: true});
verifyGraph(
imageEmbedder,
['embedderOptions', {'quantize': true, 'l2Normalize': true}]);
});
describe('transforms result', () => {
beforeEach(() => {
const floatEmbedding = new FloatEmbedding();
floatEmbedding.setValuesList([0.1, 0.9]);
const embedding = new Embedding();
embedding.setHeadIndex(1);
embedding.setHeadName('headName');
embedding.setFloatEmbedding(floatEmbedding);
const resultProto = new EmbeddingResult();
resultProto.addEmbeddings(embedding);
resultProto.setTimestampMs(42);
// Pass the test data to our listener
imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageEmbedder);
imageEmbedder.protoListener!(resultProto.serializeBinary());
});
});
it('for image mode', async () => {
// Invoke the image embedder
const embeddingResult = imageEmbedder.embed({} as HTMLImageElement);
expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(embeddingResult).toEqual({
embeddings:
[{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}],
timestampMs: 42
});
});
it('for video mode', async () => {
await imageEmbedder.setOptions({runningMode: 'video'});
// Invoke the video embedder
const embeddingResult =
imageEmbedder.embedForVideo({} as HTMLImageElement, 42);
expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(embeddingResult).toEqual({
embeddings:
[{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}],
timestampMs: 42
});
});
});
});

View File

@ -4,6 +4,7 @@
# the detection results for one or more object categories, using Object Detector.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
@ -41,3 +42,26 @@ mediapipe_ts_declaration(
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "object_detector_test_lib",
testonly = True,
srcs = [
"object_detector_test.ts",
],
deps = [
":object_detector",
":object_detector_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:detection_jspb_proto",
"//mediapipe/framework/formats:location_data_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
],
)
jasmine_node_test(
name = "object_detector_test",
tags = ["nomsan"],
deps = [":object_detector_test_lib"],
)

View File

@ -0,0 +1,229 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
import {LocationData} from '../../../../framework/formats/location_data_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {ObjectDetector} from './object_detector';
import {ObjectDetectorOptions} from './object_detector_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake {
lastSampleRate: number|undefined;
calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('detections');
this.protoListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
}
}
describe('ObjectDetector', () => {
let objectDetector: ObjectDetectorFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
objectDetector = new ObjectDetectorFake();
await objectDetector.setOptions({}); // Initialize graph
});
it('initializes graph', async () => {
verifyGraph(objectDetector);
verifyListenersRegistered(objectDetector);
});
it('reloads graph when settings are changed', async () => {
await objectDetector.setOptions({maxResults: 1});
verifyGraph(objectDetector, ['maxResults', 1]);
verifyListenersRegistered(objectDetector);
await objectDetector.setOptions({maxResults: 5});
verifyGraph(objectDetector, ['maxResults', 5]);
verifyListenersRegistered(objectDetector);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await objectDetector.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
objectDetector,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */
[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('merges options', async () => {
await objectDetector.setOptions({maxResults: 1});
await objectDetector.setOptions({displayNamesLocale: 'en'});
verifyGraph(objectDetector, ['maxResults', 1]);
verifyGraph(objectDetector, ['displayNamesLocale', 'en']);
});
describe('setOptions() ', () => {
interface TestCase {
optionName: keyof ObjectDetectorOptions;
protoName: string;
customValue: unknown;
defaultValue: unknown;
}
const testCases: TestCase[] = [
{
optionName: 'maxResults',
protoName: 'maxResults',
customValue: 5,
defaultValue: -1
},
{
optionName: 'displayNamesLocale',
protoName: 'displayNamesLocale',
customValue: 'en',
defaultValue: 'en'
},
{
optionName: 'scoreThreshold',
protoName: 'scoreThreshold',
customValue: 0.1,
defaultValue: undefined
},
{
optionName: 'categoryAllowlist',
protoName: 'categoryAllowlistList',
customValue: ['foo'],
defaultValue: []
},
{
optionName: 'categoryDenylist',
protoName: 'categoryDenylistList',
customValue: ['bar'],
defaultValue: []
},
];
for (const testCase of testCases) {
it(`can set ${testCase.optionName}`, async () => {
await objectDetector.setOptions(
{[testCase.optionName]: testCase.customValue});
verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]);
});
it(`can clear ${testCase.optionName}`, async () => {
await objectDetector.setOptions(
{[testCase.optionName]: testCase.customValue});
verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]);
await objectDetector.setOptions({[testCase.optionName]: undefined});
verifyGraph(
objectDetector, [testCase.protoName, testCase.defaultValue]);
});
}
});
it('transforms results', async () => {
const detectionProtos: Uint8Array[] = [];
// Add a detection with all optional properties
let detection = new DetectionProto();
detection.addScore(0.1);
detection.addLabelId(1);
detection.addLabel('foo');
detection.addDisplayName('bar');
let locationData = new LocationData();
let boundingBox = new LocationData.BoundingBox();
boundingBox.setXmin(1);
boundingBox.setYmin(2);
boundingBox.setWidth(3);
boundingBox.setHeight(4);
locationData.setBoundingBox(boundingBox);
detection.setLocationData(locationData);
detectionProtos.push(detection.serializeBinary());
// Add a detection without optional properties
detection = new DetectionProto();
detection.addScore(0.2);
locationData = new LocationData();
boundingBox = new LocationData.BoundingBox();
locationData.setBoundingBox(boundingBox);
detection.setLocationData(locationData);
detectionProtos.push(detection.serializeBinary());
// Pass the test data to our listener
objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(objectDetector);
objectDetector.protoListener!(detectionProtos);
});
// Invoke the object detector
const detections = objectDetector.detect({} as HTMLImageElement);
expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(detections.length).toEqual(2);
expect(detections[0]).toEqual({
categories: [{
score: 0.1,
index: 1,
categoryName: 'foo',
displayName: 'bar',
}],
boundingBox: {originX: 1, originY: 2, width: 3, height: 4}
});
expect(detections[1]).toEqual({
categories: [{
score: 0.2,
index: -1,
categoryName: '',
displayName: '',
}],
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
});
});
});