Internal change

PiperOrigin-RevId: 499283559
This commit is contained in:
Sebastian Schmidt 2023-01-03 12:09:59 -08:00 committed by Copybara-Service
parent 68f247a5c7
commit 75b87e0e32
3 changed files with 55 additions and 11 deletions

View File

@ -263,6 +263,15 @@ export class GestureRecognizer extends
NORM_RECT_STREAM, timestamp); NORM_RECT_STREAM, timestamp);
this.finishProcessing(); this.finishProcessing();
if (this.gestures.length === 0) {
// If no gestures are detected in the image, just return an empty list
return {
gestures: [],
landmarks: [],
worldLandmarks: [],
handednesses: [],
};
} else {
return { return {
gestures: this.gestures, gestures: this.gestures,
landmarks: this.landmarks, landmarks: this.landmarks,
@ -270,6 +279,7 @@ export class GestureRecognizer extends
handednesses: this.handednesses handednesses: this.handednesses
}; };
} }
}
/** Sets the default values for the graph. */ /** Sets the default values for the graph. */
private initDefaults(): void { private initDefaults(): void {
@ -283,15 +293,19 @@ export class GestureRecognizer extends
} }
/** Converts the proto data to a Category[][] structure. */ /** Converts the proto data to a Category[][] structure. */
private toJsCategories(data: Uint8Array[]): Category[][] { private toJsCategories(data: Uint8Array[], populateIndex = true):
Category[][] {
const result: Category[][] = []; const result: Category[][] = [];
for (const binaryProto of data) { for (const binaryProto of data) {
const inputList = ClassificationList.deserializeBinary(binaryProto); const inputList = ClassificationList.deserializeBinary(binaryProto);
const outputList: Category[] = []; const outputList: Category[] = [];
for (const classification of inputList.getClassificationList()) { for (const classification of inputList.getClassificationList()) {
const index = populateIndex && classification.hasIndex() ?
classification.getIndex()! :
DEFAULT_CATEGORY_INDEX;
outputList.push({ outputList.push({
score: classification.getScore() ?? 0, score: classification.getScore() ?? 0,
index: classification.getIndex() ?? DEFAULT_CATEGORY_INDEX, index,
categoryName: classification.getLabel() ?? '', categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '', displayName: classification.getDisplayName() ?? '',
}); });
@ -375,7 +389,10 @@ export class GestureRecognizer extends
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
HAND_GESTURES_STREAM, binaryProto => { HAND_GESTURES_STREAM, binaryProto => {
this.gestures.push(...this.toJsCategories(binaryProto)); // Gesture index is not used, because the final gesture result comes
// from multiple classifiers.
this.gestures.push(
...this.toJsCategories(binaryProto, /* populateIndex= */ false));
}); });
this.graphRunner.attachProtoVectorListener( this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, binaryProto => { HANDEDNESS_STREAM, binaryProto => {

View File

@ -17,6 +17,8 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
export {Category, Landmark, NormalizedLandmark};
/** /**
* Represents the gesture recognition results generated by `GestureRecognizer`. * Represents the gesture recognition results generated by `GestureRecognizer`.
*/ */
@ -30,6 +32,10 @@ export declare interface GestureRecognizerResult {
/** Handedness of detected hands. */ /** Handedness of detected hands. */
handednesses: Category[][]; handednesses: Category[][];
/** Recognized hand gestures of detected hands */ /**
* Recognized hand gestures of detected hands. Note that the index of the
* gesture is always -1, because the raw indices from multiple gesture
* classifiers cannot consolidate to a meaningful index.
*/
gestures: Category[][]; gestures: Category[][];
} }

View File

@ -272,7 +272,7 @@ describe('GestureRecognizer', () => {
expect(gestures).toEqual({ expect(gestures).toEqual({
'gestures': [[{ 'gestures': [[{
'score': 0.2, 'score': 0.2,
'index': 2, 'index': -1,
'categoryName': 'gesture_label', 'categoryName': 'gesture_label',
'displayName': 'gesture_display_name' 'displayName': 'gesture_display_name'
}]], }]],
@ -305,4 +305,25 @@ describe('GestureRecognizer', () => {
// gestures. // gestures.
expect(gestures2).toEqual(gestures1); expect(gestures2).toEqual(gestures1);
}); });
it('returns empty results when no gestures are detected', async () => {
// Pass the test data to our listener
gestureRecognizer.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(gestureRecognizer);
gestureRecognizer.listeners.get('hand_landmarks')!(createLandmarks());
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks());
gestureRecognizer.listeners.get('handedness')!(createHandednesses());
gestureRecognizer.listeners.get('hand_gestures')!([]);
});
// Invoke the gesture recognizer
const gestures = gestureRecognizer.recognize({} as HTMLImageElement);
expect(gestures).toEqual({
'gestures': [],
'landmarks': [],
'worldLandmarks': [],
'handednesses': []
});
});
}); });