added model complexity support

This commit is contained in:
Mautisim Munir 2022-09-26 22:16:07 +05:00
parent a31b61f15a
commit 65a944e3f9
4 changed files with 12 additions and 12 deletions

View File

@ -195,7 +195,7 @@ public class MainActivity extends AppCompatActivity {
this,
PoseTrackingOptions.builder()
.setStaticImageMode(true)
.setModelSelection(0)
.setModelComplexity(2)
.setMinDetectionConfidence(0.5f)
.setLandmarkVisibility(true)
.build());
@ -274,7 +274,7 @@ public class MainActivity extends AppCompatActivity {
PoseTrackingOptions.builder()
.setStaticImageMode(false)
.setLandmarkVisibility(true)
.setModelSelection(0)
.setModelComplexity(0)
.build());
poseTracking.setErrorListener(
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));

View File

@ -62,6 +62,7 @@ node {
node {
calculator: "PoseLandmarkGpu"
input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_stream: "IMAGE:throttled_input_video"
output_stream: "LANDMARKS:pose_landmarks"
output_stream: "SEGMENTATION_MASK:segmentation_mask"

View File

@ -18,7 +18,6 @@ import android.content.Context;
import android.util.Log;
import com.google.common.collect.ImmutableList;
import com.google.mediapipe.formats.proto.DetectionProto;
import com.google.mediapipe.formats.proto.LandmarkProto;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
@ -32,7 +31,6 @@ import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
@ -46,8 +44,7 @@ import javax.annotation.Nullable;
public class PoseTracking extends ImageSolutionBase {
private static final String TAG = "PoseTracking";
private static final String SHORT_RANGE_GRAPH_NAME = "pose_tracking_gpu_image.binarypb";
private static final String FULL_RANGE_GRAPH_NAME = "face_detection_full_range_image.binarypb";
private static final String GRAPH_NAME = "pose_tracking_gpu_image.binarypb";
private static final String IMAGE_INPUT_STREAM = "input_video";
private static final ImmutableList<String> OUTPUT_STREAMS =
ImmutableList.of("pose_detection", "throttled_input_video","output_video","pose_landmarks");
@ -121,15 +118,17 @@ public class PoseTracking extends ImageSolutionBase {
SolutionInfo solutionInfo =
SolutionInfo.builder()
.setBinaryGraphPath(
options.modelSelection() == 0 ? SHORT_RANGE_GRAPH_NAME : FULL_RANGE_GRAPH_NAME)
GRAPH_NAME)
.setImageInputStreamName(IMAGE_INPUT_STREAM)
.setOutputStreamNames(OUTPUT_STREAMS)
.setStaticImageMode(options.staticImageMode())
.build();
initialize(context, solutionInfo, outputHandler);
Map<String, Packet> emptyInputSidePackets = new HashMap<>();
start(emptyInputSidePackets);
Map<String, Packet> inputSidePackets = new HashMap<>();
// inputSidePackets.put("enable_segmentation", packetCreator.createBool(false));
inputSidePackets.put("model_complexity",packetCreator.createInt32(options.modelComplexity()));
start(inputSidePackets);
}
/**

View File

@ -35,7 +35,7 @@ import com.google.auto.value.AutoValue;
public abstract class PoseTrackingOptions {
public abstract boolean staticImageMode();
public abstract int modelSelection();
public abstract int modelComplexity();
public abstract float minDetectionConfidence();
@ -49,12 +49,12 @@ public abstract class PoseTrackingOptions {
@AutoValue.Builder
public abstract static class Builder {
public Builder withDefaultValues() {
return setStaticImageMode(false).setModelSelection(0).setMinDetectionConfidence(0.5f);
return setStaticImageMode(false).setModelComplexity(0).setMinDetectionConfidence(0.5f);
}
public abstract Builder setStaticImageMode(boolean value);
public abstract Builder setModelSelection(int value);
public abstract Builder setModelComplexity(int value);
public abstract Builder setMinDetectionConfidence(float value);