Open Source the remaining MediaPipe Tasks tests for Web
PiperOrigin-RevId: 493769657
This commit is contained in:
parent
24c8fa97e9
commit
9ae2e43b70
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.audio.audio_classifier.proto;
|
package mediapipe.tasks.audio.audio_classifier.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.audio.audio_embedder.proto;
|
package mediapipe.tasks.audio.audio_embedder.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.text.text_classifier.proto;
|
package mediapipe.tasks.text.text_classifier.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.text.text_embedder.proto;
|
package mediapipe.tasks.text.text_embedder.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto";
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto";
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.hand_detector.proto;
|
package mediapipe.tasks.vision.hand_detector.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto";
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.hand_landmarker.proto;
|
package mediapipe.tasks.vision.hand_landmarker.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto";
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.hand_landmarker.proto;
|
package mediapipe.tasks.vision.hand_landmarker.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto";
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.image_classifier.proto;
|
package mediapipe.tasks.vision.image_classifier.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.image_embedder.proto;
|
package mediapipe.tasks.vision.image_embedder.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.image_segmenter.proto;
|
package mediapipe.tasks.vision.image_segmenter.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto";
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.object_detector.proto;
|
package mediapipe.tasks.vision.object_detector.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto";
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
#
|
#
|
||||||
# This task takes audio data and outputs the classification result.
|
# This task takes audio data and outputs the classification result.
|
||||||
|
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
@ -44,3 +45,23 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "audio_classifier_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"audio_classifier_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":audio_classifier",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "audio_classifier_test",
|
||||||
|
deps = [":audio_classifier_test_lib"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {AudioClassifier} from './audio_classifier';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class AudioClassifierFake extends AudioClassifier implements
|
||||||
|
MediapipeTasksFake {
|
||||||
|
lastSampleRate: number|undefined;
|
||||||
|
calculatorName =
|
||||||
|
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
|
private protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined;
|
||||||
|
private resultProtoVector: ClassificationResult[] = [];
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('timestamped_classifications');
|
||||||
|
this.protoVectorListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addDoubleToStream')
|
||||||
|
.and.callFake((sampleRate, streamName, timestamp) => {
|
||||||
|
if (streamName === 'sample_rate') {
|
||||||
|
this.lastSampleRate = sampleRate;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addAudioToStreamWithShape')
|
||||||
|
.and.callFake(
|
||||||
|
(audioData, numChannels, numSamples, streamName, timestamp) => {
|
||||||
|
expect(numChannels).toBe(1);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'finishProcessing').and.callFake(() => {
|
||||||
|
if (!this.protoVectorListener) return;
|
||||||
|
this.protoVectorListener(this.resultProtoVector.map(
|
||||||
|
classificationResult => classificationResult.serializeBinary()));
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Sets the Protobuf that will be send to the API. */
|
||||||
|
setResults(results: ClassificationResult[]): void {
|
||||||
|
this.resultProtoVector = results;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('AudioClassifier', () => {
|
||||||
|
let audioClassifier: AudioClassifierFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
audioClassifier = new AudioClassifierFake();
|
||||||
|
await audioClassifier.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(audioClassifier);
|
||||||
|
verifyListenersRegistered(audioClassifier);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await audioClassifier.setOptions({maxResults: 1});
|
||||||
|
verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 1]);
|
||||||
|
verifyListenersRegistered(audioClassifier);
|
||||||
|
|
||||||
|
await audioClassifier.setOptions({maxResults: 5});
|
||||||
|
verifyGraph(audioClassifier, [['classifierOptions', 'maxResults'], 5]);
|
||||||
|
verifyListenersRegistered(audioClassifier);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await audioClassifier.setOptions({maxResults: 1});
|
||||||
|
await audioClassifier.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(audioClassifier, [
|
||||||
|
'classifierOptions', {
|
||||||
|
maxResults: 1,
|
||||||
|
displayNamesLocale: 'en',
|
||||||
|
scoreThreshold: undefined,
|
||||||
|
categoryAllowlistList: [],
|
||||||
|
categoryDenylistList: []
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses a sample rate of 48000 by default', async () => {
|
||||||
|
audioClassifier.classify(new Float32Array([]));
|
||||||
|
expect(audioClassifier.lastSampleRate).toEqual(48000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses default sample rate if none provided', async () => {
|
||||||
|
audioClassifier.setDefaultSampleRate(16000);
|
||||||
|
audioClassifier.classify(new Float32Array([]));
|
||||||
|
expect(audioClassifier.lastSampleRate).toEqual(16000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses custom sample rate if provided', async () => {
|
||||||
|
audioClassifier.setDefaultSampleRate(16000);
|
||||||
|
audioClassifier.classify(new Float32Array([]), 44100);
|
||||||
|
expect(audioClassifier.lastSampleRate).toEqual(44100);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
const resultProtoVector: ClassificationResult[] = [];
|
||||||
|
|
||||||
|
let classificationResult = new ClassificationResult();
|
||||||
|
classificationResult.setTimestampMs(0);
|
||||||
|
let classifcations = new Classifications();
|
||||||
|
classifcations.setHeadIndex(1);
|
||||||
|
classifcations.setHeadName('headName');
|
||||||
|
let classificationList = new ClassificationList();
|
||||||
|
let clasification = new Classification();
|
||||||
|
clasification.setIndex(1);
|
||||||
|
clasification.setScore(0.2);
|
||||||
|
clasification.setDisplayName('displayName');
|
||||||
|
clasification.setLabel('categoryName');
|
||||||
|
classificationList.addClassification(clasification);
|
||||||
|
classifcations.setClassificationList(classificationList);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
resultProtoVector.push(classificationResult);
|
||||||
|
|
||||||
|
classificationResult = new ClassificationResult();
|
||||||
|
classificationResult.setTimestampMs(1);
|
||||||
|
classifcations = new Classifications();
|
||||||
|
classificationList = new ClassificationList();
|
||||||
|
clasification = new Classification();
|
||||||
|
clasification.setIndex(2);
|
||||||
|
clasification.setScore(0.3);
|
||||||
|
classificationList.addClassification(clasification);
|
||||||
|
classifcations.setClassificationList(classificationList);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
resultProtoVector.push(classificationResult);
|
||||||
|
|
||||||
|
// Invoke the audio classifier
|
||||||
|
audioClassifier.setResults(resultProtoVector);
|
||||||
|
const results = audioClassifier.classify(new Float32Array([]));
|
||||||
|
expect(results.length).toEqual(2);
|
||||||
|
expect(results[0]).toEqual({
|
||||||
|
classifications: [{
|
||||||
|
categories: [{
|
||||||
|
index: 1,
|
||||||
|
score: 0.2,
|
||||||
|
displayName: 'displayName',
|
||||||
|
categoryName: 'categoryName'
|
||||||
|
}],
|
||||||
|
headIndex: 1,
|
||||||
|
headName: 'headName'
|
||||||
|
}],
|
||||||
|
timestampMs: 0
|
||||||
|
});
|
||||||
|
expect(results[1]).toEqual({
|
||||||
|
classifications: [{
|
||||||
|
categories: [{index: 2, score: 0.3, displayName: '', categoryName: ''}],
|
||||||
|
headIndex: 0,
|
||||||
|
headName: ''
|
||||||
|
}],
|
||||||
|
timestampMs: 1
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('clears results between invocations', async () => {
|
||||||
|
const classificationResult = new ClassificationResult();
|
||||||
|
const classifcations = new Classifications();
|
||||||
|
const classificationList = new ClassificationList();
|
||||||
|
const clasification = new Classification();
|
||||||
|
classificationList.addClassification(clasification);
|
||||||
|
classifcations.setClassificationList(classificationList);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
|
||||||
|
audioClassifier.setResults([classificationResult]);
|
||||||
|
|
||||||
|
// Invoke the gesture recognizer twice
|
||||||
|
const classifications1 = audioClassifier.classify(new Float32Array([]));
|
||||||
|
const classifications2 = audioClassifier.classify(new Float32Array([]));
|
||||||
|
|
||||||
|
// Verify that gestures2 is not a concatenation of all previously returned
|
||||||
|
// gestures.
|
||||||
|
expect(classifications1).toEqual(classifications2);
|
||||||
|
});
|
||||||
|
});
|
|
@ -3,6 +3,7 @@
|
||||||
# This task takes audio input and performs embedding.
|
# This task takes audio input and performs embedding.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -43,3 +44,23 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/core:embedder_options",
|
"//mediapipe/tasks/web/core:embedder_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "audio_embedder_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"audio_embedder_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":audio_embedder",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "audio_embedder_test",
|
||||||
|
deps = [":audio_embedder_test_lib"],
|
||||||
|
)
|
||||||
|
|
185
mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts
Normal file
185
mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Embedding, EmbeddingResult as EmbeddingResultProto, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {AudioEmbedder, AudioEmbedderResult} from './audio_embedder';
|
||||||
|
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class AudioEmbedderFake extends AudioEmbedder implements MediapipeTasksFake {
|
||||||
|
lastSampleRate: number|undefined;
|
||||||
|
calculatorName = 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph';
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
|
||||||
|
protoListener: ((binaryProto: Uint8Array) => void)|undefined;
|
||||||
|
protoVectorListener: ((binaryProtos: Uint8Array[]) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('embeddings_out');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
this.attachListenerSpies[1] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('timestamped_embeddings_out');
|
||||||
|
this.protoVectorListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addDoubleToStream').and.callFake(sampleRate => {
|
||||||
|
this.lastSampleRate = sampleRate;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addAudioToStreamWithShape');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('AudioEmbedder', () => {
|
||||||
|
let audioEmbedder: AudioEmbedderFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
audioEmbedder = new AudioEmbedderFake();
|
||||||
|
await audioEmbedder.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', () => {
|
||||||
|
verifyGraph(audioEmbedder);
|
||||||
|
verifyListenersRegistered(audioEmbedder);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await audioEmbedder.setOptions({quantize: true});
|
||||||
|
verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], true]);
|
||||||
|
verifyListenersRegistered(audioEmbedder);
|
||||||
|
|
||||||
|
await audioEmbedder.setOptions({quantize: undefined});
|
||||||
|
verifyGraph(audioEmbedder, [['embedderOptions', 'quantize'], undefined]);
|
||||||
|
verifyListenersRegistered(audioEmbedder);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await audioEmbedder.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
audioEmbedder,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('combines options', async () => {
|
||||||
|
await audioEmbedder.setOptions({quantize: true});
|
||||||
|
await audioEmbedder.setOptions({l2Normalize: true});
|
||||||
|
verifyGraph(
|
||||||
|
audioEmbedder,
|
||||||
|
['embedderOptions', {'quantize': true, 'l2Normalize': true}]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses a sample rate of 48000 by default', async () => {
|
||||||
|
audioEmbedder.embed(new Float32Array([]));
|
||||||
|
expect(audioEmbedder.lastSampleRate).toEqual(48000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses default sample rate if none provided', async () => {
|
||||||
|
audioEmbedder.setDefaultSampleRate(16000);
|
||||||
|
audioEmbedder.embed(new Float32Array([]));
|
||||||
|
expect(audioEmbedder.lastSampleRate).toEqual(16000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses custom sample rate if provided', async () => {
|
||||||
|
audioEmbedder.setDefaultSampleRate(16000);
|
||||||
|
audioEmbedder.embed(new Float32Array([]), 44100);
|
||||||
|
expect(audioEmbedder.lastSampleRate).toEqual(44100);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('transforms results', () => {
|
||||||
|
const embedding = new Embedding();
|
||||||
|
embedding.setHeadIndex(1);
|
||||||
|
embedding.setHeadName('headName');
|
||||||
|
|
||||||
|
const floatEmbedding = new FloatEmbedding();
|
||||||
|
floatEmbedding.setValuesList([0.1, 0.9]);
|
||||||
|
|
||||||
|
embedding.setFloatEmbedding(floatEmbedding);
|
||||||
|
const resultProto = new EmbeddingResultProto();
|
||||||
|
resultProto.addEmbeddings(embedding);
|
||||||
|
|
||||||
|
function validateEmbeddingResult(
|
||||||
|
expectedEmbeddignResult: AudioEmbedderResult[]) {
|
||||||
|
expect(expectedEmbeddignResult.length).toEqual(1);
|
||||||
|
|
||||||
|
const [embeddingResult] = expectedEmbeddignResult;
|
||||||
|
expect(embeddingResult.embeddings.length).toEqual(1);
|
||||||
|
expect(embeddingResult.embeddings[0])
|
||||||
|
.toEqual(
|
||||||
|
{floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'});
|
||||||
|
}
|
||||||
|
|
||||||
|
it('from embeddings strem', async () => {
|
||||||
|
audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(audioEmbedder);
|
||||||
|
// Pass the test data to our listener
|
||||||
|
audioEmbedder.protoListener!(resultProto.serializeBinary());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the audio embedder
|
||||||
|
const embeddingResults = audioEmbedder.embed(new Float32Array([]));
|
||||||
|
validateEmbeddingResult(embeddingResults);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('from timestamped embeddgins stream', async () => {
|
||||||
|
audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(audioEmbedder);
|
||||||
|
// Pass the test data to our listener
|
||||||
|
audioEmbedder.protoVectorListener!([resultProto.serializeBinary()]);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the audio embedder
|
||||||
|
const embeddingResults = audioEmbedder.embed(new Float32Array([]), 42);
|
||||||
|
validateEmbeddingResult(embeddingResults);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -4,6 +4,7 @@
|
||||||
# BERT-based text classification).
|
# BERT-based text classification).
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -45,3 +46,24 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "text_classifier_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"text_classifier_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":text_classifier",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "text_classifier_test",
|
||||||
|
deps = [":text_classifier_test_lib"],
|
||||||
|
)
|
||||||
|
|
152
mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts
Normal file
152
mediapipe/tasks/web/text/text_classifier/text_classifier_test.ts
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {TextClassifier} from './text_classifier';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class TextClassifierFake extends TextClassifier implements MediapipeTasksFake {
|
||||||
|
calculatorName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
protoListener: ((binaryProto: Uint8Array) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('classifications_out');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('TextClassifier', () => {
|
||||||
|
let textClassifier: TextClassifierFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
textClassifier = new TextClassifierFake();
|
||||||
|
await textClassifier.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(textClassifier);
|
||||||
|
verifyListenersRegistered(textClassifier);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await textClassifier.setOptions({maxResults: 1});
|
||||||
|
verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 1]);
|
||||||
|
verifyListenersRegistered(textClassifier);
|
||||||
|
|
||||||
|
await textClassifier.setOptions({maxResults: 5});
|
||||||
|
verifyGraph(textClassifier, [['classifierOptions', 'maxResults'], 5]);
|
||||||
|
verifyListenersRegistered(textClassifier);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await textClassifier.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
textClassifier,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */
|
||||||
|
[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await textClassifier.setOptions({maxResults: 1});
|
||||||
|
await textClassifier.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(textClassifier, [
|
||||||
|
'classifierOptions', {
|
||||||
|
maxResults: 1,
|
||||||
|
displayNamesLocale: 'en',
|
||||||
|
scoreThreshold: undefined,
|
||||||
|
categoryAllowlistList: [],
|
||||||
|
categoryDenylistList: []
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
const classificationResult = new ClassificationResult();
|
||||||
|
const classifcations = new Classifications();
|
||||||
|
classifcations.setHeadIndex(1);
|
||||||
|
classifcations.setHeadName('headName');
|
||||||
|
const classificationList = new ClassificationList();
|
||||||
|
const clasification = new Classification();
|
||||||
|
clasification.setIndex(1);
|
||||||
|
clasification.setScore(0.2);
|
||||||
|
clasification.setDisplayName('displayName');
|
||||||
|
clasification.setLabel('categoryName');
|
||||||
|
classificationList.addClassification(clasification);
|
||||||
|
classifcations.setClassificationList(classificationList);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
textClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(textClassifier);
|
||||||
|
textClassifier.protoListener!(classificationResult.serializeBinary());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the text classifier
|
||||||
|
const result = textClassifier.classify('foo');
|
||||||
|
|
||||||
|
expect(textClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(result).toEqual({
|
||||||
|
classifications: [{
|
||||||
|
categories: [{
|
||||||
|
index: 1,
|
||||||
|
score: 0.2,
|
||||||
|
displayName: 'displayName',
|
||||||
|
categoryName: 'categoryName'
|
||||||
|
}],
|
||||||
|
headIndex: 1,
|
||||||
|
headName: 'headName'
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -4,6 +4,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -44,3 +45,23 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/core:embedder_options",
|
"//mediapipe/tasks/web/core:embedder_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "text_embedder_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"text_embedder_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":text_embedder",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "text_embedder_test",
|
||||||
|
deps = [":text_embedder_test_lib"],
|
||||||
|
)
|
||||||
|
|
165
mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts
Normal file
165
mediapipe/tasks/web/text/text_embedder/text_embedder_test.ts
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Embedding, EmbeddingResult, FloatEmbedding, QuantizedEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {TextEmbedder} from './text_embedder';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class TextEmbedderFake extends TextEmbedder implements MediapipeTasksFake {
|
||||||
|
calculatorName = 'mediapipe.tasks.text.text_embedder.TextEmbedderGraph';
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
protoListener: ((binaryProtos: Uint8Array) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('embeddings_out');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('TextEmbedder', () => {
|
||||||
|
let textEmbedder: TextEmbedderFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
textEmbedder = new TextEmbedderFake();
|
||||||
|
await textEmbedder.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(textEmbedder);
|
||||||
|
verifyListenersRegistered(textEmbedder);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await textEmbedder.setOptions({quantize: true});
|
||||||
|
verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], true]);
|
||||||
|
verifyListenersRegistered(textEmbedder);
|
||||||
|
|
||||||
|
await textEmbedder.setOptions({quantize: undefined});
|
||||||
|
verifyGraph(textEmbedder, [['embedderOptions', 'quantize'], undefined]);
|
||||||
|
verifyListenersRegistered(textEmbedder);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await textEmbedder.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
textEmbedder,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('combines options', async () => {
|
||||||
|
await textEmbedder.setOptions({quantize: true});
|
||||||
|
await textEmbedder.setOptions({l2Normalize: true});
|
||||||
|
verifyGraph(
|
||||||
|
textEmbedder,
|
||||||
|
['embedderOptions', {'quantize': true, 'l2Normalize': true}]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
const embedding = new Embedding();
|
||||||
|
embedding.setHeadIndex(1);
|
||||||
|
embedding.setHeadName('headName');
|
||||||
|
|
||||||
|
const floatEmbedding = new FloatEmbedding();
|
||||||
|
floatEmbedding.setValuesList([0.1, 0.9]);
|
||||||
|
|
||||||
|
embedding.setFloatEmbedding(floatEmbedding);
|
||||||
|
const resultProto = new EmbeddingResult();
|
||||||
|
resultProto.addEmbeddings(embedding);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(textEmbedder);
|
||||||
|
textEmbedder.protoListener!(resultProto.serializeBinary());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the text embedder
|
||||||
|
const embeddingResult = textEmbedder.embed('foo');
|
||||||
|
|
||||||
|
expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(embeddingResult.embeddings.length).toEqual(1);
|
||||||
|
expect(embeddingResult.embeddings[0])
|
||||||
|
.toEqual(
|
||||||
|
{floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms custom quantized values', async () => {
|
||||||
|
const embedding = new Embedding();
|
||||||
|
embedding.setHeadIndex(1);
|
||||||
|
embedding.setHeadName('headName');
|
||||||
|
|
||||||
|
const quantizedEmbedding = new QuantizedEmbedding();
|
||||||
|
const quantizedValues = new Uint8Array([1, 2, 3]);
|
||||||
|
quantizedEmbedding.setValues(quantizedValues);
|
||||||
|
|
||||||
|
embedding.setQuantizedEmbedding(quantizedEmbedding);
|
||||||
|
const resultProto = new EmbeddingResult();
|
||||||
|
resultProto.addEmbeddings(embedding);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
textEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(textEmbedder);
|
||||||
|
textEmbedder.protoListener!(resultProto.serializeBinary());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the text embedder
|
||||||
|
const embeddingsResult = textEmbedder.embed('foo');
|
||||||
|
|
||||||
|
expect(textEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(embeddingsResult.embeddings.length).toEqual(1);
|
||||||
|
expect(embeddingsResult.embeddings[0]).toEqual({
|
||||||
|
quantizedEmbedding: new Uint8Array([1, 2, 3]),
|
||||||
|
headIndex: 1,
|
||||||
|
headName: 'headName'
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -1,5 +1,6 @@
|
||||||
# This package contains options shared by all MediaPipe Vision Tasks for Web.
|
# This package contains options shared by all MediaPipe Vision Tasks for Web.
|
||||||
|
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
@ -22,3 +23,20 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "vision_task_runner_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["vision_task_runner.test.ts"],
|
||||||
|
deps = [
|
||||||
|
":vision_task_runner",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "vision_task_runner_test",
|
||||||
|
deps = [":vision_task_runner_test_lib"],
|
||||||
|
)
|
||||||
|
|
99
mediapipe/tasks/web/vision/core/vision_task_runner.test.ts
Normal file
99
mediapipe/tasks/web/vision/core/vision_task_runner.test.ts
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||||
|
import {createSpyWasmModule} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
import {ImageSource} from '../../../../web/graph_runner/graph_runner';
|
||||||
|
|
||||||
|
import {VisionTaskRunner} from './vision_task_runner';
|
||||||
|
|
||||||
|
class VisionTaskRunnerFake extends VisionTaskRunner<void> {
|
||||||
|
baseOptions = new BaseOptionsProto();
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override process(): void {}
|
||||||
|
|
||||||
|
override processImageData(image: ImageSource): void {
|
||||||
|
super.processImageData(image);
|
||||||
|
}
|
||||||
|
|
||||||
|
override processVideoData(imageFrame: ImageSource, timestamp: number): void {
|
||||||
|
super.processVideoData(imageFrame, timestamp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('VisionTaskRunner', () => {
|
||||||
|
const streamMode = {
|
||||||
|
modelAsset: undefined,
|
||||||
|
useStreamMode: true,
|
||||||
|
acceleration: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
const imageMode = {
|
||||||
|
modelAsset: undefined,
|
||||||
|
useStreamMode: false,
|
||||||
|
acceleration: undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
let visionTaskRunner: VisionTaskRunnerFake;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
visionTaskRunner = new VisionTaskRunnerFake();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can enable image mode', async () => {
|
||||||
|
await visionTaskRunner.setOptions({runningMode: 'image'});
|
||||||
|
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can enable video mode', async () => {
|
||||||
|
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||||
|
expect(visionTaskRunner.baseOptions.toObject()).toEqual(streamMode);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can clear running mode', async () => {
|
||||||
|
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||||
|
|
||||||
|
// Clear running mode
|
||||||
|
await visionTaskRunner.setOptions({runningMode: undefined});
|
||||||
|
expect(visionTaskRunner.baseOptions.toObject()).toEqual(imageMode);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('cannot process images with video mode', async () => {
|
||||||
|
await visionTaskRunner.setOptions({runningMode: 'video'});
|
||||||
|
expect(() => {
|
||||||
|
visionTaskRunner.processImageData({} as HTMLImageElement);
|
||||||
|
}).toThrowError(/Task is not initialized with image mode./);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('cannot process video with image mode', async () => {
|
||||||
|
// Use default for `useStreamMode`
|
||||||
|
expect(() => {
|
||||||
|
visionTaskRunner.processVideoData({} as HTMLImageElement, 42);
|
||||||
|
}).toThrowError(/Task is not initialized with video mode./);
|
||||||
|
|
||||||
|
// Explicitly set to image mode
|
||||||
|
await visionTaskRunner.setOptions({runningMode: 'image'});
|
||||||
|
expect(() => {
|
||||||
|
visionTaskRunner.processVideoData({} as HTMLImageElement, 42);
|
||||||
|
}).toThrowError(/Task is not initialized with video mode./);
|
||||||
|
});
|
||||||
|
});
|
|
@ -4,6 +4,7 @@
|
||||||
# the detection results for one or more gesture categories, using Gesture Recognizer.
|
# the detection results for one or more gesture categories, using Gesture Recognizer.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -52,3 +53,27 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "gesture_recognizer_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"gesture_recognizer_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":gesture_recognizer",
|
||||||
|
":gesture_recognizer_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "gesture_recognizer_test",
|
||||||
|
tags = ["nomsan"],
|
||||||
|
deps = [":gesture_recognizer_test_lib"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,307 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
|
||||||
|
import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
type ProtoListener = ((binaryProtos: Uint8Array[]) => void);
|
||||||
|
|
||||||
|
function createHandednesses(): Uint8Array[] {
|
||||||
|
const handsProto = new ClassificationList();
|
||||||
|
const classification = new Classification();
|
||||||
|
classification.setScore(0.1);
|
||||||
|
classification.setIndex(1);
|
||||||
|
classification.setLabel('handedness_label');
|
||||||
|
classification.setDisplayName('handedness_display_name');
|
||||||
|
handsProto.addClassification(classification);
|
||||||
|
return [handsProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
function createGestures(): Uint8Array[] {
|
||||||
|
const gesturesProto = new ClassificationList();
|
||||||
|
const classification = new Classification();
|
||||||
|
classification.setScore(0.2);
|
||||||
|
classification.setIndex(2);
|
||||||
|
classification.setLabel('gesture_label');
|
||||||
|
classification.setDisplayName('gesture_display_name');
|
||||||
|
gesturesProto.addClassification(classification);
|
||||||
|
return [gesturesProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
function createLandmarks(): Uint8Array[] {
|
||||||
|
const handLandmarksProto = new NormalizedLandmarkList();
|
||||||
|
const landmark = new NormalizedLandmark();
|
||||||
|
landmark.setX(0.3);
|
||||||
|
landmark.setY(0.4);
|
||||||
|
landmark.setZ(0.5);
|
||||||
|
handLandmarksProto.addLandmark(landmark);
|
||||||
|
return [handLandmarksProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
function createWorldLandmarks(): Uint8Array[] {
|
||||||
|
const handLandmarksProto = new LandmarkList();
|
||||||
|
const landmark = new Landmark();
|
||||||
|
landmark.setX(21);
|
||||||
|
landmark.setY(22);
|
||||||
|
landmark.setZ(23);
|
||||||
|
handLandmarksProto.addLandmark(landmark);
|
||||||
|
return [handLandmarksProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
class GestureRecognizerFake extends GestureRecognizer implements
|
||||||
|
MediapipeTasksFake {
|
||||||
|
calculatorName =
|
||||||
|
'mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
listeners = new Map<string, ProtoListener>();
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toMatch(
|
||||||
|
/(hand_landmarks|world_hand_landmarks|handedness|hand_gestures)/);
|
||||||
|
this.listeners.set(stream, listener);
|
||||||
|
});
|
||||||
|
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||||
|
spyOn(this.graphRunner, 'addProtoToStream');
|
||||||
|
}
|
||||||
|
|
||||||
|
getGraphRunner(): GraphRunnerImageLib {
|
||||||
|
return this.graphRunner;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('GestureRecognizer', () => {
|
||||||
|
let gestureRecognizer: GestureRecognizerFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
gestureRecognizer = new GestureRecognizerFake();
|
||||||
|
await gestureRecognizer.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(gestureRecognizer);
|
||||||
|
verifyListenersRegistered(gestureRecognizer);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await gestureRecognizer.setOptions({numHands: 1});
|
||||||
|
verifyGraph(gestureRecognizer, [
|
||||||
|
['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1
|
||||||
|
]);
|
||||||
|
verifyListenersRegistered(gestureRecognizer);
|
||||||
|
|
||||||
|
await gestureRecognizer.setOptions({numHands: 5});
|
||||||
|
verifyGraph(gestureRecognizer, [
|
||||||
|
['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 5
|
||||||
|
]);
|
||||||
|
verifyListenersRegistered(gestureRecognizer);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await gestureRecognizer.setOptions({numHands: 1});
|
||||||
|
await gestureRecognizer.setOptions({minHandDetectionConfidence: 0.5});
|
||||||
|
verifyGraph(gestureRecognizer, [
|
||||||
|
['handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'], 1
|
||||||
|
]);
|
||||||
|
verifyGraph(gestureRecognizer, [
|
||||||
|
[
|
||||||
|
'handLandmarkerGraphOptions', 'handDetectorGraphOptions',
|
||||||
|
'minDetectionConfidence'
|
||||||
|
],
|
||||||
|
0.5
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setOptions() ', () => {
|
||||||
|
interface TestCase {
|
||||||
|
optionPath: [keyof GestureRecognizerOptions, ...string[]];
|
||||||
|
fieldPath: string[];
|
||||||
|
customValue: unknown;
|
||||||
|
defaultValue: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
const testCases: TestCase[] = [
|
||||||
|
{
|
||||||
|
optionPath: ['numHands'],
|
||||||
|
fieldPath: [
|
||||||
|
'handLandmarkerGraphOptions', 'handDetectorGraphOptions', 'numHands'
|
||||||
|
],
|
||||||
|
customValue: 5,
|
||||||
|
defaultValue: 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['minHandDetectionConfidence'],
|
||||||
|
fieldPath: [
|
||||||
|
'handLandmarkerGraphOptions', 'handDetectorGraphOptions',
|
||||||
|
'minDetectionConfidence'
|
||||||
|
],
|
||||||
|
customValue: 0.1,
|
||||||
|
defaultValue: 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['minHandPresenceConfidence'],
|
||||||
|
fieldPath: [
|
||||||
|
'handLandmarkerGraphOptions', 'handLandmarksDetectorGraphOptions',
|
||||||
|
'minDetectionConfidence'
|
||||||
|
],
|
||||||
|
customValue: 0.2,
|
||||||
|
defaultValue: 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['minTrackingConfidence'],
|
||||||
|
fieldPath: ['handLandmarkerGraphOptions', 'minTrackingConfidence'],
|
||||||
|
customValue: 0.3,
|
||||||
|
defaultValue: 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['cannedGesturesClassifierOptions', 'scoreThreshold'],
|
||||||
|
fieldPath: [
|
||||||
|
'handGestureRecognizerGraphOptions',
|
||||||
|
'cannedGestureClassifierGraphOptions', 'classifierOptions',
|
||||||
|
'scoreThreshold'
|
||||||
|
],
|
||||||
|
customValue: 0.4,
|
||||||
|
defaultValue: undefined
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['customGesturesClassifierOptions', 'scoreThreshold'],
|
||||||
|
fieldPath: [
|
||||||
|
'handGestureRecognizerGraphOptions',
|
||||||
|
'customGestureClassifierGraphOptions', 'classifierOptions',
|
||||||
|
'scoreThreshold'
|
||||||
|
],
|
||||||
|
customValue: 0.5,
|
||||||
|
defaultValue: undefined,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
/** Creates an options object that can be passed to setOptions() */
|
||||||
|
function createOptions(
|
||||||
|
path: string[], value: unknown): GestureRecognizerOptions {
|
||||||
|
const options: Record<string, unknown> = {};
|
||||||
|
let currentLevel = options;
|
||||||
|
for (const element of path.slice(0, -1)) {
|
||||||
|
currentLevel[element] = {};
|
||||||
|
currentLevel = currentLevel[element] as Record<string, unknown>;
|
||||||
|
}
|
||||||
|
currentLevel[path[path.length - 1]] = value;
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const testCase of testCases) {
|
||||||
|
it(`uses default value for ${testCase.optionPath[0]}`, async () => {
|
||||||
|
verifyGraph(
|
||||||
|
gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it(`can set ${testCase.optionPath[0]}`, async () => {
|
||||||
|
await gestureRecognizer.setOptions(
|
||||||
|
createOptions(testCase.optionPath, testCase.customValue));
|
||||||
|
verifyGraph(
|
||||||
|
gestureRecognizer, [testCase.fieldPath, testCase.customValue]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it(`can clear ${testCase.optionPath[0]}`, async () => {
|
||||||
|
await gestureRecognizer.setOptions(
|
||||||
|
createOptions(testCase.optionPath, testCase.customValue));
|
||||||
|
verifyGraph(
|
||||||
|
gestureRecognizer, [testCase.fieldPath, testCase.customValue]);
|
||||||
|
|
||||||
|
await gestureRecognizer.setOptions(
|
||||||
|
createOptions(testCase.optionPath, undefined));
|
||||||
|
verifyGraph(
|
||||||
|
gestureRecognizer, [testCase.fieldPath, testCase.defaultValue]);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
// Pass the test data to our listener
|
||||||
|
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(gestureRecognizer);
|
||||||
|
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
|
||||||
|
gestureRecognizer.listeners.get('world_hand_landmarks')!
|
||||||
|
(createWorldLandmarks());
|
||||||
|
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
|
||||||
|
gestureRecognizer.listeners.get('hand_gestures')!(createGestures());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the gesture recognizer
|
||||||
|
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
|
||||||
|
expect(gestureRecognizer.getGraphRunner().addProtoToStream)
|
||||||
|
.toHaveBeenCalledTimes(1);
|
||||||
|
expect(gestureRecognizer.getGraphRunner().addGpuBufferAsImageToStream)
|
||||||
|
.toHaveBeenCalledTimes(1);
|
||||||
|
expect(gestureRecognizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
|
||||||
|
expect(gestures).toEqual({
|
||||||
|
'gestures': [[{
|
||||||
|
'score': 0.2,
|
||||||
|
'index': 2,
|
||||||
|
'categoryName': 'gesture_label',
|
||||||
|
'displayName': 'gesture_display_name'
|
||||||
|
}]],
|
||||||
|
'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]],
|
||||||
|
'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]],
|
||||||
|
'handednesses': [[{
|
||||||
|
'score': 0.1,
|
||||||
|
'index': 1,
|
||||||
|
'categoryName': 'handedness_label',
|
||||||
|
'displayName': 'handedness_display_name'
|
||||||
|
}]]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('clears results between invoations', async () => {
|
||||||
|
// Pass the test data to our listener
|
||||||
|
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
|
||||||
|
gestureRecognizer.listeners.get('world_hand_landmarks')!
|
||||||
|
(createWorldLandmarks());
|
||||||
|
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
|
||||||
|
gestureRecognizer.listeners.get('hand_gestures')!(createGestures());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the gesture recognizer twice
|
||||||
|
const gestures1 = gestureRecognizer.recognize({} as HTMLImageElement);
|
||||||
|
const gestures2 = gestureRecognizer.recognize({} as HTMLImageElement);
|
||||||
|
|
||||||
|
// Verify that gestures2 is not a concatenation of all previously returned
|
||||||
|
// gestures.
|
||||||
|
expect(gestures2).toEqual(gestures1);
|
||||||
|
});
|
||||||
|
});
|
|
@ -4,6 +4,7 @@
|
||||||
# the detection results for one or more hand categories, using Hand Landmarker.
|
# the detection results for one or more hand categories, using Hand Landmarker.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -47,3 +48,27 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "hand_landmarker_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"hand_landmarker_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":hand_landmarker",
|
||||||
|
":hand_landmarker_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "hand_landmarker_test",
|
||||||
|
tags = ["nomsan"],
|
||||||
|
deps = [":hand_landmarker_test_lib"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
|
||||||
|
import {GraphRunnerImageLib} from '../../../../tasks/web/core/task_runner';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {HandLandmarker} from './hand_landmarker';
|
||||||
|
import {HandLandmarkerOptions} from './hand_landmarker_options';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
type ProtoListener = ((binaryProtos: Uint8Array[]) => void);
|
||||||
|
|
||||||
|
function createHandednesses(): Uint8Array[] {
|
||||||
|
const handsProto = new ClassificationList();
|
||||||
|
const classification = new Classification();
|
||||||
|
classification.setScore(0.1);
|
||||||
|
classification.setIndex(1);
|
||||||
|
classification.setLabel('handedness_label');
|
||||||
|
classification.setDisplayName('handedness_display_name');
|
||||||
|
handsProto.addClassification(classification);
|
||||||
|
return [handsProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
function createLandmarks(): Uint8Array[] {
|
||||||
|
const handLandmarksProto = new NormalizedLandmarkList();
|
||||||
|
const landmark = new NormalizedLandmark();
|
||||||
|
landmark.setX(0.3);
|
||||||
|
landmark.setY(0.4);
|
||||||
|
landmark.setZ(0.5);
|
||||||
|
handLandmarksProto.addLandmark(landmark);
|
||||||
|
return [handLandmarksProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
function createWorldLandmarks(): Uint8Array[] {
|
||||||
|
const handLandmarksProto = new LandmarkList();
|
||||||
|
const landmark = new Landmark();
|
||||||
|
landmark.setX(21);
|
||||||
|
landmark.setY(22);
|
||||||
|
landmark.setZ(23);
|
||||||
|
handLandmarksProto.addLandmark(landmark);
|
||||||
|
return [handLandmarksProto.serializeBinary()];
|
||||||
|
}
|
||||||
|
|
||||||
|
class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake {
|
||||||
|
calculatorName = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
listeners = new Map<string, ProtoListener>();
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toMatch(
|
||||||
|
/(hand_landmarks|world_hand_landmarks|handedness|hand_hands)/);
|
||||||
|
this.listeners.set(stream, listener);
|
||||||
|
});
|
||||||
|
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||||
|
spyOn(this.graphRunner, 'addProtoToStream');
|
||||||
|
}
|
||||||
|
|
||||||
|
getGraphRunner(): GraphRunnerImageLib {
|
||||||
|
return this.graphRunner;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('HandLandmarker', () => {
|
||||||
|
let handLandmarker: HandLandmarkerFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
handLandmarker = new HandLandmarkerFake();
|
||||||
|
await handLandmarker.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(handLandmarker);
|
||||||
|
verifyListenersRegistered(handLandmarker);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
verifyListenersRegistered(handLandmarker);
|
||||||
|
|
||||||
|
await handLandmarker.setOptions({numHands: 1});
|
||||||
|
verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 1]);
|
||||||
|
verifyListenersRegistered(handLandmarker);
|
||||||
|
|
||||||
|
await handLandmarker.setOptions({numHands: 5});
|
||||||
|
verifyGraph(handLandmarker, [['handDetectorGraphOptions', 'numHands'], 5]);
|
||||||
|
verifyListenersRegistered(handLandmarker);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await handLandmarker.setOptions({numHands: 1});
|
||||||
|
await handLandmarker.setOptions({minHandDetectionConfidence: 0.5});
|
||||||
|
verifyGraph(handLandmarker, [
|
||||||
|
'handDetectorGraphOptions',
|
||||||
|
{numHands: 1, baseOptions: undefined, minDetectionConfidence: 0.5}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setOptions() ', () => {
|
||||||
|
interface TestCase {
|
||||||
|
optionPath: [keyof HandLandmarkerOptions, ...string[]];
|
||||||
|
fieldPath: string[];
|
||||||
|
customValue: unknown;
|
||||||
|
defaultValue: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
const testCases: TestCase[] = [
|
||||||
|
{
|
||||||
|
optionPath: ['numHands'],
|
||||||
|
fieldPath: ['handDetectorGraphOptions', 'numHands'],
|
||||||
|
customValue: 5,
|
||||||
|
defaultValue: 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['minHandDetectionConfidence'],
|
||||||
|
fieldPath: ['handDetectorGraphOptions', 'minDetectionConfidence'],
|
||||||
|
customValue: 0.1,
|
||||||
|
defaultValue: 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['minHandPresenceConfidence'],
|
||||||
|
fieldPath:
|
||||||
|
['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'],
|
||||||
|
customValue: 0.2,
|
||||||
|
defaultValue: 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionPath: ['minTrackingConfidence'],
|
||||||
|
fieldPath: ['minTrackingConfidence'],
|
||||||
|
customValue: 0.3,
|
||||||
|
defaultValue: 0.5
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
/** Creates an options object that can be passed to setOptions() */
|
||||||
|
function createOptions(
|
||||||
|
path: string[], value: unknown): HandLandmarkerOptions {
|
||||||
|
const options: Record<string, unknown> = {};
|
||||||
|
let currentLevel = options;
|
||||||
|
for (const element of path.slice(0, -1)) {
|
||||||
|
currentLevel[element] = {};
|
||||||
|
currentLevel = currentLevel[element] as Record<string, unknown>;
|
||||||
|
}
|
||||||
|
currentLevel[path[path.length - 1]] = value;
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const testCase of testCases) {
|
||||||
|
it(`uses default value for ${testCase.optionPath[0]}`, async () => {
|
||||||
|
verifyGraph(
|
||||||
|
handLandmarker, [testCase.fieldPath, testCase.defaultValue]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it(`can set ${testCase.optionPath[0]}`, async () => {
|
||||||
|
await handLandmarker.setOptions(
|
||||||
|
createOptions(testCase.optionPath, testCase.customValue));
|
||||||
|
verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it(`can clear ${testCase.optionPath[0]}`, async () => {
|
||||||
|
await handLandmarker.setOptions(
|
||||||
|
createOptions(testCase.optionPath, testCase.customValue));
|
||||||
|
verifyGraph(handLandmarker, [testCase.fieldPath, testCase.customValue]);
|
||||||
|
|
||||||
|
await handLandmarker.setOptions(
|
||||||
|
createOptions(testCase.optionPath, undefined));
|
||||||
|
verifyGraph(
|
||||||
|
handLandmarker, [testCase.fieldPath, testCase.defaultValue]);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
// Pass the test data to our listener
|
||||||
|
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(handLandmarker);
|
||||||
|
handLandmarker.listeners.get('hand_landmarks')!(createLandmarks());
|
||||||
|
handLandmarker.listeners.get('world_hand_landmarks')!
|
||||||
|
(createWorldLandmarks());
|
||||||
|
handLandmarker.listeners.get('handedness')!(createHandednesses());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the hand landmarker
|
||||||
|
const landmarks = handLandmarker.detect({} as HTMLImageElement);
|
||||||
|
expect(handLandmarker.getGraphRunner().addProtoToStream)
|
||||||
|
.toHaveBeenCalledTimes(1);
|
||||||
|
expect(handLandmarker.getGraphRunner().addGpuBufferAsImageToStream)
|
||||||
|
.toHaveBeenCalledTimes(1);
|
||||||
|
expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
|
||||||
|
expect(landmarks).toEqual({
|
||||||
|
'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]],
|
||||||
|
'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]],
|
||||||
|
'handednesses': [[{
|
||||||
|
'score': 0.1,
|
||||||
|
'index': 1,
|
||||||
|
'categoryName': 'handedness_label',
|
||||||
|
'displayName': 'handedness_display_name'
|
||||||
|
}]]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('clears results between invoations', async () => {
|
||||||
|
// Pass the test data to our listener
|
||||||
|
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
handLandmarker.listeners.get('hand_landmarks')!(createLandmarks());
|
||||||
|
handLandmarker.listeners.get('world_hand_landmarks')!
|
||||||
|
(createWorldLandmarks());
|
||||||
|
handLandmarker.listeners.get('handedness')!(createHandednesses());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the hand landmarker twice
|
||||||
|
const landmarks1 = handLandmarker.detect({} as HTMLImageElement);
|
||||||
|
const landmarks2 = handLandmarker.detect({} as HTMLImageElement);
|
||||||
|
|
||||||
|
// Verify that hands2 is not a concatenation of all previously returned
|
||||||
|
// hands.
|
||||||
|
expect(landmarks1).toEqual(landmarks2);
|
||||||
|
});
|
||||||
|
});
|
|
@ -3,6 +3,7 @@
|
||||||
# This task takes video or image frames and outputs the classification result.
|
# This task takes video or image frames and outputs the classification result.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -44,3 +45,26 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image_classifier_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"image_classifier_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":image_classifier",
|
||||||
|
":image_classifier_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "image_classifier_test",
|
||||||
|
tags = ["nomsan"],
|
||||||
|
deps = [":image_classifier_test_lib"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,150 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
|
||||||
|
import {ClassificationResult, Classifications} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {ImageClassifier} from './image_classifier';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class ImageClassifierFake extends ImageClassifier implements
|
||||||
|
MediapipeTasksFake {
|
||||||
|
calculatorName =
|
||||||
|
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
protoListener: ((binaryProto: Uint8Array) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('classifications');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('ImageClassifier', () => {
|
||||||
|
let imageClassifier: ImageClassifierFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
imageClassifier = new ImageClassifierFake();
|
||||||
|
await imageClassifier.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(imageClassifier);
|
||||||
|
verifyListenersRegistered(imageClassifier);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await imageClassifier.setOptions({maxResults: 1});
|
||||||
|
verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]);
|
||||||
|
verifyListenersRegistered(imageClassifier);
|
||||||
|
|
||||||
|
await imageClassifier.setOptions({maxResults: 5});
|
||||||
|
verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 5]);
|
||||||
|
verifyListenersRegistered(imageClassifier);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await imageClassifier.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
imageClassifier,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await imageClassifier.setOptions({maxResults: 1});
|
||||||
|
await imageClassifier.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(imageClassifier, [['classifierOptions', 'maxResults'], 1]);
|
||||||
|
verifyGraph(
|
||||||
|
imageClassifier, [['classifierOptions', 'displayNamesLocale'], 'en']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
const classificationResult = new ClassificationResult();
|
||||||
|
const classifcations = new Classifications();
|
||||||
|
classifcations.setHeadIndex(1);
|
||||||
|
classifcations.setHeadName('headName');
|
||||||
|
const classificationList = new ClassificationList();
|
||||||
|
const clasification = new Classification();
|
||||||
|
clasification.setIndex(1);
|
||||||
|
clasification.setScore(0.2);
|
||||||
|
clasification.setDisplayName('displayName');
|
||||||
|
clasification.setLabel('categoryName');
|
||||||
|
classificationList.addClassification(clasification);
|
||||||
|
classifcations.setClassificationList(classificationList);
|
||||||
|
classificationResult.addClassifications(classifcations);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageClassifier.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(imageClassifier);
|
||||||
|
imageClassifier.protoListener!(classificationResult.serializeBinary());
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the image classifier
|
||||||
|
const result = imageClassifier.classify({} as HTMLImageElement);
|
||||||
|
|
||||||
|
expect(imageClassifier.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(result).toEqual({
|
||||||
|
classifications: [{
|
||||||
|
categories: [{
|
||||||
|
index: 1,
|
||||||
|
score: 0.2,
|
||||||
|
displayName: 'displayName',
|
||||||
|
categoryName: 'categoryName'
|
||||||
|
}],
|
||||||
|
headIndex: 1,
|
||||||
|
headName: 'headName'
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -3,6 +3,7 @@
|
||||||
# This task performs embedding extraction on images.
|
# This task performs embedding extraction on images.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -45,3 +46,23 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "image_embedder_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"image_embedder_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":image_embedder",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "image_embedder_test",
|
||||||
|
deps = [":image_embedder_test_lib"],
|
||||||
|
)
|
||||||
|
|
158
mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts
Normal file
158
mediapipe/tasks/web/vision/image_embedder/image_embedder_test.ts
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Embedding, EmbeddingResult, FloatEmbedding} from '../../../../tasks/cc/components/containers/proto/embeddings_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {ImageEmbedder} from './image_embedder';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class ImageEmbedderFake extends ImageEmbedder implements MediapipeTasksFake {
|
||||||
|
calculatorName = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph';
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
protoListener: ((binaryProtos: Uint8Array) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('embeddings_out');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('ImageEmbedder', () => {
|
||||||
|
let imageEmbedder: ImageEmbedderFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
imageEmbedder = new ImageEmbedderFake();
|
||||||
|
await imageEmbedder.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(imageEmbedder);
|
||||||
|
verifyListenersRegistered(imageEmbedder);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
verifyListenersRegistered(imageEmbedder);
|
||||||
|
|
||||||
|
await imageEmbedder.setOptions({quantize: true});
|
||||||
|
verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], true]);
|
||||||
|
verifyListenersRegistered(imageEmbedder);
|
||||||
|
|
||||||
|
await imageEmbedder.setOptions({quantize: undefined});
|
||||||
|
verifyGraph(imageEmbedder, [['embedderOptions', 'quantize'], undefined]);
|
||||||
|
verifyListenersRegistered(imageEmbedder);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await imageEmbedder.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
imageEmbedder,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('overrides options', async () => {
|
||||||
|
await imageEmbedder.setOptions({quantize: true});
|
||||||
|
await imageEmbedder.setOptions({l2Normalize: true});
|
||||||
|
verifyGraph(
|
||||||
|
imageEmbedder,
|
||||||
|
['embedderOptions', {'quantize': true, 'l2Normalize': true}]);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('transforms result', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
const floatEmbedding = new FloatEmbedding();
|
||||||
|
floatEmbedding.setValuesList([0.1, 0.9]);
|
||||||
|
|
||||||
|
const embedding = new Embedding();
|
||||||
|
embedding.setHeadIndex(1);
|
||||||
|
embedding.setHeadName('headName');
|
||||||
|
embedding.setFloatEmbedding(floatEmbedding);
|
||||||
|
|
||||||
|
const resultProto = new EmbeddingResult();
|
||||||
|
resultProto.addEmbeddings(embedding);
|
||||||
|
resultProto.setTimestampMs(42);
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(imageEmbedder);
|
||||||
|
imageEmbedder.protoListener!(resultProto.serializeBinary());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('for image mode', async () => {
|
||||||
|
// Invoke the image embedder
|
||||||
|
const embeddingResult = imageEmbedder.embed({} as HTMLImageElement);
|
||||||
|
|
||||||
|
expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(embeddingResult).toEqual({
|
||||||
|
embeddings:
|
||||||
|
[{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}],
|
||||||
|
timestampMs: 42
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('for video mode', async () => {
|
||||||
|
await imageEmbedder.setOptions({runningMode: 'video'});
|
||||||
|
|
||||||
|
// Invoke the video embedder
|
||||||
|
const embeddingResult =
|
||||||
|
imageEmbedder.embedForVideo({} as HTMLImageElement, 42);
|
||||||
|
|
||||||
|
expect(imageEmbedder.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(embeddingResult).toEqual({
|
||||||
|
embeddings:
|
||||||
|
[{headIndex: 1, headName: 'headName', floatEmbedding: [0.1, 0.9]}],
|
||||||
|
timestampMs: 42
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -4,6 +4,7 @@
|
||||||
# the detection results for one or more object categories, using Object Detector.
|
# the detection results for one or more object categories, using Object Detector.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||||
|
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
@ -41,3 +42,26 @@ mediapipe_ts_declaration(
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "object_detector_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = [
|
||||||
|
"object_detector_test.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":object_detector",
|
||||||
|
":object_detector_types",
|
||||||
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:detection_jspb_proto",
|
||||||
|
"//mediapipe/framework/formats:location_data_jspb_proto",
|
||||||
|
"//mediapipe/tasks/web/core",
|
||||||
|
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
jasmine_node_test(
|
||||||
|
name = "object_detector_test",
|
||||||
|
tags = ["nomsan"],
|
||||||
|
deps = [":object_detector_test_lib"],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,229 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import 'jasmine';
|
||||||
|
|
||||||
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
|
import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb';
|
||||||
|
import {LocationData} from '../../../../framework/formats/location_data_pb';
|
||||||
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
|
|
||||||
|
import {ObjectDetector} from './object_detector';
|
||||||
|
import {ObjectDetectorOptions} from './object_detector_options';
|
||||||
|
|
||||||
|
// The OSS JS API does not support the builder pattern.
|
||||||
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
class ObjectDetectorFake extends ObjectDetector implements MediapipeTasksFake {
|
||||||
|
lastSampleRate: number|undefined;
|
||||||
|
calculatorName = 'mediapipe.tasks.vision.ObjectDetectorGraph';
|
||||||
|
attachListenerSpies: jasmine.Spy[] = [];
|
||||||
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
|
fakeWasmModule: SpyWasmModule;
|
||||||
|
protoListener: ((binaryProtos: Uint8Array[]) => void)|undefined;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||||
|
this.fakeWasmModule =
|
||||||
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
|
this.attachListenerSpies[0] =
|
||||||
|
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('detections');
|
||||||
|
this.protoListener = listener;
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
});
|
||||||
|
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('ObjectDetector', () => {
|
||||||
|
let objectDetector: ObjectDetectorFake;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
addJasmineCustomFloatEqualityTester();
|
||||||
|
objectDetector = new ObjectDetectorFake();
|
||||||
|
await objectDetector.setOptions({}); // Initialize graph
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes graph', async () => {
|
||||||
|
verifyGraph(objectDetector);
|
||||||
|
verifyListenersRegistered(objectDetector);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reloads graph when settings are changed', async () => {
|
||||||
|
await objectDetector.setOptions({maxResults: 1});
|
||||||
|
verifyGraph(objectDetector, ['maxResults', 1]);
|
||||||
|
verifyListenersRegistered(objectDetector);
|
||||||
|
|
||||||
|
await objectDetector.setOptions({maxResults: 5});
|
||||||
|
verifyGraph(objectDetector, ['maxResults', 5]);
|
||||||
|
verifyListenersRegistered(objectDetector);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can use custom models', async () => {
|
||||||
|
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||||
|
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||||
|
await objectDetector.setOptions({
|
||||||
|
baseOptions: {
|
||||||
|
modelAssetBuffer: newModel,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verifyGraph(
|
||||||
|
objectDetector,
|
||||||
|
/* expectedCalculatorOptions= */ undefined,
|
||||||
|
/* expectedBaseOptions= */
|
||||||
|
[
|
||||||
|
'modelAsset', {
|
||||||
|
fileContent: newModelBase64,
|
||||||
|
fileName: undefined,
|
||||||
|
fileDescriptorMeta: undefined,
|
||||||
|
filePointerMeta: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('merges options', async () => {
|
||||||
|
await objectDetector.setOptions({maxResults: 1});
|
||||||
|
await objectDetector.setOptions({displayNamesLocale: 'en'});
|
||||||
|
verifyGraph(objectDetector, ['maxResults', 1]);
|
||||||
|
verifyGraph(objectDetector, ['displayNamesLocale', 'en']);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setOptions() ', () => {
|
||||||
|
interface TestCase {
|
||||||
|
optionName: keyof ObjectDetectorOptions;
|
||||||
|
protoName: string;
|
||||||
|
customValue: unknown;
|
||||||
|
defaultValue: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
const testCases: TestCase[] = [
|
||||||
|
{
|
||||||
|
optionName: 'maxResults',
|
||||||
|
protoName: 'maxResults',
|
||||||
|
customValue: 5,
|
||||||
|
defaultValue: -1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionName: 'displayNamesLocale',
|
||||||
|
protoName: 'displayNamesLocale',
|
||||||
|
customValue: 'en',
|
||||||
|
defaultValue: 'en'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionName: 'scoreThreshold',
|
||||||
|
protoName: 'scoreThreshold',
|
||||||
|
customValue: 0.1,
|
||||||
|
defaultValue: undefined
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionName: 'categoryAllowlist',
|
||||||
|
protoName: 'categoryAllowlistList',
|
||||||
|
customValue: ['foo'],
|
||||||
|
defaultValue: []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
optionName: 'categoryDenylist',
|
||||||
|
protoName: 'categoryDenylistList',
|
||||||
|
customValue: ['bar'],
|
||||||
|
defaultValue: []
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const testCase of testCases) {
|
||||||
|
it(`can set ${testCase.optionName}`, async () => {
|
||||||
|
await objectDetector.setOptions(
|
||||||
|
{[testCase.optionName]: testCase.customValue});
|
||||||
|
verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it(`can clear ${testCase.optionName}`, async () => {
|
||||||
|
await objectDetector.setOptions(
|
||||||
|
{[testCase.optionName]: testCase.customValue});
|
||||||
|
verifyGraph(objectDetector, [testCase.protoName, testCase.customValue]);
|
||||||
|
await objectDetector.setOptions({[testCase.optionName]: undefined});
|
||||||
|
verifyGraph(
|
||||||
|
objectDetector, [testCase.protoName, testCase.defaultValue]);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('transforms results', async () => {
|
||||||
|
const detectionProtos: Uint8Array[] = [];
|
||||||
|
|
||||||
|
// Add a detection with all optional properties
|
||||||
|
let detection = new DetectionProto();
|
||||||
|
detection.addScore(0.1);
|
||||||
|
detection.addLabelId(1);
|
||||||
|
detection.addLabel('foo');
|
||||||
|
detection.addDisplayName('bar');
|
||||||
|
let locationData = new LocationData();
|
||||||
|
let boundingBox = new LocationData.BoundingBox();
|
||||||
|
boundingBox.setXmin(1);
|
||||||
|
boundingBox.setYmin(2);
|
||||||
|
boundingBox.setWidth(3);
|
||||||
|
boundingBox.setHeight(4);
|
||||||
|
locationData.setBoundingBox(boundingBox);
|
||||||
|
detection.setLocationData(locationData);
|
||||||
|
detectionProtos.push(detection.serializeBinary());
|
||||||
|
|
||||||
|
// Add a detection without optional properties
|
||||||
|
detection = new DetectionProto();
|
||||||
|
detection.addScore(0.2);
|
||||||
|
locationData = new LocationData();
|
||||||
|
boundingBox = new LocationData.BoundingBox();
|
||||||
|
locationData.setBoundingBox(boundingBox);
|
||||||
|
detection.setLocationData(locationData);
|
||||||
|
detectionProtos.push(detection.serializeBinary());
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
objectDetector.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
verifyListenersRegistered(objectDetector);
|
||||||
|
objectDetector.protoListener!(detectionProtos);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Invoke the object detector
|
||||||
|
const detections = objectDetector.detect({} as HTMLImageElement);
|
||||||
|
|
||||||
|
expect(objectDetector.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(detections.length).toEqual(2);
|
||||||
|
expect(detections[0]).toEqual({
|
||||||
|
categories: [{
|
||||||
|
score: 0.1,
|
||||||
|
index: 1,
|
||||||
|
categoryName: 'foo',
|
||||||
|
displayName: 'bar',
|
||||||
|
}],
|
||||||
|
boundingBox: {originX: 1, originY: 2, width: 3, height: 4}
|
||||||
|
});
|
||||||
|
expect(detections[1]).toEqual({
|
||||||
|
categories: [{
|
||||||
|
score: 0.2,
|
||||||
|
index: -1,
|
||||||
|
categoryName: '',
|
||||||
|
displayName: '',
|
||||||
|
}],
|
||||||
|
boundingBox: {originX: 0, originY: 0, width: 0, height: 0}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
Loading…
Reference in New Issue
Block a user