Remove auxiliary landmarks in PoseLandmarker API results.
PiperOrigin-RevId: 529989746
This commit is contained in:
parent
ddb84702f6
commit
876987b389
|
@ -63,8 +63,6 @@ constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS";
|
|||
constexpr char kNormLandmarksStreamName[] = "norm_landmarks";
|
||||
constexpr char kPoseWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kPoseWorldLandmarksStreamName[] = "world_landmarks";
|
||||
constexpr char kPoseAuxiliaryLandmarksTag[] = "AUXILIARY_LANDMARKS";
|
||||
constexpr char kPoseAuxiliaryLandmarksStreamName[] = "auxiliary_landmarks";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
|
@ -83,9 +81,6 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
graph.Out(kNormLandmarksTag);
|
||||
subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >>
|
||||
graph.Out(kPoseWorldLandmarksTag);
|
||||
subgraph.Out(kPoseAuxiliaryLandmarksTag)
|
||||
.SetName(kPoseAuxiliaryLandmarksStreamName) >>
|
||||
graph.Out(kPoseAuxiliaryLandmarksTag);
|
||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||
if (output_segmentation_masks) {
|
||||
subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >>
|
||||
|
@ -163,8 +158,6 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
|
|||
status_or_packets.value()[kNormLandmarksStreamName];
|
||||
Packet pose_world_landmarks_packet =
|
||||
status_or_packets.value()[kPoseWorldLandmarksStreamName];
|
||||
Packet pose_auxiliary_landmarks_packet =
|
||||
status_or_packets.value()[kPoseAuxiliaryLandmarksStreamName];
|
||||
std::optional<std::vector<Image>> segmentation_mask = std::nullopt;
|
||||
if (output_segmentation_masks) {
|
||||
segmentation_mask = segmentation_mask_packet.Get<std::vector<Image>>();
|
||||
|
@ -175,9 +168,7 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create(
|
|||
/* pose_landmarks= */
|
||||
pose_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
|
||||
/* pose_world_landmarks= */
|
||||
pose_world_landmarks_packet.Get<std::vector<LandmarkList>>(),
|
||||
pose_auxiliary_landmarks_packet
|
||||
.Get<std::vector<NormalizedLandmarkList>>()),
|
||||
pose_world_landmarks_packet.Get<std::vector<LandmarkList>>()),
|
||||
image_packet.Get<Image>(),
|
||||
pose_landmarks_packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond);
|
||||
|
@ -234,10 +225,7 @@ absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::Detect(
|
|||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>(),
|
||||
/* pose_world_landmarks */
|
||||
output_packets[kPoseWorldLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::LandmarkList>>(),
|
||||
/*pose_auxiliary_landmarks= */
|
||||
output_packets[kPoseAuxiliaryLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>());
|
||||
.Get<std::vector<mediapipe::LandmarkList>>());
|
||||
}
|
||||
|
||||
absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::DetectForVideo(
|
||||
|
@ -277,10 +265,7 @@ absl::StatusOr<PoseLandmarkerResult> PoseLandmarker::DetectForVideo(
|
|||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>(),
|
||||
/* pose_world_landmarks */
|
||||
output_packets[kPoseWorldLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::LandmarkList>>(),
|
||||
/* pose_auxiliary_landmarks= */
|
||||
output_packets[kPoseAuxiliaryLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>());
|
||||
.Get<std::vector<mediapipe::LandmarkList>>());
|
||||
}
|
||||
|
||||
absl::Status PoseLandmarker::DetectAsync(
|
||||
|
|
|
@ -27,15 +27,12 @@ namespace pose_landmarker {
|
|||
PoseLandmarkerResult ConvertToPoseLandmarkerResult(
|
||||
std::optional<std::vector<mediapipe::Image>> segmentation_masks,
|
||||
const std::vector<mediapipe::NormalizedLandmarkList>& pose_landmarks_proto,
|
||||
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto,
|
||||
const std::vector<mediapipe::NormalizedLandmarkList>&
|
||||
pose_auxiliary_landmarks_proto) {
|
||||
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto) {
|
||||
PoseLandmarkerResult result;
|
||||
result.segmentation_masks = segmentation_masks;
|
||||
|
||||
result.pose_landmarks.resize(pose_landmarks_proto.size());
|
||||
result.pose_world_landmarks.resize(pose_world_landmarks_proto.size());
|
||||
result.pose_auxiliary_landmarks.resize(pose_auxiliary_landmarks_proto.size());
|
||||
std::transform(pose_landmarks_proto.begin(), pose_landmarks_proto.end(),
|
||||
result.pose_landmarks.begin(),
|
||||
components::containers::ConvertToNormalizedLandmarks);
|
||||
|
@ -43,10 +40,6 @@ PoseLandmarkerResult ConvertToPoseLandmarkerResult(
|
|||
pose_world_landmarks_proto.end(),
|
||||
result.pose_world_landmarks.begin(),
|
||||
components::containers::ConvertToLandmarks);
|
||||
std::transform(pose_auxiliary_landmarks_proto.begin(),
|
||||
pose_auxiliary_landmarks_proto.end(),
|
||||
result.pose_auxiliary_landmarks.begin(),
|
||||
components::containers::ConvertToNormalizedLandmarks);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,17 +37,12 @@ struct PoseLandmarkerResult {
|
|||
std::vector<components::containers::NormalizedLandmarks> pose_landmarks;
|
||||
// Detected pose landmarks in world coordinates.
|
||||
std::vector<components::containers::Landmarks> pose_world_landmarks;
|
||||
// Detected auxiliary landmarks, used for deriving ROI for next frame.
|
||||
std::vector<components::containers::NormalizedLandmarks>
|
||||
pose_auxiliary_landmarks;
|
||||
};
|
||||
|
||||
PoseLandmarkerResult ConvertToPoseLandmarkerResult(
|
||||
std::optional<std::vector<mediapipe::Image>> segmentation_mask,
|
||||
const std::vector<mediapipe::NormalizedLandmarkList>& pose_landmarks_proto,
|
||||
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto,
|
||||
const std::vector<mediapipe::NormalizedLandmarkList>&
|
||||
pose_auxiliary_landmarks_proto);
|
||||
const std::vector<mediapipe::LandmarkList>& pose_world_landmarks_proto);
|
||||
|
||||
} // namespace pose_landmarker
|
||||
} // namespace vision
|
||||
|
|
|
@ -47,13 +47,6 @@ TEST(ConvertFromProto, Succeeds) {
|
|||
landmark_proto.set_y(5.2);
|
||||
landmark_proto.set_z(4.3);
|
||||
|
||||
mediapipe::NormalizedLandmarkList auxiliary_landmark_list_proto;
|
||||
mediapipe::NormalizedLandmark& auxiliary_landmark_proto =
|
||||
*auxiliary_landmark_list_proto.add_landmark();
|
||||
auxiliary_landmark_proto.set_x(0.5);
|
||||
auxiliary_landmark_proto.set_y(0.5);
|
||||
auxiliary_landmark_proto.set_z(0.5);
|
||||
|
||||
std::vector<Image> segmentation_masks_lists = {segmentation_mask};
|
||||
|
||||
std::vector<mediapipe::NormalizedLandmarkList> normalized_landmarks_lists = {
|
||||
|
@ -62,12 +55,9 @@ TEST(ConvertFromProto, Succeeds) {
|
|||
std::vector<mediapipe::LandmarkList> world_landmarks_lists = {
|
||||
world_landmark_list_proto};
|
||||
|
||||
std::vector<mediapipe::NormalizedLandmarkList> auxiliary_landmarks_lists = {
|
||||
auxiliary_landmark_list_proto};
|
||||
|
||||
PoseLandmarkerResult pose_landmarker_result = ConvertToPoseLandmarkerResult(
|
||||
segmentation_masks_lists, normalized_landmarks_lists,
|
||||
world_landmarks_lists, auxiliary_landmarks_lists);
|
||||
world_landmarks_lists);
|
||||
|
||||
EXPECT_EQ(pose_landmarker_result.pose_landmarks.size(), 1);
|
||||
EXPECT_EQ(pose_landmarker_result.pose_landmarks[0].landmarks.size(), 1);
|
||||
|
@ -82,14 +72,6 @@ TEST(ConvertFromProto, Succeeds) {
|
|||
testing::FieldsAre(testing::FloatEq(3.1), testing::FloatEq(5.2),
|
||||
testing::FloatEq(4.3), std::nullopt,
|
||||
std::nullopt, std::nullopt));
|
||||
|
||||
EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks.size(), 1);
|
||||
EXPECT_EQ(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks.size(),
|
||||
1);
|
||||
EXPECT_THAT(pose_landmarker_result.pose_auxiliary_landmarks[0].landmarks[0],
|
||||
testing::FieldsAre(testing::FloatEq(0.5), testing::FloatEq(0.5),
|
||||
testing::FloatEq(0.5), std::nullopt,
|
||||
std::nullopt, std::nullopt));
|
||||
}
|
||||
|
||||
} // namespace pose_landmarker
|
||||
|
|
|
@ -79,8 +79,7 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
|
|||
|
||||
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
|
||||
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
|
||||
private static final int AUXILIARY_LANDMARKS_OUT_STREAM_INDEX = 2;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 3;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 2;
|
||||
private static int segmentationMasksOutStreamIndex = -1;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph";
|
||||
|
@ -145,7 +144,6 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
|
|||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("NORM_LANDMARKS:pose_landmarks");
|
||||
outputStreams.add("WORLD_LANDMARKS:world_landmarks");
|
||||
outputStreams.add("AUXILIARY_LANDMARKS:auxiliary_landmarks");
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
if (landmarkerOptions.outputSegmentationMasks()) {
|
||||
outputStreams.add("SEGMENTATION_MASK:segmentation_masks");
|
||||
|
@ -161,7 +159,6 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
|
|||
// If there is no poses detected in the image, just returns empty lists.
|
||||
if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return PoseLandmarkerResult.create(
|
||||
new ArrayList<>(),
|
||||
new ArrayList<>(),
|
||||
new ArrayList<>(),
|
||||
Optional.empty(),
|
||||
|
@ -179,9 +176,6 @@ public final class PoseLandmarker extends BaseVisionTaskApi {
|
|||
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(AUXILIARY_LANDMARKS_OUT_STREAM_INDEX),
|
||||
NormalizedLandmarkList.parser()),
|
||||
segmentedMasks,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX)));
|
||||
|
|
|
@ -40,7 +40,6 @@ public abstract class PoseLandmarkerResult implements TaskResult {
|
|||
static PoseLandmarkerResult create(
|
||||
List<LandmarkProto.NormalizedLandmarkList> landmarksProto,
|
||||
List<LandmarkProto.LandmarkList> worldLandmarksProto,
|
||||
List<LandmarkProto.NormalizedLandmarkList> auxiliaryLandmarksProto,
|
||||
Optional<List<MPImage>> segmentationMasksData,
|
||||
long timestampMs) {
|
||||
|
||||
|
@ -52,7 +51,6 @@ public abstract class PoseLandmarkerResult implements TaskResult {
|
|||
|
||||
List<List<NormalizedLandmark>> multiPoseLandmarks = new ArrayList<>();
|
||||
List<List<Landmark>> multiPoseWorldLandmarks = new ArrayList<>();
|
||||
List<List<NormalizedLandmark>> multiPoseAuxiliaryLandmarks = new ArrayList<>();
|
||||
for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) {
|
||||
List<NormalizedLandmark> poseLandmarks = new ArrayList<>();
|
||||
multiPoseLandmarks.add(poseLandmarks);
|
||||
|
@ -75,24 +73,10 @@ public abstract class PoseLandmarkerResult implements TaskResult {
|
|||
poseWorldLandmarkProto.getZ()));
|
||||
}
|
||||
}
|
||||
for (LandmarkProto.NormalizedLandmarkList poseAuxiliaryLandmarksProto :
|
||||
auxiliaryLandmarksProto) {
|
||||
List<NormalizedLandmark> poseAuxiliaryLandmarks = new ArrayList<>();
|
||||
multiPoseAuxiliaryLandmarks.add(poseAuxiliaryLandmarks);
|
||||
for (LandmarkProto.NormalizedLandmark poseAuxiliaryLandmarkProto :
|
||||
poseAuxiliaryLandmarksProto.getLandmarkList()) {
|
||||
poseAuxiliaryLandmarks.add(
|
||||
NormalizedLandmark.create(
|
||||
poseAuxiliaryLandmarkProto.getX(),
|
||||
poseAuxiliaryLandmarkProto.getY(),
|
||||
poseAuxiliaryLandmarkProto.getZ()));
|
||||
}
|
||||
}
|
||||
return new AutoValue_PoseLandmarkerResult(
|
||||
timestampMs,
|
||||
Collections.unmodifiableList(multiPoseLandmarks),
|
||||
Collections.unmodifiableList(multiPoseWorldLandmarks),
|
||||
Collections.unmodifiableList(multiPoseAuxiliaryLandmarks),
|
||||
multiPoseSegmentationMasks);
|
||||
}
|
||||
|
||||
|
@ -105,9 +89,6 @@ public abstract class PoseLandmarkerResult implements TaskResult {
|
|||
/** Pose landmarks in world coordniates of detected poses. */
|
||||
public abstract List<List<Landmark>> worldLandmarks();
|
||||
|
||||
/** Pose auxiliary landmarks. */
|
||||
public abstract List<List<NormalizedLandmark>> auxiliaryLandmarks();
|
||||
|
||||
/** Pose segmentation masks. */
|
||||
public abstract Optional<List<MPImage>> segmentationMasks();
|
||||
}
|
||||
|
|
|
@ -330,7 +330,6 @@ public class PoseLandmarkerTest {
|
|||
return PoseLandmarkerResult.create(
|
||||
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
|
||||
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
|
||||
Arrays.asList(),
|
||||
Optional.empty(),
|
||||
/* timestampMs= */ 0);
|
||||
}
|
||||
|
|
|
@ -74,7 +74,6 @@ def _get_expected_pose_landmarker_result(
|
|||
return PoseLandmarkerResult(
|
||||
pose_landmarks=[landmarks_detection_result.landmarks],
|
||||
pose_world_landmarks=[],
|
||||
pose_auxiliary_landmarks=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -296,7 +295,6 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
|||
# Comparing results.
|
||||
self.assertEmpty(detection_result.pose_landmarks)
|
||||
self.assertEmpty(detection_result.pose_world_landmarks)
|
||||
self.assertEmpty(detection_result.pose_auxiliary_landmarks)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
|
@ -391,7 +389,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
|||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [], [])),
|
||||
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [])),
|
||||
)
|
||||
def test_detect_for_video(
|
||||
self, image_path, rotation, output_segmentation_masks, expected_result
|
||||
|
@ -473,7 +471,7 @@ class PoseLandmarkerTest(parameterized.TestCase):
|
|||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [], [])),
|
||||
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [])),
|
||||
)
|
||||
def test_detect_async_calls(
|
||||
self, image_path, rotation, output_segmentation_masks, expected_result
|
||||
|
|
|
@ -49,8 +49,6 @@ _NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks'
|
|||
_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS'
|
||||
_POSE_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
|
||||
_POSE_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
|
||||
_POSE_AUXILIARY_LANDMARKS_STREAM_NAME = 'auxiliary_landmarks'
|
||||
_POSE_AUXILIARY_LANDMARKS_TAG = 'AUXILIARY_LANDMARKS'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
@ -62,14 +60,11 @@ class PoseLandmarkerResult:
|
|||
Attributes:
|
||||
pose_landmarks: Detected pose landmarks in normalized image coordinates.
|
||||
pose_world_landmarks: Detected pose landmarks in world coordinates.
|
||||
pose_auxiliary_landmarks: Detected auxiliary landmarks, used for deriving
|
||||
ROI for next frame.
|
||||
segmentation_masks: Optional segmentation masks for pose.
|
||||
"""
|
||||
|
||||
pose_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||
pose_world_landmarks: List[List[landmark_module.Landmark]]
|
||||
pose_auxiliary_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||
segmentation_masks: Optional[List[image_module.Image]] = None
|
||||
|
||||
|
||||
|
@ -77,7 +72,7 @@ def _build_landmarker_result(
|
|||
output_packets: Mapping[str, packet_module.Packet]
|
||||
) -> PoseLandmarkerResult:
|
||||
"""Constructs a `PoseLandmarkerResult` from output packets."""
|
||||
pose_landmarker_result = PoseLandmarkerResult([], [], [])
|
||||
pose_landmarker_result = PoseLandmarkerResult([], [])
|
||||
|
||||
if _SEGMENTATION_MASK_STREAM_NAME in output_packets:
|
||||
pose_landmarker_result.segmentation_masks = packet_getter.get_image_list(
|
||||
|
@ -90,9 +85,6 @@ def _build_landmarker_result(
|
|||
pose_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME]
|
||||
)
|
||||
pose_auxiliary_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_POSE_AUXILIARY_LANDMARKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
for proto in pose_landmarks_proto_list:
|
||||
pose_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||
|
@ -116,19 +108,6 @@ def _build_landmarker_result(
|
|||
pose_world_landmarks_list
|
||||
)
|
||||
|
||||
for proto in pose_auxiliary_landmarks_proto_list:
|
||||
pose_auxiliary_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||
pose_auxiliary_landmarks.MergeFrom(proto)
|
||||
pose_auxiliary_landmarks_list = []
|
||||
for pose_auxiliary_landmark in pose_auxiliary_landmarks.landmark:
|
||||
pose_auxiliary_landmarks_list.append(
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(
|
||||
pose_auxiliary_landmark
|
||||
)
|
||||
)
|
||||
pose_landmarker_result.pose_auxiliary_landmarks.append(
|
||||
pose_auxiliary_landmarks_list
|
||||
)
|
||||
return pose_landmarker_result
|
||||
|
||||
|
||||
|
@ -301,7 +280,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
|
||||
empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME]
|
||||
options.result_callback(
|
||||
PoseLandmarkerResult([], [], []),
|
||||
PoseLandmarkerResult([], []),
|
||||
image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
@ -320,10 +299,6 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
':'.join(
|
||||
[_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME]
|
||||
),
|
||||
':'.join([
|
||||
_POSE_AUXILIARY_LANDMARKS_TAG,
|
||||
_POSE_AUXILIARY_LANDMARKS_STREAM_NAME,
|
||||
]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
]
|
||||
|
||||
|
@ -382,7 +357,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
})
|
||||
|
||||
if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
|
||||
return PoseLandmarkerResult([], [], [])
|
||||
return PoseLandmarkerResult([], [])
|
||||
|
||||
return _build_landmarker_result(output_packets)
|
||||
|
||||
|
@ -427,7 +402,7 @@ class PoseLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
|||
})
|
||||
|
||||
if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty():
|
||||
return PoseLandmarkerResult([], [], [])
|
||||
return PoseLandmarkerResult([], [])
|
||||
|
||||
return _build_landmarker_result(output_packets)
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ const IMAGE_STREAM = 'image_in';
|
|||
const NORM_RECT_STREAM = 'norm_rect';
|
||||
const NORM_LANDMARKS_STREAM = 'normalized_landmarks';
|
||||
const WORLD_LANDMARKS_STREAM = 'world_landmarks';
|
||||
const AUXILIARY_LANDMARKS_STREAM = 'auxiliary_landmarks';
|
||||
const SEGMENTATION_MASK_STREAM = 'segmentation_masks';
|
||||
const POSE_LANDMARKER_GRAPH =
|
||||
'mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph';
|
||||
|
@ -371,9 +370,6 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
if (!('worldLandmarks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
if (!('auxilaryLandmarks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
if (this.outputSegmentationMasks && !('segmentationMasks' in this.result)) {
|
||||
return;
|
||||
}
|
||||
|
@ -419,20 +415,6 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts raw data into a landmark, and adds it to our auxilary
|
||||
* landmarks list.
|
||||
*/
|
||||
private addJsAuxiliaryLandmarks(data: Uint8Array[]): void {
|
||||
this.result.auxilaryLandmarks = [];
|
||||
for (const binaryProto of data) {
|
||||
const auxiliaryLandmarksProto =
|
||||
NormalizedLandmarkList.deserializeBinary(binaryProto);
|
||||
this.result.auxilaryLandmarks.push(
|
||||
convertToLandmarks(auxiliaryLandmarksProto));
|
||||
}
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
|
@ -440,7 +422,6 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
graphConfig.addOutputStream(NORM_LANDMARKS_STREAM);
|
||||
graphConfig.addOutputStream(WORLD_LANDMARKS_STREAM);
|
||||
graphConfig.addOutputStream(AUXILIARY_LANDMARKS_STREAM);
|
||||
graphConfig.addOutputStream(SEGMENTATION_MASK_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
|
@ -453,8 +434,6 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
landmarkerNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||
landmarkerNode.addOutputStream('NORM_LANDMARKS:' + NORM_LANDMARKS_STREAM);
|
||||
landmarkerNode.addOutputStream('WORLD_LANDMARKS:' + WORLD_LANDMARKS_STREAM);
|
||||
landmarkerNode.addOutputStream(
|
||||
'AUXILIARY_LANDMARKS:' + AUXILIARY_LANDMARKS_STREAM);
|
||||
landmarkerNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(landmarkerNode);
|
||||
|
@ -485,19 +464,6 @@ export class PoseLandmarker extends VisionTaskRunner {
|
|||
this.maybeInvokeCallback();
|
||||
});
|
||||
|
||||
this.graphRunner.attachProtoVectorListener(
|
||||
AUXILIARY_LANDMARKS_STREAM, (binaryProto, timestamp) => {
|
||||
this.addJsAuxiliaryLandmarks(binaryProto);
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
AUXILIARY_LANDMARKS_STREAM, timestamp => {
|
||||
this.result.auxilaryLandmarks = [];
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
this.maybeInvokeCallback();
|
||||
});
|
||||
|
||||
if (this.outputSegmentationMasks) {
|
||||
landmarkerNode.addOutputStream(
|
||||
'SEGMENTATION_MASK:' + SEGMENTATION_MASK_STREAM);
|
||||
|
|
|
@ -31,9 +31,6 @@ export declare interface PoseLandmarkerResult {
|
|||
/** Pose landmarks in world coordinates of detected poses. */
|
||||
worldLandmarks: Landmark[][];
|
||||
|
||||
/** Detected auxiliary landmarks, used for deriving ROI for next frame. */
|
||||
auxilaryLandmarks: NormalizedLandmark[][];
|
||||
|
||||
/** Segmentation mask for the detected pose. */
|
||||
segmentationMasks?: MPMask[];
|
||||
}
|
||||
|
|
|
@ -45,8 +45,7 @@ class PoseLandmarkerFake extends PoseLandmarker implements MediapipeTasksFake {
|
|||
this.attachListenerSpies[0] =
|
||||
spyOn(this.graphRunner, 'attachProtoVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toMatch(
|
||||
/(normalized_landmarks|world_landmarks|auxiliary_landmarks)/);
|
||||
expect(stream).toMatch(/(normalized_landmarks|world_landmarks)/);
|
||||
this.listeners.set(stream, listener as PacketListener);
|
||||
});
|
||||
this.attachListenerSpies[1] =
|
||||
|
@ -80,23 +79,23 @@ describe('PoseLandmarker', () => {
|
|||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(poseLandmarker);
|
||||
expect(poseLandmarker.listeners).toHaveSize(3);
|
||||
expect(poseLandmarker.listeners).toHaveSize(2);
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await poseLandmarker.setOptions({numPoses: 1});
|
||||
verifyGraph(poseLandmarker, [['poseDetectorGraphOptions', 'numPoses'], 1]);
|
||||
expect(poseLandmarker.listeners).toHaveSize(3);
|
||||
expect(poseLandmarker.listeners).toHaveSize(2);
|
||||
|
||||
await poseLandmarker.setOptions({numPoses: 5});
|
||||
verifyGraph(poseLandmarker, [['poseDetectorGraphOptions', 'numPoses'], 5]);
|
||||
expect(poseLandmarker.listeners).toHaveSize(3);
|
||||
expect(poseLandmarker.listeners).toHaveSize(2);
|
||||
});
|
||||
|
||||
it('registers listener for segmentation masks', async () => {
|
||||
expect(poseLandmarker.listeners).toHaveSize(3);
|
||||
expect(poseLandmarker.listeners).toHaveSize(2);
|
||||
await poseLandmarker.setOptions({outputSegmentationMasks: true});
|
||||
expect(poseLandmarker.listeners).toHaveSize(4);
|
||||
expect(poseLandmarker.listeners).toHaveSize(3);
|
||||
});
|
||||
|
||||
it('merges options', async () => {
|
||||
|
@ -209,8 +208,6 @@ describe('PoseLandmarker', () => {
|
|||
(landmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('world_landmarks')!
|
||||
(worldLandmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('auxiliary_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
|
||||
});
|
||||
|
||||
|
@ -224,7 +221,6 @@ describe('PoseLandmarker', () => {
|
|||
|
||||
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||
expect(result.segmentationMasks![0]).toBeInstanceOf(MPMask);
|
||||
done();
|
||||
});
|
||||
|
@ -240,8 +236,6 @@ describe('PoseLandmarker', () => {
|
|||
(landmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('world_landmarks')!
|
||||
(worldLandmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('auxiliary_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
});
|
||||
|
||||
// Invoke the pose landmarker twice
|
||||
|
@ -279,8 +273,6 @@ describe('PoseLandmarker', () => {
|
|||
(landmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('world_landmarks')!
|
||||
(worldLandmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('auxiliary_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
});
|
||||
|
||||
// Invoke the pose landmarker
|
||||
|
@ -291,9 +283,6 @@ describe('PoseLandmarker', () => {
|
|||
expect(result.worldLandmarks).toEqual([
|
||||
[{'x': 1, 'y': 2, 'z': 3}], [{'x': 4, 'y': 5, 'z': 6}]
|
||||
]);
|
||||
expect(result.auxilaryLandmarks).toEqual([
|
||||
[{'x': 0.1, 'y': 0.2, 'z': 0.3}], [{'x': 0.4, 'y': 0.5, 'z': 0.6}]
|
||||
]);
|
||||
done();
|
||||
});
|
||||
});
|
||||
|
@ -318,8 +307,6 @@ describe('PoseLandmarker', () => {
|
|||
poseLandmarker.listeners.get('world_landmarks')!
|
||||
(worldLandmarksProto, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
poseLandmarker.listeners.get('auxiliary_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
expect(listenerCalled).toBeFalse();
|
||||
poseLandmarker.listeners.get('segmentation_masks')!(masks, 1337);
|
||||
expect(listenerCalled).toBeTrue();
|
||||
|
@ -342,8 +329,6 @@ describe('PoseLandmarker', () => {
|
|||
(landmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('world_landmarks')!
|
||||
(worldLandmarksProto, 1337);
|
||||
poseLandmarker.listeners.get('auxiliary_landmarks')!
|
||||
(landmarksProto, 1337);
|
||||
});
|
||||
|
||||
// Invoke the pose landmarker
|
||||
|
@ -351,6 +336,5 @@ describe('PoseLandmarker', () => {
|
|||
expect(poseLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(result.landmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||
expect(result.worldLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||
expect(result.auxilaryLandmarks).toEqual([[{'x': 0, 'y': 0, 'z': 0}]]);
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue
Block a user