posetracking solution api drawing working

This commit is contained in:
Mautisim Munir 2022-09-26 17:08:09 +05:00
parent 6417c0abe0
commit bd57eb40b9
7 changed files with 447 additions and 112 deletions

View File

@ -30,12 +30,16 @@ android_library(
"//mediapipe/java/com/google/mediapipe/glutil", "//mediapipe/java/com/google/mediapipe/glutil",
"//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input",
"//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering",
"//mediapipe/java/com/google/mediapipe/solutioncore:video_input",
"//mediapipe/java/com/google/mediapipe/solutions/facedetection", "//mediapipe/java/com/google/mediapipe/solutions/facedetection",
"//mediapipe/java/com/google/mediapipe/solutions/posetracking", "//mediapipe/java/com/google/mediapipe/solutions/posetracking",
"//third_party:androidx_appcompat", "//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout", "//third_party:androidx_constraint_layout",
"//third_party:opencv", "//third_party:opencv",
"@maven//:androidx_activity_activity",
"@maven//:androidx_concurrent_concurrent_futures", "@maven//:androidx_concurrent_concurrent_futures",
"@maven//:androidx_exifinterface_exifinterface",
"@maven//:androidx_fragment_fragment",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )

View File

@ -1,4 +1,4 @@
// Copyright 2019 The MediaPipe Authors. // Copyright 2021 The MediaPipe Authors.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -14,138 +14,355 @@
package com.google.mediapipe.apps.posetrackingsolutiongpu; package com.google.mediapipe.apps.posetrackingsolutiongpu;
import android.content.pm.ApplicationInfo; import android.content.Intent;
import android.content.pm.PackageManager; import android.graphics.Bitmap;
import android.content.pm.PackageManager.NameNotFoundException; import android.graphics.Matrix;
import android.graphics.SurfaceTexture; import android.media.Image;
import android.os.Bundle; import android.os.Bundle;
import android.provider.MediaStore;
import androidx.appcompat.app.AppCompatActivity; import androidx.appcompat.app.AppCompatActivity;
import android.util.Log; import android.util.Log;
import android.util.Size;
import android.view.SurfaceHolder;
import android.view.SurfaceView;
import android.view.View; import android.view.View;
import android.view.ViewGroup; import android.widget.Button;
import android.widget.FrameLayout; import android.widget.FrameLayout;
import android.widget.ImageView;
import com.google.mediapipe.components.CameraHelper; import androidx.activity.result.ActivityResultLauncher;
import com.google.mediapipe.components.CameraXPreviewHelper; import androidx.activity.result.contract.ActivityResultContracts;
import com.google.mediapipe.components.ExternalTextureConverter; import androidx.exifinterface.media.ExifInterface;
import com.google.mediapipe.components.FrameProcessor; // ContentResolver dependency
import com.google.mediapipe.components.PermissionHelper;
import com.google.mediapipe.formats.proto.LocationDataProto;
import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.glutil.EglManager;
import com.google.mediapipe.solutioncore.CameraInput; import com.google.mediapipe.solutioncore.CameraInput;
import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; import com.google.mediapipe.solutioncore.SolutionGlSurfaceView;
import com.google.mediapipe.solutions.facedetection.FaceDetection; import com.google.mediapipe.solutioncore.VideoInput;
import com.google.mediapipe.solutions.facedetection.FaceDetectionOptions;
import com.google.mediapipe.solutions.posetracking.PoseTracking; import com.google.mediapipe.solutions.posetracking.PoseTracking;
import com.google.mediapipe.solutions.posetracking.PoseTrackingOptions; import com.google.mediapipe.solutions.posetracking.PoseTrackingOptions;
import com.google.mediapipe.solutions.posetracking.PoseTrackingResult; import com.google.mediapipe.solutions.posetracking.PoseTrackingResult;
//import com.google.mediapipe.solutions.posetracking.FaceKeypoint;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList; /** Main activity of MediaPipe Face Detection app. */
/** Main activity of MediaPipe basic app. */
public class MainActivity extends AppCompatActivity { public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity"; private static final String TAG = "MainActivity";
// Flips the camera-preview frames vertically by default, before sending them into FrameProcessor private PoseTracking faceDetection;
// to be processed in a MediaPipe graph, and flips the processed frames back when they are
// displayed. This maybe needed because OpenGL represents images assuming the image origin is at
// the bottom-left corner, whereas MediaPipe in general assumes the image origin is at the
// top-left corner.
// NOTE: use "flipFramesVertically" in manifest metadata to override this behavior.
private static final boolean FLIP_FRAMES_VERTICALLY = true;
// Number of output frames allocated in ExternalTextureConverter. private enum InputSource {
// NOTE: use "converterNumBuffers" in manifest metadata to override number of buffers. For UNKNOWN,
// example, when there is a FlowLimiterCalculator in the graph, number of buffers should be at IMAGE,
// least `max_in_flight + max_in_queue + 1` (where max_in_flight and max_in_queue are used in VIDEO,
// FlowLimiterCalculator options). That's because we need buffers for all the frames that are in CAMERA,
// flight/queue plus one for the next frame from the camera.
private static final int NUM_BUFFERS = 2;
static {
// Load all native libraries needed by the app.
System.loadLibrary("mediapipe_jni");
try {
System.loadLibrary("opencv_java3");
} catch (java.lang.UnsatisfiedLinkError e) {
// Some example apps (e.g. template matching) require OpenCV 4.
System.loadLibrary("opencv_java4");
}
} }
private InputSource inputSource = InputSource.UNKNOWN;
// Image demo UI and image loader components.
private ActivityResultLauncher<Intent> imageGetter;
private ImageView imageView;
// Video demo UI and video loader components.
private VideoInput videoInput;
private ActivityResultLauncher<Intent> videoGetter;
// Live camera demo UI and camera components.
private CameraInput cameraInput;
private SolutionGlSurfaceView<PoseTrackingResult> glSurfaceView;
@Override @Override
protected void onCreate(Bundle savedInstanceState) { protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState); super.onCreate(savedInstanceState);
setContentView(getContentViewLayoutResId()); setContentView(R.layout.activity_main);
setupStaticImageDemoUiComponents();
setupVideoDemoUiComponents();
setupLiveDemoUiComponents();
}
PoseTrackingOptions poseTrackingOptions = PoseTrackingOptions.builder() @Override
.setStaticImageMode(false).build(); protected void onResume() {
PoseTracking poseTracking = new PoseTracking(this,poseTrackingOptions); super.onResume();
if (inputSource == InputSource.CAMERA) {
// Restarts the camera and the opengl surface rendering.
cameraInput = new CameraInput(this);
cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame));
glSurfaceView.post(this::startCamera);
glSurfaceView.setVisibility(View.VISIBLE);
} else if (inputSource == InputSource.VIDEO) {
videoInput.resume();
}
}
poseTracking.setErrorListener( @Override
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); protected void onPause() {
CameraInput cameraInput = new CameraInput(this); super.onPause();
if (inputSource == InputSource.CAMERA) {
glSurfaceView.setVisibility(View.GONE);
cameraInput.close();
} else if (inputSource == InputSource.VIDEO) {
videoInput.pause();
}
}
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
cameraInput.setNewFrameListener( private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
textureFrame -> poseTracking.send(textureFrame)); int orientation =
SolutionGlSurfaceView<PoseTrackingResult> glSurfaceView = new ExifInterface(imageData)
new SolutionGlSurfaceView<>( .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
this, poseTracking.getGlContext(), poseTracking.getGlMajorVersion()); if (orientation == ExifInterface.ORIENTATION_NORMAL) {
glSurfaceView.setSolutionResultRenderer(new PoseTrackingResultGlRenderer()); return inputBitmap;
glSurfaceView.setRenderInputImage(true); }
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
poseTracking.setResultListener( /** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
downscaleBitmap(
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}
if (bitmap != null) {
faceDetection.send(bitmap);
}
}
}
});
Button loadImageButton = findViewById(R.id.button_load_picture);
loadImageButton.setOnClickListener(
v -> {
if (inputSource != InputSource.IMAGE) {
stopCurrentPipeline();
setupStaticImageModePipeline();
}
// Reads images from gallery.
Intent pickImageIntent = new Intent(Intent.ACTION_PICK);
pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*");
imageGetter.launch(pickImageIntent);
});
// imageView = new PoseTrackingResultImageView(this);
imageView = new ImageView(this);
}
/** Sets up core workflow for static image mode. */
private void setupStaticImageModePipeline() {
this.inputSource = InputSource.IMAGE;
// Initializes a new MediaPipe Face Detection solution instance in the static image mode.
faceDetection =
new PoseTracking(
this,
PoseTrackingOptions.builder()
.setStaticImageMode(true)
.setModelSelection(0)
.setMinDetectionConfidence(0.5f)
.build());
// Connects MediaPipe Face Detection solution to the user-defined PoseTrackingResultImageView.
faceDetection.setResultListener(
faceDetectionResult -> { faceDetectionResult -> {
if (faceDetectionResult.multiPoseTrackings().isEmpty()) { logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ true);
// imageView.setPoseTrackingResult(faceDetectionResult);
// runOnUiThread(() -> imageView.update());
});
faceDetection.setErrorListener(
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));
// Updates the preview layout.
FrameLayout frameLayout = findViewById(R.id.preview_display_layout);
frameLayout.removeAllViewsInLayout();
imageView.setImageDrawable(null);
frameLayout.addView(imageView);
imageView.setVisibility(View.VISIBLE);
}
/** Sets up the UI components for the video demo. */
private void setupVideoDemoUiComponents() {
// The Intent to access gallery and read a video file.
videoGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
glSurfaceView.post(
() ->
videoInput.start(
this,
resultIntent.getData(),
faceDetection.getGlContext(),
glSurfaceView.getWidth(),
glSurfaceView.getHeight()));
}
}
});
Button loadVideoButton = findViewById(R.id.button_load_video);
loadVideoButton.setOnClickListener(
v -> {
stopCurrentPipeline();
setupStreamingModePipeline(InputSource.VIDEO);
// Reads video from gallery.
Intent pickVideoIntent = new Intent(Intent.ACTION_PICK);
pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*");
videoGetter.launch(pickVideoIntent);
});
}
/** Sets up the UI components for the live demo with camera input. */
private void setupLiveDemoUiComponents() {
Button startCameraButton = findViewById(R.id.button_start_camera);
startCameraButton.setOnClickListener(
v -> {
if (inputSource == InputSource.CAMERA) {
return; return;
} }
LocationDataProto.LocationData locationData = faceDetectionResult stopCurrentPipeline();
.multiPoseTrackings() setupStreamingModePipeline(InputSource.CAMERA);
.get(0) });
.getLocationData(); }
// .getRelativeKeypoints(FaceKeypoint.NOSE_TIP);
Log.i( /** Sets up core workflow for streaming mode. */
TAG, locationData.toString()); private void setupStreamingModePipeline(InputSource inputSource) {
// String.format( this.inputSource = inputSource;
// "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]): x=%f, y=%f", // Initializes a new MediaPipe Face Detection solution instance in the streaming mode.
// noseTip.getX(), noseTip.getY())); faceDetection =
// Request GL rendering. new PoseTracking(
this,
PoseTrackingOptions.builder().setStaticImageMode(false).setModelSelection(0).build());
faceDetection.setErrorListener(
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));
if (inputSource == InputSource.CAMERA) {
cameraInput = new CameraInput(this);
cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame));
} else if (inputSource == InputSource.VIDEO) {
videoInput = new VideoInput(this);
videoInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame));
}
// Initializes a new Gl surface view with a user-defined PoseTrackingResultGlRenderer.
glSurfaceView =
new SolutionGlSurfaceView<>(
this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion());
glSurfaceView.setSolutionResultRenderer(new PoseTrackingResultGlRenderer());
glSurfaceView.setRenderInputImage(true);
faceDetection.setResultListener(
faceDetectionResult -> {
logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ false);
glSurfaceView.setRenderData(faceDetectionResult); glSurfaceView.setRenderData(faceDetectionResult);
glSurfaceView.requestRender(); glSurfaceView.requestRender();
}); });
// The runnable to start camera after the GLSurfaceView is attached.
glSurfaceView.post( // The runnable to start camera after the gl surface view is attached.
() -> // For video input source, videoInput.start() will be called when the video uri is available.
cameraInput.start( if (inputSource == InputSource.CAMERA) {
this, glSurfaceView.post(this::startCamera);
poseTracking.getGlContext(), }
CameraInput.CameraFacing.FRONT,
glSurfaceView.getWidth(), // Updates the preview layout.
glSurfaceView.getHeight()));
glSurfaceView.setVisibility(View.VISIBLE);
FrameLayout frameLayout = findViewById(R.id.preview_display_layout); FrameLayout frameLayout = findViewById(R.id.preview_display_layout);
imageView.setVisibility(View.GONE);
frameLayout.removeAllViewsInLayout(); frameLayout.removeAllViewsInLayout();
frameLayout.addView(glSurfaceView); frameLayout.addView(glSurfaceView);
glSurfaceView.setVisibility(View.VISIBLE); glSurfaceView.setVisibility(View.VISIBLE);
frameLayout.requestLayout(); frameLayout.requestLayout();
} }
private void startCamera() {
// Used to obtain the content view for this application. If you are extending this class, and cameraInput.start(
// have a custom layout, override this method and return the custom layout. this,
protected int getContentViewLayoutResId() { faceDetection.getGlContext(),
return R.layout.activity_main; CameraInput.CameraFacing.FRONT,
glSurfaceView.getWidth(),
glSurfaceView.getHeight());
} }
private void stopCurrentPipeline() {
if (cameraInput != null) {
cameraInput.setNewFrameListener(null);
cameraInput.close();
}
if (videoInput != null) {
videoInput.setNewFrameListener(null);
videoInput.close();
}
if (glSurfaceView != null) {
glSurfaceView.setVisibility(View.GONE);
}
if (faceDetection != null) {
faceDetection.close();
}
}
private void logNoseTipKeypoint(
PoseTrackingResult result, int faceIndex, boolean showPixelValues) {
if (result.multiPoseTrackings().isEmpty()) {
return;
}
// RelativeKeypoint noseTip =
// result
// .multiPoseTrackings()
// .get(faceIndex)
// .getLocationData()
// .getRelativeKeypoints(FaceKeypoint.NOSE_TIP);
// For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates.
if (showPixelValues) {
int width = result.inputBitmap().getWidth();
int height = result.inputBitmap().getHeight();
// Log.i(
// TAG,
// String.format(
// "MediaPipe Face Detection nose tip coordinates (pixel values): x=%f, y=%f",
// noseTip.getX() * width, noseTip.getY() * height));
} else {
// Log.i(
// TAG,
// String.format(
// "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]):"
// + " x=%f, y=%f",
// noseTip.getX(), noseTip.getY()));
}
}
} }

View File

@ -1,20 +1,40 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" <LinearLayout
xmlns:app="http://schemas.android.com/apk/res-auto" xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent" android:layout_width="match_parent"
android:layout_height="match_parent"> android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:id="@+id/buttons"
android:layout_width="match_parent"
android:layout_height="wrap_content"
style="?android:attr/buttonBarStyle" android:gravity="center"
android:orientation="horizontal">
<Button
android:id="@+id/button_load_picture"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="Load Image" />
<Button
android:id="@+id/button_load_video"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="Load Video" />
<Button
android:id="@+id/button_start_camera"
android:layout_width="wrap_content"
style="?android:attr/buttonBarButtonStyle" android:layout_height="wrap_content"
android:text="Start Camera" />
</LinearLayout>
<FrameLayout <FrameLayout
android:id="@+id/preview_display_layout" android:id="@+id/preview_display_layout"
android:layout_width="fill_parent" android:layout_width="match_parent"
android:layout_height="fill_parent" android:layout_height="match_parent">
android:layout_weight="1">
<TextView <TextView
android:id="@+id/no_camera_access_view" android:id="@+id/no_view"
android:layout_height="fill_parent" android:layout_width="match_parent"
android:layout_width="fill_parent" android:layout_height="wrap_content"
android:gravity="center" android:gravity="center"
android:text="@string/no_camera_access" /> android:text="No camera" />
</FrameLayout> </FrameLayout>
</androidx.constraintlayout.widget.ConstraintLayout> </LinearLayout>

View File

@ -1,12 +1,16 @@
# MediaPipe graph that performs pose tracking with TensorFlow Lite on GPU. # MediaPipe graph that performs pose tracking with TensorFlow Lite on GPU.
# GPU buffer. (GpuBuffer) # GPU buffer. (GpuBuffer)
input_stream: "input_video" input_stream: "IMAGE:input_video"
# Output image with rendered results. (GpuBuffer) # Output image with rendered results. (GpuBuffer)
output_stream: "output_video" output_stream: "output_video"
# Pose landmarks. (NormalizedLandmarkList) # Pose landmarks. (NormalizedLandmarkList)
output_stream: "pose_landmarks" #output_stream: "pose_landmarks"
output_stream: "IMAGE:throttled_input_video"
output_stream: "DETECTION:pose_detection"
output_stream: "output_video"
# Generates side packet to enable segmentation. # Generates side packet to enable segmentation.
node { node {
@ -14,7 +18,7 @@ node {
output_side_packet: "PACKET:enable_segmentation" output_side_packet: "PACKET:enable_segmentation"
node_options: { node_options: {
[type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: {
packet { bool_value: true } packet { bool_value: false }
} }
} }
} }

View File

@ -0,0 +1,63 @@
# MediaPipe graph that performs pose tracking with TensorFlow Lite on GPU.
# GPU buffer. (GpuBuffer)
input_stream: "input_video"
# Output image with rendered results. (GpuBuffer)
output_stream: "output_video"
# Pose landmarks. (NormalizedLandmarkList)
output_stream: "pose_landmarks"
# Generates side packet to enable segmentation.
node {
calculator: "ConstantSidePacketCalculator"
output_side_packet: "PACKET:enable_segmentation"
node_options: {
[type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: {
packet { bool_value: true }
}
}
}
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for downstream nodes
# (calculators and subgraphs) in the graph to finish their tasks before it
# passes through another image. All images that come in while waiting are
# dropped, limiting the number of in-flight images in most part of the graph to
# 1. This prevents the downstream nodes from queuing up incoming images and data
# excessively, which leads to increased latency and memory usage, unwanted in
# real-time mobile applications. It also eliminates unnecessarily computation,
# e.g., the output produced by a node may get dropped downstream if the
# subsequent nodes are still busy processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:output_video"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Subgraph that detects poses and corresponding landmarks.
node {
calculator: "PoseLandmarkGpu"
input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation"
input_stream: "IMAGE:throttled_input_video"
output_stream: "LANDMARKS:pose_landmarks"
output_stream: "SEGMENTATION_MASK:segmentation_mask"
output_stream: "DETECTION:pose_detection"
output_stream: "ROI_FROM_LANDMARKS:roi_from_landmarks"
}
# Subgraph that renders pose-landmark annotation onto the input image.
node {
calculator: "PoseRendererGpu"
input_stream: "IMAGE:throttled_input_video"
input_stream: "LANDMARKS:pose_landmarks"
input_stream: "SEGMENTATION_MASK:segmentation_mask"
input_stream: "DETECTION:pose_detection"
input_stream: "ROI:roi_from_landmarks"
output_stream: "IMAGE:output_video"
}

View File

@ -43,6 +43,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input",
"//mediapipe/java/com/google/mediapipe/solutioncore:solution_base", "//mediapipe/java/com/google/mediapipe/solutioncore:solution_base",
"//third_party:autovalue", "//third_party:autovalue",
"@com_google_protobuf//:protobuf_javalite",
"@maven//:androidx_annotation_annotation", "@maven//:androidx_annotation_annotation",
"@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_code_findbugs_jsr305",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",

View File

@ -15,15 +15,21 @@
package com.google.mediapipe.solutions.posetracking; package com.google.mediapipe.solutions.posetracking;
import android.content.Context; import android.content.Context;
import android.util.Log;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.mediapipe.formats.proto.DetectionProto;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.solutioncore.ErrorListener; import com.google.mediapipe.solutioncore.ErrorListener;
import com.google.mediapipe.solutioncore.ImageSolutionBase; import com.google.mediapipe.solutioncore.ImageSolutionBase;
import com.google.mediapipe.solutioncore.OutputHandler; import com.google.mediapipe.solutioncore.OutputHandler;
import com.google.mediapipe.solutioncore.ResultListener; import com.google.mediapipe.solutioncore.ResultListener;
import com.google.mediapipe.solutioncore.SolutionInfo; import com.google.mediapipe.solutioncore.SolutionInfo;
import com.google.mediapipe.formats.proto.DetectionProto.Detection; import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -42,9 +48,10 @@ public class PoseTracking extends ImageSolutionBase {
private static final String FULL_RANGE_GRAPH_NAME = "face_detection_full_range_image.binarypb"; private static final String FULL_RANGE_GRAPH_NAME = "face_detection_full_range_image.binarypb";
private static final String IMAGE_INPUT_STREAM = "input_video"; private static final String IMAGE_INPUT_STREAM = "input_video";
private static final ImmutableList<String> OUTPUT_STREAMS = private static final ImmutableList<String> OUTPUT_STREAMS =
ImmutableList.of("pose_detection", "throttled_input_video"); ImmutableList.of("pose_detection", "throttled_input_video","output_video");
private static final int DETECTIONS_INDEX = 0; private static final int DETECTIONS_INDEX = 0;
private static final int INPUT_IMAGE_INDEX = 1; private static final int INPUT_IMAGE_INDEX = 1;
private static final int OUTPUT_IMAGE_INDEX = 2;
private final OutputHandler<PoseTrackingResult> outputHandler; private final OutputHandler<PoseTrackingResult> outputHandler;
/** /**
@ -59,15 +66,34 @@ public class PoseTracking extends ImageSolutionBase {
packets -> { packets -> {
PoseTrackingResult.Builder poseTrackingResultBuilder = PoseTrackingResult.builder(); PoseTrackingResult.Builder poseTrackingResultBuilder = PoseTrackingResult.builder();
try { try {
Packet packet = packets.get(DETECTIONS_INDEX);
if (!packet.isEmpty()){
try {
byte[] bytes = PacketGetter.getProtoBytes(packet);
Detection det = Detection.parseFrom(bytes);
poseTrackingResultBuilder.setMultiPoseTrackings(
ImmutableList.<Detection>of(det));
// Detection det = PacketGetter.getProto(packet, Detection.getDefaultInstance());
Log.v(TAG,"Packet not empty");
}catch (InvalidProtocolBufferException e){
Log.e(TAG,e.getMessage());
poseTrackingResultBuilder.setMultiPoseTrackings(
ImmutableList.<Detection>of());
}
}else {
poseTrackingResultBuilder.setMultiPoseTrackings( poseTrackingResultBuilder.setMultiPoseTrackings(
getProtoVector(packets.get(DETECTIONS_INDEX), Detection.parser())); getProtoVector(packets.get(DETECTIONS_INDEX), Detection.parser()));
}
} catch (MediaPipeException e) { } catch (MediaPipeException e) {
reportError("Error occurs while getting MediaPipe pose tracking results.", e); reportError("Error occurs while getting MediaPipe pose tracking results.", e);
} }
return poseTrackingResultBuilder return poseTrackingResultBuilder
.setImagePacket(packets.get(INPUT_IMAGE_INDEX)) .setImagePacket(packets.get(OUTPUT_IMAGE_INDEX))
.setTimestamp( .setTimestamp(
staticImageMode ? Long.MIN_VALUE : packets.get(INPUT_IMAGE_INDEX).getTimestamp()) staticImageMode ? Long.MIN_VALUE : packets.get(OUTPUT_IMAGE_INDEX).getTimestamp())
.build(); .build();
}); });