Add Handedness to JS, C++ and Android API

PiperOrigin-RevId: 564559718
This commit is contained in:
Sebastian Schmidt 2023-09-11 18:20:30 -07:00 committed by Copybara-Service
parent 02bd0d95e7
commit 12502b6f96
13 changed files with 89 additions and 44 deletions

View File

@ -83,7 +83,7 @@ struct HandLandmarkerOutputs {
Stream<std::vector<NormalizedLandmarkList>> landmark_lists;
Stream<std::vector<LandmarkList>> world_landmark_lists;
Stream<std::vector<NormalizedRect>> hand_rects_next_frame;
Stream<std::vector<ClassificationList>> handednesses;
Stream<std::vector<ClassificationList>> handedness;
Stream<std::vector<NormalizedRect>> palm_rects;
Stream<std::vector<Detection>> palm_detections;
Stream<Image> image;
@ -241,7 +241,7 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_landmarker_outputs.hand_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kHandRectNextFrameTag)];
hand_landmarker_outputs.handednesses >>
hand_landmarker_outputs.handedness >>
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
hand_landmarker_outputs.palm_rects >>
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];

View File

@ -93,7 +93,7 @@ struct HandLandmarkerOutputs {
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
Source<std::vector<bool>> presences;
Source<std::vector<float>> presence_scores;
Source<std::vector<ClassificationList>> handednesses;
Source<std::vector<ClassificationList>> handedness;
};
absl::Status SanityCheckOptions(
@ -478,7 +478,7 @@ class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
graph[Output<std::vector<bool>>(kPresenceTag)];
hand_landmark_detection_outputs.presence_scores >>
graph[Output<std::vector<float>>(kPresenceScoreTag)];
hand_landmark_detection_outputs.handednesses >>
hand_landmark_detection_outputs.handedness >>
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
return graph.GetConfig();
@ -562,7 +562,7 @@ class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
/* hand_rects_next_frame= */ hand_rects_next_frame,
/* presences= */ presences,
/* presence_scores= */ presence_scores,
/* handednesses= */ handednesses,
/* handedness= */ handednesses,
}};
}
};

View File

@ -319,15 +319,15 @@ TEST_P(MultiHandLandmarkerTest, Succeeds) {
const std::vector<bool>& presences =
(*output_packets)[kPresenceName].Get<std::vector<bool>>();
const std::vector<ClassificationList>& handednesses =
const std::vector<ClassificationList>& handedness =
(*output_packets)[kHandednessName].Get<std::vector<ClassificationList>>();
const std::vector<NormalizedLandmarkList>& landmark_lists =
(*output_packets)[kLandmarksName]
.Get<std::vector<NormalizedLandmarkList>>();
EXPECT_THAT(presences, ElementsAreArray(GetParam().expected_presences));
EXPECT_THAT(handednesses, Pointwise(Partially(EqualsProto()),
GetParam().expected_handedness));
EXPECT_THAT(handedness, Pointwise(Partially(EqualsProto()),
GetParam().expected_handedness));
EXPECT_THAT(
landmark_lists,
Pointwise(Approximately(Partially(EqualsProto()), /*margin=*/kAbsMargin,

View File

@ -114,11 +114,21 @@ public abstract class GestureRecognizerResult implements TaskResult {
/** Hand landmarks of detected hands. */
public abstract List<List<NormalizedLandmark>> landmarks();
/** Hand landmarks in world coordniates of detected hands. */
/** Hand landmarks in world coordinates of detected hands. */
public abstract List<List<Landmark>> worldLandmarks();
/**
* Handedness of detected hands.
*
* @deprecated Use {@link #handedness()} instead.
*/
@Deprecated
public List<List<Category>> handednesses() {
return handedness();
}
/** Handedness of detected hands. */
public abstract List<List<Category>> handednesses();
public abstract List<List<Category>> handedness();
/**
* Recognized hand gestures of detected hands. Note that the index of the gesture is always -1,

View File

@ -108,9 +108,19 @@ public abstract class HandLandmarkerResult implements TaskResult {
/** Hand landmarks of detected hands. */
public abstract List<List<NormalizedLandmark>> landmarks();
/** Hand landmarks in world coordniates of detected hands. */
/** Hand landmarks in world coordinates of detected hands. */
public abstract List<List<Landmark>> worldLandmarks();
/**
* Handedness of detected hands.
*
* @deprecated Use {@link #handedness()} instead.
*/
@Deprecated
public List<List<Category>> handednesses() {
return handedness();
}
/** Handedness of detected hands. */
public abstract List<List<Category>> handednesses();
public abstract List<List<Category>> handedness();
}

View File

@ -102,7 +102,7 @@ public class GestureRecognizerTest {
gestureRecognizer.recognize(getImageFromAsset(NO_HANDS_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.handedness()).isEmpty();
assertThat(actualResult.gestures()).isEmpty();
}
@ -143,7 +143,7 @@ public class GestureRecognizerTest {
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
GestureRecognizerResult actualResult =
gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE));
assertThat(actualResult.handednesses()).hasSize(2);
assertThat(actualResult.handedness()).hasSize(2);
}
@Test
@ -251,7 +251,7 @@ public class GestureRecognizerTest {
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.handedness()).isEmpty();
assertThat(actualResult.gestures()).isEmpty();
}
@ -284,7 +284,7 @@ public class GestureRecognizerTest {
gestureRecognizer.recognize(getImageFromAsset(FIST_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.handedness()).isEmpty();
assertThat(actualResult.gestures()).isEmpty();
}
@ -596,7 +596,7 @@ public class GestureRecognizerTest {
// Expects to have the same number of hands detected.
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size());
assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size());
assertThat(actualResult.handedness()).hasSize(expectedResult.handedness().size());
assertThat(actualResult.gestures()).hasSize(expectedResult.gestures().size());
// Actual landmarks match expected landmarks.
@ -614,8 +614,8 @@ public class GestureRecognizerTest {
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
// Actual handedness matches expected handedness.
Category actualTopHandedness = actualResult.handednesses().get(0).get(0);
Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0);
Category actualTopHandedness = actualResult.handedness().get(0).get(0);
Category expectedTopHandedness = expectedResult.handedness().get(0).get(0);
assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index());
assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName());

View File

@ -92,7 +92,7 @@ public class HandLandmarkerTest {
handLandmarker.detect(getImageFromAsset(NO_HANDS_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
assertThat(actualResult.handedness()).isEmpty();
}
@Test
@ -109,7 +109,7 @@ public class HandLandmarkerTest {
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
HandLandmarkerResult actualResult =
handLandmarker.detect(getImageFromAsset(TWO_HANDS_IMAGE));
assertThat(actualResult.handednesses()).hasSize(2);
assertThat(actualResult.handedness()).hasSize(2);
}
@Test
@ -393,7 +393,7 @@ public class HandLandmarkerTest {
// Expects to have the same number of hands detected.
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size());
assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size());
assertThat(actualResult.handedness()).hasSize(expectedResult.handedness().size());
// Actual landmarks match expected landmarks.
assertThat(actualResult.landmarks().get(0))
@ -410,8 +410,8 @@ public class HandLandmarkerTest {
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
// Actual handedness matches expected handedness.
Category actualTopHandedness = actualResult.handednesses().get(0).get(0);
Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0);
Category actualTopHandedness = actualResult.handedness().get(0).get(0);
Category expectedTopHandedness = expectedResult.handedness().get(0).get(0);
assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index());
assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName());
}

View File

@ -63,7 +63,7 @@ export class GestureRecognizer extends VisionTaskRunner {
private gestures: Category[][] = [];
private landmarks: NormalizedLandmark[][] = [];
private worldLandmarks: Landmark[][] = [];
private handednesses: Category[][] = [];
private handedness: Category[][] = [];
private readonly options: GestureRecognizerGraphOptions;
private readonly handLandmarkerGraphOptions: HandLandmarkerGraphOptions;
@ -273,7 +273,7 @@ export class GestureRecognizer extends VisionTaskRunner {
this.gestures = [];
this.landmarks = [];
this.worldLandmarks = [];
this.handednesses = [];
this.handedness = [];
}
private processResults(): GestureRecognizerResult {
@ -283,14 +283,16 @@ export class GestureRecognizer extends VisionTaskRunner {
gestures: [],
landmarks: [],
worldLandmarks: [],
handednesses: [],
handedness: [],
handednesses: []
};
} else {
return {
gestures: this.gestures,
landmarks: this.landmarks,
worldLandmarks: this.worldLandmarks,
handednesses: this.handednesses
handedness: this.handedness,
handednesses: this.handedness
};
}
}
@ -416,7 +418,7 @@ export class GestureRecognizer extends VisionTaskRunner {
this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, (binaryProto, timestamp) => {
this.handednesses.push(...this.toJsCategories(binaryProto));
this.handedness.push(...this.toJsCategories(binaryProto));
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(HANDEDNESS_STREAM, timestamp => {

View File

@ -30,6 +30,12 @@ export declare interface GestureRecognizerResult {
worldLandmarks: Landmark[][];
/** Handedness of detected hands. */
handedness: Category[][];
/**
* Handedness of detected hands.
* @deprecated Use `.handedness` instead.
*/
handednesses: Category[][];
/**

View File

@ -28,7 +28,7 @@ import {GestureRecognizer, GestureRecognizerOptions} from './gesture_recognizer'
type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void);
function createHandednesses(): Uint8Array[] {
function createHandedness(): Uint8Array[] {
const handsProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
@ -282,8 +282,7 @@ describe('GestureRecognizer', () => {
(createLandmarks(), 1337);
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks(), 1337);
gestureRecognizer.listeners.get('handedness')!
(createHandednesses(), 1337);
gestureRecognizer.listeners.get('handedness')!(createHandedness(), 1337);
gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337);
});
@ -304,6 +303,12 @@ describe('GestureRecognizer', () => {
}]],
'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]],
'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]],
'handedness': [[{
'score': 0.1,
'index': 1,
'categoryName': 'handedness_label',
'displayName': 'handedness_display_name'
}]],
'handednesses': [[{
'score': 0.1,
'index': 1,
@ -320,8 +325,7 @@ describe('GestureRecognizer', () => {
(createLandmarks(), 1337);
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks(), 1337);
gestureRecognizer.listeners.get('handedness')!
(createHandednesses(), 1337);
gestureRecognizer.listeners.get('handedness')!(createHandedness(), 1337);
gestureRecognizer.listeners.get('hand_gestures')!(createGestures(), 1337);
});
@ -342,8 +346,7 @@ describe('GestureRecognizer', () => {
(createLandmarks(), 1337);
gestureRecognizer.listeners.get('world_hand_landmarks')!
(createWorldLandmarks(), 1337);
gestureRecognizer.listeners.get('handedness')!
(createHandednesses(), 1337);
gestureRecognizer.listeners.get('handedness')!(createHandedness(), 1337);
gestureRecognizer.listeners.get('hand_gestures')!([], 1337);
});
@ -353,6 +356,7 @@ describe('GestureRecognizer', () => {
'gestures': [],
'landmarks': [],
'worldLandmarks': [],
'handedness': [],
'handednesses': []
});
});

View File

@ -58,7 +58,7 @@ const DEFAULT_CATEGORY_INDEX = -1;
export class HandLandmarker extends VisionTaskRunner {
private landmarks: NormalizedLandmark[][] = [];
private worldLandmarks: Landmark[][] = [];
private handednesses: Category[][] = [];
private handedness: Category[][] = [];
private readonly options: HandLandmarkerGraphOptions;
private readonly handLandmarksDetectorGraphOptions:
@ -222,14 +222,15 @@ export class HandLandmarker extends VisionTaskRunner {
private resetResults(): void {
this.landmarks = [];
this.worldLandmarks = [];
this.handednesses = [];
this.handedness = [];
}
private processResults(): HandLandmarkerResult {
return {
landmarks: this.landmarks,
worldLandmarks: this.worldLandmarks,
handednesses: this.handednesses
handednesses: this.handedness,
handedness: this.handedness,
};
}
@ -330,7 +331,7 @@ export class HandLandmarker extends VisionTaskRunner {
this.graphRunner.attachProtoVectorListener(
HANDEDNESS_STREAM, (binaryProto, timestamp) => {
this.handednesses.push(...this.toJsCategories(binaryProto));
this.handedness.push(...this.toJsCategories(binaryProto));
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(

View File

@ -29,6 +29,12 @@ export declare interface HandLandmarkerResult {
/** Hand landmarks in world coordinates of detected hands. */
worldLandmarks: Landmark[][];
/** Handedness of detected hands. */
/**
* Handedness of detected hands.
* @deprecated Use `.handedness` instead.
*/
handednesses: Category[][];
/** Handedness of detected hands. */
handedness: Category[][];
}

View File

@ -30,7 +30,7 @@ import {HandLandmarkerOptions} from './hand_landmarker_options';
type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void);
function createHandednesses(): ClassificationList {
function createHandedness(): ClassificationList {
const handsProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
@ -198,7 +198,7 @@ describe('HandLandmarker', () => {
it('transforms results', async () => {
const landmarksProto = [createLandmarks().serializeBinary()];
const worldLandmarksProto = [createWorldLandmarks().serializeBinary()];
const handednessProto = [createHandednesses().serializeBinary()];
const handednessProto = [createHandedness().serializeBinary()];
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
@ -220,6 +220,12 @@ describe('HandLandmarker', () => {
expect(landmarks).toEqual({
'landmarks': [[{'x': 0, 'y': 0, 'z': 0}]],
'worldLandmarks': [[{'x': 0, 'y': 0, 'z': 0}]],
'handedness': [[{
'score': 0.1,
'index': 1,
'categoryName': 'handedness_label',
'displayName': 'handedness_display_name'
}]],
'handednesses': [[{
'score': 0.1,
'index': 1,
@ -232,7 +238,7 @@ describe('HandLandmarker', () => {
it('clears results between invoations', async () => {
const landmarks = [createLandmarks().serializeBinary()];
const worldLandmarks = [createWorldLandmarks().serializeBinary()];
const handedness = [createHandednesses().serializeBinary()];
const handedness = [createHandedness().serializeBinary()];
// Pass the test data to our listener
handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {