added model complexity support
This commit is contained in:
parent
a31b61f15a
commit
65a944e3f9
|
@ -195,7 +195,7 @@ public class MainActivity extends AppCompatActivity {
|
||||||
this,
|
this,
|
||||||
PoseTrackingOptions.builder()
|
PoseTrackingOptions.builder()
|
||||||
.setStaticImageMode(true)
|
.setStaticImageMode(true)
|
||||||
.setModelSelection(0)
|
.setModelComplexity(2)
|
||||||
.setMinDetectionConfidence(0.5f)
|
.setMinDetectionConfidence(0.5f)
|
||||||
.setLandmarkVisibility(true)
|
.setLandmarkVisibility(true)
|
||||||
.build());
|
.build());
|
||||||
|
@ -274,7 +274,7 @@ public class MainActivity extends AppCompatActivity {
|
||||||
PoseTrackingOptions.builder()
|
PoseTrackingOptions.builder()
|
||||||
.setStaticImageMode(false)
|
.setStaticImageMode(false)
|
||||||
.setLandmarkVisibility(true)
|
.setLandmarkVisibility(true)
|
||||||
.setModelSelection(0)
|
.setModelComplexity(0)
|
||||||
.build());
|
.build());
|
||||||
poseTracking.setErrorListener(
|
poseTracking.setErrorListener(
|
||||||
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));
|
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));
|
||||||
|
|
|
@ -62,6 +62,7 @@ node {
|
||||||
node {
|
node {
|
||||||
calculator: "PoseLandmarkGpu"
|
calculator: "PoseLandmarkGpu"
|
||||||
input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation"
|
input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation"
|
||||||
|
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
|
||||||
input_stream: "IMAGE:throttled_input_video"
|
input_stream: "IMAGE:throttled_input_video"
|
||||||
output_stream: "LANDMARKS:pose_landmarks"
|
output_stream: "LANDMARKS:pose_landmarks"
|
||||||
output_stream: "SEGMENTATION_MASK:segmentation_mask"
|
output_stream: "SEGMENTATION_MASK:segmentation_mask"
|
||||||
|
|
|
@ -18,7 +18,6 @@ import android.content.Context;
|
||||||
import android.util.Log;
|
import android.util.Log;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableList;
|
import com.google.common.collect.ImmutableList;
|
||||||
import com.google.mediapipe.formats.proto.DetectionProto;
|
|
||||||
import com.google.mediapipe.formats.proto.LandmarkProto;
|
import com.google.mediapipe.formats.proto.LandmarkProto;
|
||||||
import com.google.mediapipe.framework.MediaPipeException;
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
import com.google.mediapipe.framework.Packet;
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
@ -32,7 +31,6 @@ import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
import com.google.protobuf.InvalidProtocolBufferException;
|
import com.google.protobuf.InvalidProtocolBufferException;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import javax.annotation.Nullable;
|
import javax.annotation.Nullable;
|
||||||
|
|
||||||
|
@ -46,8 +44,7 @@ import javax.annotation.Nullable;
|
||||||
public class PoseTracking extends ImageSolutionBase {
|
public class PoseTracking extends ImageSolutionBase {
|
||||||
private static final String TAG = "PoseTracking";
|
private static final String TAG = "PoseTracking";
|
||||||
|
|
||||||
private static final String SHORT_RANGE_GRAPH_NAME = "pose_tracking_gpu_image.binarypb";
|
private static final String GRAPH_NAME = "pose_tracking_gpu_image.binarypb";
|
||||||
private static final String FULL_RANGE_GRAPH_NAME = "face_detection_full_range_image.binarypb";
|
|
||||||
private static final String IMAGE_INPUT_STREAM = "input_video";
|
private static final String IMAGE_INPUT_STREAM = "input_video";
|
||||||
private static final ImmutableList<String> OUTPUT_STREAMS =
|
private static final ImmutableList<String> OUTPUT_STREAMS =
|
||||||
ImmutableList.of("pose_detection", "throttled_input_video","output_video","pose_landmarks");
|
ImmutableList.of("pose_detection", "throttled_input_video","output_video","pose_landmarks");
|
||||||
|
@ -121,15 +118,17 @@ public class PoseTracking extends ImageSolutionBase {
|
||||||
SolutionInfo solutionInfo =
|
SolutionInfo solutionInfo =
|
||||||
SolutionInfo.builder()
|
SolutionInfo.builder()
|
||||||
.setBinaryGraphPath(
|
.setBinaryGraphPath(
|
||||||
options.modelSelection() == 0 ? SHORT_RANGE_GRAPH_NAME : FULL_RANGE_GRAPH_NAME)
|
GRAPH_NAME)
|
||||||
.setImageInputStreamName(IMAGE_INPUT_STREAM)
|
.setImageInputStreamName(IMAGE_INPUT_STREAM)
|
||||||
.setOutputStreamNames(OUTPUT_STREAMS)
|
.setOutputStreamNames(OUTPUT_STREAMS)
|
||||||
.setStaticImageMode(options.staticImageMode())
|
.setStaticImageMode(options.staticImageMode())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
initialize(context, solutionInfo, outputHandler);
|
initialize(context, solutionInfo, outputHandler);
|
||||||
Map<String, Packet> emptyInputSidePackets = new HashMap<>();
|
Map<String, Packet> inputSidePackets = new HashMap<>();
|
||||||
start(emptyInputSidePackets);
|
// inputSidePackets.put("enable_segmentation", packetCreator.createBool(false));
|
||||||
|
inputSidePackets.put("model_complexity",packetCreator.createInt32(options.modelComplexity()));
|
||||||
|
start(inputSidePackets);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -35,7 +35,7 @@ import com.google.auto.value.AutoValue;
|
||||||
public abstract class PoseTrackingOptions {
|
public abstract class PoseTrackingOptions {
|
||||||
public abstract boolean staticImageMode();
|
public abstract boolean staticImageMode();
|
||||||
|
|
||||||
public abstract int modelSelection();
|
public abstract int modelComplexity();
|
||||||
|
|
||||||
public abstract float minDetectionConfidence();
|
public abstract float minDetectionConfidence();
|
||||||
|
|
||||||
|
@ -49,12 +49,12 @@ public abstract class PoseTrackingOptions {
|
||||||
@AutoValue.Builder
|
@AutoValue.Builder
|
||||||
public abstract static class Builder {
|
public abstract static class Builder {
|
||||||
public Builder withDefaultValues() {
|
public Builder withDefaultValues() {
|
||||||
return setStaticImageMode(false).setModelSelection(0).setMinDetectionConfidence(0.5f);
|
return setStaticImageMode(false).setModelComplexity(0).setMinDetectionConfidence(0.5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Builder setStaticImageMode(boolean value);
|
public abstract Builder setStaticImageMode(boolean value);
|
||||||
|
|
||||||
public abstract Builder setModelSelection(int value);
|
public abstract Builder setModelComplexity(int value);
|
||||||
|
|
||||||
public abstract Builder setMinDetectionConfidence(float value);
|
public abstract Builder setMinDetectionConfidence(float value);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user