cleaned dependencies and added landmark visibility modifier

This commit is contained in:
Mautisim Munir 2022-09-26 17:29:13 +05:00
parent bd57eb40b9
commit 517bb70212
6 changed files with 44 additions and 145 deletions

View File

@ -31,7 +31,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input",
"//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering",
"//mediapipe/java/com/google/mediapipe/solutioncore:video_input", "//mediapipe/java/com/google/mediapipe/solutioncore:video_input",
"//mediapipe/java/com/google/mediapipe/solutions/facedetection", # "//mediapipe/java/com/google/mediapipe/solutions/facedetection",
"//mediapipe/java/com/google/mediapipe/solutions/posetracking", "//mediapipe/java/com/google/mediapipe/solutions/posetracking",
"//third_party:androidx_appcompat", "//third_party:androidx_appcompat",
"//third_party:androidx_constraint_layout", "//third_party:androidx_constraint_layout",
@ -55,11 +55,11 @@ cc_binary(
linkshared = 1, linkshared = 1,
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
"//mediapipe/graphs/edge_detection:mobile_calculators", # "//mediapipe/graphs/edge_detection:mobile_calculators",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
#facedetection deps #facedetection deps
"//mediapipe/graphs/face_detection:face_detection_full_range_mobile_gpu_deps", # "//mediapipe/graphs/face_detection:face_detection_full_range_mobile_gpu_deps",
"//mediapipe/graphs/face_detection:mobile_calculators", # "//mediapipe/graphs/face_detection:mobile_calculators",
#pose tracking deps #pose tracking deps
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps", "//mediapipe/graphs/pose_tracking:pose_tracking_gpu_deps",
], ],

View File

@ -1,110 +0,0 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.apps.posetrackingsolutiongpu;
import static java.lang.Math.min;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import androidx.appcompat.widget.AppCompatImageView;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.mediapipe.solutions.facedetection.FaceDetectionResult;
import com.google.mediapipe.solutions.facedetection.FaceKeypoint;
/** An ImageView implementation for displaying {@link FaceDetectionResult}. */
public class FaceDetectionResultImageView extends AppCompatImageView {
private static final String TAG = "FaceDetectionResultImageView";
private static final int KEYPOINT_COLOR = Color.RED;
private static final int KEYPOINT_RADIUS = 8; // Pixels
private static final int BBOX_COLOR = Color.GREEN;
private static final int BBOX_THICKNESS = 5; // Pixels
private Bitmap latest;
public FaceDetectionResultImageView(Context context) {
super(context);
setScaleType(ScaleType.FIT_CENTER);
}
/**
* Sets a {@link FaceDetectionResult} to render.
*
* @param result a {@link FaceDetectionResult} object that contains the solution outputs and the
* input {@link Bitmap}.
*/
public void setFaceDetectionResult(FaceDetectionResult result) {
if (result == null) {
return;
}
Bitmap bmInput = result.inputBitmap();
int width = bmInput.getWidth();
int height = bmInput.getHeight();
latest = Bitmap.createBitmap(width, height, bmInput.getConfig());
Canvas canvas = new Canvas(latest);
canvas.drawBitmap(bmInput, new Matrix(), null);
int numDetectedFaces = result.multiFaceDetections().size();
for (int i = 0; i < numDetectedFaces; ++i) {
drawDetectionOnCanvas(result.multiFaceDetections().get(i), canvas, width, height);
}
}
/** Updates the image view with the latest {@link FaceDetectionResult}. */
public void update() {
postInvalidate();
if (latest != null) {
setImageBitmap(latest);
}
}
private void drawDetectionOnCanvas(Detection detection, Canvas canvas, int width, int height) {
if (!detection.hasLocationData()) {
return;
}
// Draw keypoints.
Paint keypointPaint = new Paint();
keypointPaint.setColor(KEYPOINT_COLOR);
for (int i = 0; i < FaceKeypoint.NUM_KEY_POINTS; ++i) {
int xPixel =
min(
(int) (detection.getLocationData().getRelativeKeypoints(i).getX() * width),
width - 1);
int yPixel =
min(
(int) (detection.getLocationData().getRelativeKeypoints(i).getY() * height),
height - 1);
canvas.drawCircle(xPixel, yPixel, KEYPOINT_RADIUS, keypointPaint);
}
if (!detection.getLocationData().hasRelativeBoundingBox()) {
return;
}
// Draw bounding box.
Paint bboxPaint = new Paint();
bboxPaint.setColor(BBOX_COLOR);
bboxPaint.setStyle(Paint.Style.STROKE);
bboxPaint.setStrokeWidth(BBOX_THICKNESS);
float left = detection.getLocationData().getRelativeBoundingBox().getXmin() * width;
float top = detection.getLocationData().getRelativeBoundingBox().getYmin() * height;
float right = left + detection.getLocationData().getRelativeBoundingBox().getWidth() * width;
float bottom = top + detection.getLocationData().getRelativeBoundingBox().getHeight() * height;
canvas.drawRect(left, top, right, bottom, bboxPaint);
}
}

View File

@ -17,7 +17,6 @@ package com.google.mediapipe.apps.posetrackingsolutiongpu;
import android.content.Intent; import android.content.Intent;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import android.graphics.Matrix; import android.graphics.Matrix;
import android.media.Image;
import android.os.Bundle; import android.os.Bundle;
import android.provider.MediaStore; import android.provider.MediaStore;
import androidx.appcompat.app.AppCompatActivity; import androidx.appcompat.app.AppCompatActivity;
@ -38,7 +37,6 @@ import com.google.mediapipe.solutions.posetracking.PoseTracking;
import com.google.mediapipe.solutions.posetracking.PoseTrackingOptions; import com.google.mediapipe.solutions.posetracking.PoseTrackingOptions;
import com.google.mediapipe.solutions.posetracking.PoseTrackingResult; import com.google.mediapipe.solutions.posetracking.PoseTrackingResult;
//import com.google.mediapipe.solutions.posetracking.FaceKeypoint; //import com.google.mediapipe.solutions.posetracking.FaceKeypoint;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
@ -46,7 +44,7 @@ import java.io.InputStream;
public class MainActivity extends AppCompatActivity { public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity"; private static final String TAG = "MainActivity";
private PoseTracking faceDetection; private PoseTracking poseTracking;
private enum InputSource { private enum InputSource {
UNKNOWN, UNKNOWN,
@ -82,7 +80,7 @@ public class MainActivity extends AppCompatActivity {
if (inputSource == InputSource.CAMERA) { if (inputSource == InputSource.CAMERA) {
// Restarts the camera and the opengl surface rendering. // Restarts the camera and the opengl surface rendering.
cameraInput = new CameraInput(this); cameraInput = new CameraInput(this);
cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); cameraInput.setNewFrameListener(textureFrame -> poseTracking.send(textureFrame));
glSurfaceView.post(this::startCamera); glSurfaceView.post(this::startCamera);
glSurfaceView.setVisibility(View.VISIBLE); glSurfaceView.setVisibility(View.VISIBLE);
} else if (inputSource == InputSource.VIDEO) { } else if (inputSource == InputSource.VIDEO) {
@ -165,7 +163,7 @@ public class MainActivity extends AppCompatActivity {
Log.e(TAG, "Bitmap rotation error:" + e); Log.e(TAG, "Bitmap rotation error:" + e);
} }
if (bitmap != null) { if (bitmap != null) {
faceDetection.send(bitmap); poseTracking.send(bitmap);
} }
} }
} }
@ -190,23 +188,24 @@ public class MainActivity extends AppCompatActivity {
private void setupStaticImageModePipeline() { private void setupStaticImageModePipeline() {
this.inputSource = InputSource.IMAGE; this.inputSource = InputSource.IMAGE;
// Initializes a new MediaPipe Face Detection solution instance in the static image mode. // Initializes a new MediaPipe Face Detection solution instance in the static image mode.
faceDetection = poseTracking =
new PoseTracking( new PoseTracking(
this, this,
PoseTrackingOptions.builder() PoseTrackingOptions.builder()
.setStaticImageMode(true) .setStaticImageMode(true)
.setModelSelection(0) .setModelSelection(0)
.setMinDetectionConfidence(0.5f) .setMinDetectionConfidence(0.5f)
.setLandmarkVisibility(true)
.build()); .build());
// Connects MediaPipe Face Detection solution to the user-defined PoseTrackingResultImageView. // Connects MediaPipe Face Detection solution to the user-defined PoseTrackingResultImageView.
faceDetection.setResultListener( poseTracking.setResultListener(
faceDetectionResult -> { poseTrackingResult -> {
logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ true); logNoseTipKeypoint(poseTrackingResult, /*faceIndex=*/ 0, /*showPixelValues=*/ true);
// imageView.setPoseTrackingResult(faceDetectionResult); // imageView.setPoseTrackingResult(poseTrackingResult);
// runOnUiThread(() -> imageView.update()); // runOnUiThread(() -> imageView.update());
}); });
faceDetection.setErrorListener( poseTracking.setErrorListener(
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));
// Updates the preview layout. // Updates the preview layout.
@ -232,7 +231,7 @@ public class MainActivity extends AppCompatActivity {
videoInput.start( videoInput.start(
this, this,
resultIntent.getData(), resultIntent.getData(),
faceDetection.getGlContext(), poseTracking.getGlContext(),
glSurfaceView.getWidth(), glSurfaceView.getWidth(),
glSurfaceView.getHeight())); glSurfaceView.getHeight()));
} }
@ -267,31 +266,35 @@ public class MainActivity extends AppCompatActivity {
private void setupStreamingModePipeline(InputSource inputSource) { private void setupStreamingModePipeline(InputSource inputSource) {
this.inputSource = inputSource; this.inputSource = inputSource;
// Initializes a new MediaPipe Face Detection solution instance in the streaming mode. // Initializes a new MediaPipe Face Detection solution instance in the streaming mode.
faceDetection = poseTracking =
new PoseTracking( new PoseTracking(
this, this,
PoseTrackingOptions.builder().setStaticImageMode(false).setModelSelection(0).build()); PoseTrackingOptions.builder()
faceDetection.setErrorListener( .setStaticImageMode(false)
.setLandmarkVisibility(false)
.setModelSelection(0)
.build());
poseTracking.setErrorListener(
(message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message));
if (inputSource == InputSource.CAMERA) { if (inputSource == InputSource.CAMERA) {
cameraInput = new CameraInput(this); cameraInput = new CameraInput(this);
cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); cameraInput.setNewFrameListener(textureFrame -> poseTracking.send(textureFrame));
} else if (inputSource == InputSource.VIDEO) { } else if (inputSource == InputSource.VIDEO) {
videoInput = new VideoInput(this); videoInput = new VideoInput(this);
videoInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); videoInput.setNewFrameListener(textureFrame -> poseTracking.send(textureFrame));
} }
// Initializes a new Gl surface view with a user-defined PoseTrackingResultGlRenderer. // Initializes a new Gl surface view with a user-defined PoseTrackingResultGlRenderer.
glSurfaceView = glSurfaceView =
new SolutionGlSurfaceView<>( new SolutionGlSurfaceView<>(
this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); this, poseTracking.getGlContext(), poseTracking.getGlMajorVersion());
glSurfaceView.setSolutionResultRenderer(new PoseTrackingResultGlRenderer()); glSurfaceView.setSolutionResultRenderer(new PoseTrackingResultGlRenderer());
glSurfaceView.setRenderInputImage(true); glSurfaceView.setRenderInputImage(true);
faceDetection.setResultListener( poseTracking.setResultListener(
faceDetectionResult -> { poseTrackingResult -> {
logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ false); logNoseTipKeypoint(poseTrackingResult, /*faceIndex=*/ 0, /*showPixelValues=*/ false);
glSurfaceView.setRenderData(faceDetectionResult); glSurfaceView.setRenderData(poseTrackingResult);
glSurfaceView.requestRender(); glSurfaceView.requestRender();
}); });
@ -313,7 +316,7 @@ public class MainActivity extends AppCompatActivity {
private void startCamera() { private void startCamera() {
cameraInput.start( cameraInput.start(
this, this,
faceDetection.getGlContext(), poseTracking.getGlContext(),
CameraInput.CameraFacing.FRONT, CameraInput.CameraFacing.FRONT,
glSurfaceView.getWidth(), glSurfaceView.getWidth(),
glSurfaceView.getHeight()); glSurfaceView.getHeight());
@ -331,8 +334,8 @@ public class MainActivity extends AppCompatActivity {
if (glSurfaceView != null) { if (glSurfaceView != null) {
glSurfaceView.setVisibility(View.GONE); glSurfaceView.setVisibility(View.GONE);
} }
if (faceDetection != null) { if (poseTracking != null) {
faceDetection.close(); poseTracking.close();
} }
} }

View File

@ -23,10 +23,10 @@ android_library(
"PoseTrackingResult.java", "PoseTrackingResult.java",
], ],
assets = [ assets = [
"//mediapipe/modules/face_detection:face_detection_full_range_image.binarypb", # "//mediapipe/modules/face_detection:face_detection_full_range_image.binarypb",
"//mediapipe/modules/face_detection:face_detection_full_range_sparse.tflite", # "//mediapipe/modules/face_detection:face_detection_full_range_sparse.tflite",
"//mediapipe/modules/face_detection:face_detection_short_range.tflite", # "//mediapipe/modules/face_detection:face_detection_short_range.tflite",
"//mediapipe/modules/face_detection:face_detection_short_range_image.binarypb", # "//mediapipe/modules/face_detection:face_detection_short_range_image.binarypb",
"//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb", "//mediapipe/graphs/pose_tracking:pose_tracking_gpu.binarypb",
"//mediapipe/modules/pose_landmark:pose_landmark_heavy.tflite", "//mediapipe/modules/pose_landmark:pose_landmark_heavy.tflite",
"//mediapipe/modules/pose_landmark:pose_landmark_full.tflite", "//mediapipe/modules/pose_landmark:pose_landmark_full.tflite",

View File

@ -90,10 +90,12 @@ public class PoseTracking extends ImageSolutionBase {
} catch (MediaPipeException e) { } catch (MediaPipeException e) {
reportError("Error occurs while getting MediaPipe pose tracking results.", e); reportError("Error occurs while getting MediaPipe pose tracking results.", e);
} }
int imageIndex = options.landmarkVisibility() ? OUTPUT_IMAGE_INDEX : INPUT_IMAGE_INDEX;
return poseTrackingResultBuilder return poseTrackingResultBuilder
.setImagePacket(packets.get(OUTPUT_IMAGE_INDEX)) .setImagePacket(packets.get(imageIndex))
.setTimestamp( .setTimestamp(
staticImageMode ? Long.MIN_VALUE : packets.get(OUTPUT_IMAGE_INDEX).getTimestamp()) staticImageMode ? Long.MIN_VALUE : packets.get(imageIndex).getTimestamp())
.build(); .build();
}); });

View File

@ -39,6 +39,8 @@ public abstract class PoseTrackingOptions {
public abstract float minDetectionConfidence(); public abstract float minDetectionConfidence();
public abstract boolean landmarkVisibility();
public static Builder builder() { public static Builder builder() {
return new AutoValue_PoseTrackingOptions.Builder().withDefaultValues(); return new AutoValue_PoseTrackingOptions.Builder().withDefaultValues();
} }
@ -56,6 +58,8 @@ public abstract class PoseTrackingOptions {
public abstract Builder setMinDetectionConfidence(float value); public abstract Builder setMinDetectionConfidence(float value);
public abstract Builder setLandmarkVisibility(boolean value);
public abstract PoseTrackingOptions build(); public abstract PoseTrackingOptions build();
} }
} }