diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackingsolutiongpu/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackingsolutiongpu/MainActivity.java index d92eb758e..d2a6002b6 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackingsolutiongpu/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackingsolutiongpu/MainActivity.java @@ -195,7 +195,7 @@ public class MainActivity extends AppCompatActivity { this, PoseTrackingOptions.builder() .setStaticImageMode(true) - .setModelSelection(0) + .setModelComplexity(2) .setMinDetectionConfidence(0.5f) .setLandmarkVisibility(true) .build()); @@ -274,7 +274,7 @@ public class MainActivity extends AppCompatActivity { PoseTrackingOptions.builder() .setStaticImageMode(false) .setLandmarkVisibility(true) - .setModelSelection(0) + .setModelComplexity(0) .build()); poseTracking.setErrorListener( (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); diff --git a/mediapipe/graphs/pose_tracking/pose_tracking_gpu_image.pbtxt b/mediapipe/graphs/pose_tracking/pose_tracking_gpu_image.pbtxt index 9bd7c09be..a606b8fe9 100644 --- a/mediapipe/graphs/pose_tracking/pose_tracking_gpu_image.pbtxt +++ b/mediapipe/graphs/pose_tracking/pose_tracking_gpu_image.pbtxt @@ -62,6 +62,7 @@ node { node { calculator: "PoseLandmarkGpu" input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" + input_side_packet: "MODEL_COMPLEXITY:model_complexity" input_stream: "IMAGE:throttled_input_video" output_stream: "LANDMARKS:pose_landmarks" output_stream: "SEGMENTATION_MASK:segmentation_mask" diff --git a/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTracking.java b/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTracking.java index 8f2a03b0b..02dc69604 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTracking.java +++ b/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTracking.java @@ -18,7 +18,6 @@ import android.content.Context; import android.util.Log; import com.google.common.collect.ImmutableList; -import com.google.mediapipe.formats.proto.DetectionProto; import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; @@ -32,7 +31,6 @@ import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.protobuf.InvalidProtocolBufferException; import java.util.HashMap; -import java.util.List; import java.util.Map; import javax.annotation.Nullable; @@ -46,8 +44,7 @@ import javax.annotation.Nullable; public class PoseTracking extends ImageSolutionBase { private static final String TAG = "PoseTracking"; - private static final String SHORT_RANGE_GRAPH_NAME = "pose_tracking_gpu_image.binarypb"; - private static final String FULL_RANGE_GRAPH_NAME = "face_detection_full_range_image.binarypb"; + private static final String GRAPH_NAME = "pose_tracking_gpu_image.binarypb"; private static final String IMAGE_INPUT_STREAM = "input_video"; private static final ImmutableList OUTPUT_STREAMS = ImmutableList.of("pose_detection", "throttled_input_video","output_video","pose_landmarks"); @@ -121,15 +118,17 @@ public class PoseTracking extends ImageSolutionBase { SolutionInfo solutionInfo = SolutionInfo.builder() .setBinaryGraphPath( - options.modelSelection() == 0 ? SHORT_RANGE_GRAPH_NAME : FULL_RANGE_GRAPH_NAME) + GRAPH_NAME) .setImageInputStreamName(IMAGE_INPUT_STREAM) .setOutputStreamNames(OUTPUT_STREAMS) .setStaticImageMode(options.staticImageMode()) .build(); initialize(context, solutionInfo, outputHandler); - Map emptyInputSidePackets = new HashMap<>(); - start(emptyInputSidePackets); + Map inputSidePackets = new HashMap<>(); +// inputSidePackets.put("enable_segmentation", packetCreator.createBool(false)); + inputSidePackets.put("model_complexity",packetCreator.createInt32(options.modelComplexity())); + start(inputSidePackets); } /** diff --git a/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTrackingOptions.java b/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTrackingOptions.java index 8fe703187..7c5931bb0 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTrackingOptions.java +++ b/mediapipe/java/com/google/mediapipe/solutions/posetracking/PoseTrackingOptions.java @@ -35,7 +35,7 @@ import com.google.auto.value.AutoValue; public abstract class PoseTrackingOptions { public abstract boolean staticImageMode(); - public abstract int modelSelection(); + public abstract int modelComplexity(); public abstract float minDetectionConfidence(); @@ -49,12 +49,12 @@ public abstract class PoseTrackingOptions { @AutoValue.Builder public abstract static class Builder { public Builder withDefaultValues() { - return setStaticImageMode(false).setModelSelection(0).setMinDetectionConfidence(0.5f); + return setStaticImageMode(false).setModelComplexity(0).setMinDetectionConfidence(0.5f); } public abstract Builder setStaticImageMode(boolean value); - public abstract Builder setModelSelection(int value); + public abstract Builder setModelComplexity(int value); public abstract Builder setMinDetectionConfidence(float value);