Remove auxiliary landmarks in PoseLandmarker API results.

PiperOrigin-RevId: 529989746
This commit is contained in:
MediaPipe Team 2023-05-06 12:55:55 -07:00 committed by Copybara-Service
parent ddb84702f6
commit 876987b389
12 changed files with 19 additions and 170 deletions

View File

@ -63,8 +63,6 @@ constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
constexpr char kNormLandmarksStreamName[] = "norm_landmarks";
constexpr char kPoseWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kPoseWorldLandmarksStreamName[] = "world_landmarks";
constexpr char kPoseAuxiliaryLandmarksTag[] = "AUXILIARY_LANDMARKS";
constexpr char kPoseAuxiliaryLandmarksStreamName[] = "auxiliary_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that contains a subgraph node of
@ -83,9 +81,6 @@ CalculatorGraphConfig CreateGraphConfig(
graph.Out(kNormLandmarksTag);
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
graph.Out(kPoseWorldLandmarksTag);
subgraph.Out(kPoseAuxiliaryLandmarksTag)
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
graph.Out(kPoseAuxiliaryLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (output_segmentation_masks) {
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
@ -163,8 +158,6 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
status_or_packets.value()[kNormLandmarksStreamName];
Packet pose_world_landmarks_packet =
status_or_packets.value()[kPoseWorldLandmarksStreamName];
Packet pose_auxiliary_landmarks_packet =
status_or_packets.value()[kPoseAuxiliaryLandmarksStreamName];
std::optional<std::vector<Image>> segmentation_mask = std::nullopt;
if (output_segmentation_masks) {
segmentation_mask = segmentation_mask_packet.Get<std::vector<Image>>();
@ -175,9 +168,7 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
/* pose_landmarks= */
pose_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
/* pose_world_landmarks= */
pose_world_landmarks_packet.Get<std::vector<LandmarkList>>(),
pose_auxiliary_landmarks_packet
.Get<std::vector<NormalizedLandmarkList>>()),
pose_world_landmarks_packet.Get<std::vector<LandmarkList>>()),
image_packet.Get<Image>(),
pose_landmarks_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
@ -234,10 +225,7 @@ absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::Detect(
.Get<std::vector<mediapipe::NormalizedLandmarkList>>(),
/* pose_world_landmarks */
output_packets[kPoseWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>(),
/*pose_auxiliary_landmarks= */
output_packets[kPoseAuxiliaryLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>());
.Get<std::vector<mediapipe::LandmarkList>>());
}
absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::DetectForVideo(
@ -277,10 +265,7 @@ absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::DetectForVideo(
.Get<std::vector<mediapipe::NormalizedLandmarkList>>(),
/* pose_world_landmarks */
output_packets[kPoseWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>(),
/* pose_auxiliary_landmarks= */
output_packets[kPoseAuxiliaryLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>());
.Get<std::vector<mediapipe::LandmarkList>>());
}
absl::Status PoseLandmarker::DetectAsync(

View File

@ -27,15 +27,12 @@ namespace pose_landmarker {
PoseLandmarkerResult ConvertToPoseLandmarkerResult(
std::optional<std::vector<mediapipe::Image>> segmentation_masks,
const std::vector<mediapipe::NormalizedLandmarkList>& pose_landmarks_proto,
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto,
const std::vector<mediapipe::NormalizedLandmarkList>&
pose_auxiliary_landmarks_proto) {
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto) {
PoseLandmarkerResult result;
result.segmentation_masks = segmentation_masks;
result.pose_landmarks.resize(pose_landmarks_proto.size());
result.pose_world_landmarks.resize(pose_world_landmarks_proto.size());
result.pose_auxiliary_landmarks.resize(pose_auxiliary_landmarks_proto.size());
std::transform(pose_landmarks_proto.begin(), pose_landmarks_proto.end(),
result.pose_landmarks.begin(),
components::containers::ConvertToNormalizedLandmarks);
@ -43,10 +40,6 @@ PoseLandmarkerResult ConvertToPoseLandmarkerResult(
pose_world_landmarks_proto.end(),
result.pose_world_landmarks.begin(),
components::containers::ConvertToLandmarks);
std::transform(pose_auxiliary_landmarks_proto.begin(),
pose_auxiliary_landmarks_proto.end(),
result.pose_auxiliary_landmarks.begin(),
components::containers::ConvertToNormalizedLandmarks);
return result;
}

View File

@ -37,17 +37,12 @@ struct PoseLandmarkerResult {
std::vector<components::containers::NormalizedLandmarks> pose_landmarks;
// Detected pose landmarks in world coordinates.
std::vector<components::containers::Landmarks> pose_world_landmarks;
// Detected auxiliary landmarks, used for deriving ROI for next frame.
std::vector<components::containers::NormalizedLandmarks>
pose_auxiliary_landmarks;
};
PoseLandmarkerResult ConvertToPoseLandmarkerResult(
std::optional<std::vector<mediapipe::Image>> segmentation_mask,
const std::vector<mediapipe::NormalizedLandmarkList>& pose_landmarks_proto,
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto,
const std::vector<mediapipe::NormalizedLandmarkList>&
pose_auxiliary_landmarks_proto);
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto);
} // namespace pose_landmarker
} // namespace vision

View File

@ -47,13 +47,6 @@ TEST(ConvertFromProto, Succeeds) {
landmark_proto.set_y(5.2);
landmark_proto.set_z(4.3);
mediapipe::NormalizedLandmarkList auxiliary_landmark_list_proto;
mediapipe::NormalizedLandmark& auxiliary_landmark_proto =
*auxiliary_landmark_list_proto.add_landmark();
auxiliary_landmark_proto.set_x(0.5);
auxiliary_landmark_proto.set_y(0.5);
auxiliary_landmark_proto.set_z(0.5);
std::vector<Image> segmentation_masks_lists = {segmentation_mask};
std::vector<mediapipe::NormalizedLandmarkList> normalized_landmarks_lists = {
@ -62,12 +55,9 @@ TEST(ConvertFromProto, Succeeds) {
std::vector<mediapipe::LandmarkList> world_landmarks_lists = {
world_landmark_list_proto};
std::vector<mediapipe::NormalizedLandmarkList> auxiliary_landmarks_lists = {
auxiliary_landmark_list_proto};
PoseLandmarkerResult pose_landmarker_result = ConvertToPoseLandmarkerResult(
segmentation_masks_lists, normalized_landmarks_lists,
world_landmarks_lists, auxiliary_landmarks_lists);
world_landmarks_lists);
EXPECT_EQ(pose_landmarker_result.pose_landmarks.size(), 1);
EXPECT_EQ(pose_landmarker_result.pose_landmarks[0].landmarks.size(), 1);
@ -82,14 +72,6 @@ TEST(ConvertFromProto, Succeeds) {
testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2),
testing::FloatEq(4.3), std::nullopt,
std::nullopt, std::nullopt));
EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks.size(), 1);
EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks.size(),
1);
EXPECT_THAT(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks[0],
testing::FieldsAre(testing::FloatEq(0.5), testing::FloatEq(0.5),
testing::FloatEq(0.5), std::nullopt,
std::nullopt, std::nullopt));
}
} // namespace pose_landmarker

View File

@ -79,8 +79,7 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
private static final int AUXILIARY_LANDMARKS_OUT_STREAM_INDEX = 2;
private static final int IMAGE_OUT_STREAM_INDEX = 3;
private static final int IMAGE_OUT_STREAM_INDEX = 2;
private static int segmentationMasksOutStreamIndex = -1;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph";
@ -145,7 +144,6 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
List<String> outputStreams = new ArrayList<>();
outputStreams.add("NORM_LANDMARKS:pose_landmarks");
outputStreams.add("WORLD_LANDMARKS:world_landmarks");
outputStreams.add("AUXILIARY_LANDMARKS:auxiliary_landmarks");
outputStreams.add("IMAGE:image_out");
if (landmarkerOptions.outputSegmentationMasks()) {
outputStreams.add("SEGMENTATION_MASK:segmentation_masks");
@ -161,7 +159,6 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
// If there is no poses detected in the image, just returns empty lists.
if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
return PoseLandmarkerResult.create(
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
Optional.empty(),
@ -179,9 +176,6 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
PacketGetter.getProtoVector(
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
PacketGetter.getProtoVector(
packets.get(AUXILIARY_LANDMARKS_OUT_STREAM_INDEX),
NormalizedLandmarkList.parser()),
segmentedMasks,
BaseVisionTaskApi.generateResultTimestampMs(
landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX)));

View File

@ -40,7 +40,6 @@ public abstract class PoseLandmarkerResult implements TaskResult {
static PoseLandmarkerResult create(
List<LandmarkProto.NormalizedLandmarkList> landmarksProto,
List<LandmarkProto.LandmarkList> worldLandmarksProto,
List<LandmarkProto.NormalizedLandmarkList> auxiliaryLandmarksProto,
Optional<List<MPImage>> segmentationMasksData,
long timestampMs) {
@ -52,7 +51,6 @@ public abstract class PoseLandmarkerResult implements TaskResult {
List<List<NormalizedLandmark>> multiPoseLandmarks = new ArrayList<>();
List<List<Landmark>> multiPoseWorldLandmarks = new ArrayList<>();
List<List<NormalizedLandmark>> multiPoseAuxiliaryLandmarks = new ArrayList<>();
for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) {
List<NormalizedLandmark> poseLandmarks = new ArrayList<>();
multiPoseLandmarks.add(poseLandmarks);
@ -75,24 +73,10 @@ public abstract class PoseLandmarkerResult implements TaskResult {
poseWorldLandmarkProto.getZ()));
}
}
for (LandmarkProto.NormalizedLandmarkList poseAuxiliaryLandmarksProto :
auxiliaryLandmarksProto) {
List<NormalizedLandmark> poseAuxiliaryLandmarks = new ArrayList<>();
multiPoseAuxiliaryLandmarks.add(poseAuxiliaryLandmarks);
for (LandmarkProto.NormalizedLandmark poseAuxiliaryLandmarkProto :
poseAuxiliaryLandmarksProto.getLandmarkList()) {
poseAuxiliaryLandmarks.add(
NormalizedLandmark.create(
poseAuxiliaryLandmarkProto.getX(),
poseAuxiliaryLandmarkProto.getY(),
poseAuxiliaryLandmarkProto.getZ()));
}
}
return new AutoValue_PoseLandmarkerResult(
timestampMs,
Collections.unmodifiableList(multiPoseLandmarks),
Collections.unmodifiableList(multiPoseWorldLandmarks),
Collections.unmodifiableList(multiPoseAuxiliaryLandmarks),
multiPoseSegmentationMasks);
}
@ -105,9 +89,6 @@ public abstract class PoseLandmarkerResult implements TaskResult {
/** Pose landmarks in world coordniates of detected poses. */
public abstract List<List<Landmark>> worldLandmarks();
/** Pose auxiliary landmarks. */
public abstract List<List<NormalizedLandmark>> auxiliaryLandmarks();
/** Pose segmentation masks. */
public abstract Optional<List<MPImage>> segmentationMasks();
}

View File

@ -330,7 +330,6 @@ public class PoseLandmarkerTest {
return PoseLandmarkerResult.create(
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
Arrays.asList(),
Optional.empty(),
/* timestampMs= */ 0);
}

View File

@ -74,7 +74,6 @@ def _get_expected_pose_landmarker_result(
return PoseLandmarkerResult(
pose_landmarks=[landmarks_detection_result.landmarks],
pose_world_landmarks=[],
pose_auxiliary_landmarks=[],
)
@ -296,7 +295,6 @@ class PoseLandmarkerTest(parameterized.TestCase):
# Comparing results.
self.assertEmpty(detection_result.pose_landmarks)
self.assertEmpty(detection_result.pose_world_landmarks)
self.assertEmpty(detection_result.pose_auxiliary_landmarks)
def test_missing_result_callback(self):
options = _PoseLandmarkerOptions(
@ -391,7 +389,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
),
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [], [])),
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [])),
)
def test_detect_for_video(
self, image_path, rotation, output_segmentation_masks, expected_result
@ -473,7 +471,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
True,
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
),
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [], [])),
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [])),
)
def test_detect_async_calls(
self, image_path, rotation, output_segmentation_masks, expected_result

View File

@ -49,8 +49,6 @@ _NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks'
_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS'
_POSE_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_POSE_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_POSE_AUXILIARY_LANDMARKS_STREAM_NAME = 'auxiliary_landmarks'
_POSE_AUXILIARY_LANDMARKS_TAG = 'AUXILIARY_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@ -62,14 +60,11 @@ class PoseLandmarkerResult:
Attributes:
pose_landmarks: Detected pose landmarks in normalized image coordinates.
pose_world_landmarks: Detected pose landmarks in world coordinates.
pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving
ROI for next frame.
segmentation_masks: Optional segmentation masks for pose.
"""
pose_landmarks: List[List[landmark_module.NormalizedLandmark]]
pose_world_landmarks: List[List[landmark_module.Landmark]]
pose_auxiliary_landmarks: List[List[landmark_module.NormalizedLandmark]]
segmentation_masks: Optional[List[image_module.Image]] = None
@ -77,7 +72,7 @@ def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet]
) -> PoseLandmarkerResult:
"""Constructs a `PoseLandmarkerResult` from output packets."""
pose_landmarker_result = PoseLandmarkerResult([], [], [])
pose_landmarker_result = PoseLandmarkerResult([], [])
if _SEGMENTATION_MASK_STREAM_NAME in output_packets:
pose_landmarker_result.segmentation_masks = packet_getter.get_image_list(
@ -90,9 +85,6 @@ def _build_landmarker_result(
pose_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME]
)
pose_auxiliary_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_POSE_AUXILIARY_LANDMARKS_STREAM_NAME]
)
for proto in pose_landmarks_proto_list:
pose_landmarks = landmark_pb2.NormalizedLandmarkList()
@ -116,19 +108,6 @@ def _build_landmarker_result(
pose_world_landmarks_list
)
for proto in pose_auxiliary_landmarks_proto_list:
pose_auxiliary_landmarks = landmark_pb2.NormalizedLandmarkList()
pose_auxiliary_landmarks.MergeFrom(proto)
pose_auxiliary_landmarks_list = []
for pose_auxiliary_landmark in pose_auxiliary_landmarks.landmark:
pose_auxiliary_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(
pose_auxiliary_landmark
)
)
pose_landmarker_result.pose_auxiliary_landmarks.append(
pose_auxiliary_landmarks_list
)
return pose_landmarker_result
@ -301,7 +280,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME]
options.result_callback(
PoseLandmarkerResult([], [], []),
PoseLandmarkerResult([], []),
image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
@ -320,10 +299,6 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
':'.join(
[_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME]
),
':'.join([
_POSE_AUXILIARY_LANDMARKS_TAG,
_POSE_AUXILIARY_LANDMARKS_STREAM_NAME,
]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
@ -382,7 +357,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
})
if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
return PoseLandmarkerResult([], [], [])
return PoseLandmarkerResult([], [])
return _build_landmarker_result(output_packets)
@ -427,7 +402,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
})
if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
return PoseLandmarkerResult([], [], [])
return PoseLandmarkerResult([], [])
return _build_landmarker_result(output_packets)

View File

@ -43,7 +43,6 @@ const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const NORM_LANDMARKS_STREAM = 'normalized_landmarks';
const WORLD_LANDMARKS_STREAM = 'world_landmarks';
const AUXILIARY_LANDMARKS_STREAM = 'auxiliary_landmarks';
const SEGMENTATION_MASK_STREAM = 'segmentation_masks';
const POSE_LANDMARKER_GRAPH =
'mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph';
@ -371,9 +370,6 @@ export class PoseLandmarker extends VisionTaskRunner {
if (!('worldLandmarks' in this.result)) {
return;
}
if (!('auxilaryLandmarks' in this.result)) {
return;
}
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
return;
}
@ -419,20 +415,6 @@ export class PoseLandmarker extends VisionTaskRunner {
}
}
/**
* Converts raw data into a landmark, and adds it to our auxilary
* landmarks list.
*/
private addJsAuxiliaryLandmarks(data: Uint8Array[]): void {
this.result.auxilaryLandmarks = [];
for (const binaryProto of data) {
const auxiliaryLandmarksProto =
NormalizedLandmarkList.deserializeBinary(binaryProto);
this.result.auxilaryLandmarks.push(
convertToLandmarks(auxiliaryLandmarksProto));
}
}
/** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
@ -440,7 +422,6 @@ export class PoseLandmarker extends VisionTaskRunner {
graphConfig.addInputStream(NORM_RECT_STREAM);
graphConfig.addOutputStream(NORM_LANDMARKS_STREAM);
graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM);
graphConfig.addOutputStream(AUXILIARY_LANDMARKS_STREAM);
graphConfig.addOutputStream(SEGMENTATION_MASK_STREAM);
const calculatorOptions = new CalculatorOptions();
@ -453,8 +434,6 @@ export class PoseLandmarker extends VisionTaskRunner {
landmarkerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
landmarkerNode.addOutputStream('NORM_LANDMARKS:' + NORM_LANDMARKS_STREAM);
landmarkerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM);
landmarkerNode.addOutputStream(
'AUXILIARY_LANDMARKS:' + AUXILIARY_LANDMARKS_STREAM);
landmarkerNode.setOptions(calculatorOptions);
graphConfig.addNode(landmarkerNode);
@ -485,19 +464,6 @@ export class PoseLandmarker extends VisionTaskRunner {
this.maybeInvokeCallback();
});
this.graphRunner.attachProtoVectorListener(
AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsAuxiliaryLandmarks(binaryProto);
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
this.graphRunner.attachEmptyPacketListener(
AUXILIARY_LANDMARKS_STREAM, timestamp => {
this.result.auxilaryLandmarks = [];
this.setLatestOutputTimestamp(timestamp);
this.maybeInvokeCallback();
});
if (this.outputSegmentationMasks) {
landmarkerNode.addOutputStream(
'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);

View File

@ -31,9 +31,6 @@ export declare interface PoseLandmarkerResult {
/** Pose landmarks in world coordinates of detected poses. */
worldLandmarks: Landmark[][];
/** Detected auxiliary landmarks, used for deriving ROI for next frame. */
auxilaryLandmarks: NormalizedLandmark[][];
/** Segmentation mask for the detected pose. */
segmentationMasks?: MPMask[];
}

View File

@ -45,8 +45,7 @@ class PoseLandmarkerFake extends PoseLandmarker implements MediapipeTasksFake {
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toMatch(
/(normalized_landmarks|world_landmarks|auxiliary_landmarks)/);
expect(stream).toMatch(/(normalized_landmarks|world_landmarks)/);
this.listeners.set(stream, listener as PacketListener);
});
this.attachListenerSpies[1] =
@ -80,23 +79,23 @@ describe('PoseLandmarker', () => {
it('initializes graph', async () => {
verifyGraph(poseLandmarker);
expect(poseLandmarker.listeners).toHaveSize(3);
expect(poseLandmarker.listeners).toHaveSize(2);
});
it('reloads graph when settings are changed', async () => {
await poseLandmarker.setOptions({numPoses: 1});
verifyGraph(poseLandmarker, [['poseDetectorGraphOptions', 'numPoses'], 1]);
expect(poseLandmarker.listeners).toHaveSize(3);
expect(poseLandmarker.listeners).toHaveSize(2);
await poseLandmarker.setOptions({numPoses: 5});
verifyGraph(poseLandmarker, [['poseDetectorGraphOptions', 'numPoses'], 5]);
expect(poseLandmarker.listeners).toHaveSize(3);
expect(poseLandmarker.listeners).toHaveSize(2);
});
it('registers listener for segmentation masks', async () => {
expect(poseLandmarker.listeners).toHaveSize(3);
expect(poseLandmarker.listeners).toHaveSize(2);
await poseLandmarker.setOptions({outputSegmentationMasks: true});
expect(poseLandmarker.listeners).toHaveSize(4);
expect(poseLandmarker.listeners).toHaveSize(3);
});
it('merges options', async () => {
@ -209,8 +208,6 @@ describe('PoseLandmarker', () => {
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
});
@ -224,7 +221,6 @@ describe('PoseLandmarker', () => {
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.segmentationMasks![0]).toBeInstanceOf(MPMask);
done();
});
@ -240,8 +236,6 @@ describe('PoseLandmarker', () => {
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
});
// Invoke the pose landmarker twice
@ -279,8 +273,6 @@ describe('PoseLandmarker', () => {
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
});
// Invoke the pose landmarker
@ -291,9 +283,6 @@ describe('PoseLandmarker', () => {
expect(result.worldLandmarks).toEqual([
[{'x': 1, 'y': 2, 'z': 3}], [{'x': 4, 'y': 5, 'z': 6}]
]);
expect(result.auxilaryLandmarks).toEqual([
[{'x': 0.1, 'y': 0.2, 'z': 0.3}], [{'x': 0.4, 'y': 0.5, 'z': 0.6}]
]);
done();
});
});
@ -318,8 +307,6 @@ describe('PoseLandmarker', () => {
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
expect(listenerCalled).toBeFalse();
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
expect(listenerCalled).toBeTrue();
@ -342,8 +329,6 @@ describe('PoseLandmarker', () => {
(landmarksProto, 1337);
poseLandmarker.listeners.get('world_landmarks')!
(worldLandmarksProto, 1337);
poseLandmarker.listeners.get('auxiliary_landmarks')!
(landmarksProto, 1337);
});
// Invoke the pose landmarker
@ -351,6 +336,5 @@ describe('PoseLandmarker', () => {
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
});
});