diff --git a/WORKSPACE b/WORKSPACE index aacf856c2..9dbbff399 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -16,11 +16,11 @@ bazel_skylib_workspace() load("@bazel_skylib//lib:versions.bzl", "versions") versions.check(minimum_bazel_version = "3.7.2") -# ABSL cpp library lts_2020_09_23 +# ABSL cpp library lts_2021_03_24, patch 2. http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/20200923.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -29,8 +29,8 @@ http_archive( patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20200923", - sha256 = "b3744a4f7a249d5eaf2309daad597631ce77ea62e0fc6abffbab4b4c3dc0fc08" + strip_prefix = "abseil-cpp-20210324.2", + sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" ) http_archive( @@ -333,6 +333,7 @@ maven_install( "androidx.concurrent:concurrent-futures:1.0.0-alpha03", "androidx.lifecycle:lifecycle-common:2.3.1", "androidx.activity:activity:1.2.2", + "androidx.exifinterface:exifinterface:1.3.3", "androidx.fragment:fragment:1.3.4", "androidx.annotation:annotation:aar:1.1.0", "androidx.appcompat:appcompat:aar:1.1.0-rc01", @@ -349,8 +350,8 @@ maven_install( "com.google.auto.value:auto-value:1.8.1", "com.google.auto.value:auto-value-annotations:1.8.1", "com.google.code.findbugs:jsr305:latest.release", - "com.google.flogger:flogger-system-backend:latest.release", - "com.google.flogger:flogger:latest.release", + "com.google.flogger:flogger-system-backend:0.6", + "com.google.flogger:flogger:0.6", "com.google.guava:guava:27.0.1-android", "com.google.guava:listenablefuture:1.0", "junit:junit:4.12", @@ -389,6 +390,8 @@ http_archive( patches = [ "@//third_party:org_tensorflow_compatibility_fixes.diff", "@//third_party:org_tensorflow_objc_cxx17.diff", + # Diff is generated with a script, don't update it manually. + "@//third_party:org_tensorflow_custom_ops.diff", ], patch_args = [ "-p1", diff --git a/docs/getting_started/android_solutions.md b/docs/getting_started/android_solutions.md index de7135c18..5b3f537c9 100644 --- a/docs/getting_started/android_solutions.md +++ b/docs/getting_started/android_solutions.md @@ -26,15 +26,17 @@ the following into the project's Gradle dependencies: ``` dependencies { - // MediaPipe solution-core is the foundation of any MediaPipe solutions. + // MediaPipe solution-core is the foundation of any MediaPipe Solutions. implementation 'com.google.mediapipe:solution-core:latest.release' - // Optional: MediaPipe Hands solution. - implementation 'com.google.mediapipe:hands:latest.release' - // Optional: MediaPipe FaceMesh solution. + // Optional: MediaPipe Face Detection Solution. + implementation 'com.google.mediapipe:facedetection:latest.release' + // Optional: MediaPipe Face Mesh Solution. implementation 'com.google.mediapipe:facemesh:latest.release' + // Optional: MediaPipe Hands Solution. + implementation 'com.google.mediapipe:hands:latest.release' // MediaPipe deps - implementation 'com.google.flogger:flogger:latest.release' - implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.flogger:flogger:0.6' + implementation 'com.google.flogger:flogger-system-backend:0.6' implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.protobuf:protobuf-java:3.11.4' // CameraX core library @@ -45,7 +47,7 @@ dependencies { } ``` -See the detailed solutions API usage examples for different use cases in the +See the detailed solution APIs usage examples for different use cases in the solution example apps' [source code](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions). If the prebuilt maven packages are not sufficient, building the MediaPipe diff --git a/docs/getting_started/faq.md b/docs/getting_started/faq.md index 75bf8ad97..c42ef898c 100644 --- a/docs/getting_started/faq.md +++ b/docs/getting_started/faq.md @@ -103,7 +103,7 @@ monotonically increasing timestamps. By convention, realtime calculators and graphs use the recording time or the presentation time as the timestamp for each packet, with each timestamp representing microseconds since `Jan/1/1970:00:00:00`. This allows packets from various sources to be processed -in a gloablly consistent order. +in a globally consistent order. Normally for offline processing, every input packet is processed and processing continues as long as necessary. For online processing, it is often necessary to diff --git a/docs/images/attention_mesh_architecture.png b/docs/images/attention_mesh_architecture.png new file mode 100644 index 000000000..3a38de5c9 Binary files /dev/null and b/docs/images/attention_mesh_architecture.png differ diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 9d08ee482..2c8bf3c18 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -121,12 +121,10 @@ with mp_face_detection.FaceDetection( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = face_detection.process(image) # Draw the face detection annotations on the image. @@ -135,7 +133,8 @@ with mp_face_detection.FaceDetection( if results.detections: for detection in results.detections: mp_drawing.draw_detection(image, detection) - cv2.imshow('MediaPipe Face Detection', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Face Detection', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -200,7 +199,7 @@ const faceDetection = new FaceDetection({locateFile: (file) => { return `https://cdn.jsdelivr.net/npm/@mediapipe/face_detection@0.0/${file}`; }}); faceDetection.setOptions({ - modelSelection: 0 + modelSelection: 0, minDetectionConfidence: 0.5 }); faceDetection.onResults(onResults); @@ -216,6 +215,194 @@ camera.start(); ``` +### Android Solution API + +Please first follow general +[instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api) +to add MediaPipe Gradle dependencies, then try the Face Detection Solution API +in the companion +[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facedetection) +following +[these instructions](../getting_started/android_solutions.md#build-solution-example-apps-in-android-studio) +and learn more in the usage example below. + +* [staticImageMode](#static_image_mode) +* [modelSelection](#model_selection) + +#### Camera Input + +```java +// For camera input and result rendering with OpenGL. +FaceDetectionOptions faceDetectionOptions = + FaceDetectionOptions.builder() + .setStaticImageMode(false) + .setModelSelection(0).build(); +FaceDetection faceDetection = new FaceDetection(this, faceDetectionOptions); +faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + +// Initializes a new CameraInput instance and connects it to MediaPipe Face Detection Solution. +CameraInput cameraInput = new CameraInput(this); +cameraInput.setNewFrameListener( + textureFrame -> faceDetection.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceDetectionResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); +faceDetection.setResultListener( + faceDetectionResult -> { + RelativeKeypoint noseTip = + FaceDetection.getFaceKeypoint(result, 0, FaceKeypoint.NOSE_TIP); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseTip.getX(), noseTip.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceDetectionResult); + glSurfaceView.requestRender(); + }); + +// The runnable to start camera after the GLSurfaceView is attached. +glSurfaceView.post( + () -> + cameraInput.start( + this, + faceDetection.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); +``` + +#### Image Input + +```java +// For reading images from gallery and drawing the output in an ImageView. +FaceDetectionOptions faceDetectionOptions = + FaceDetectionOptions.builder() + .setStaticImageMode(true) + .setModelSelection(0).build(); +FaceDetection faceDetection = new FaceDetection(this, faceDetectionOptions); + +// Connects MediaPipe Face Detection Solution to the user-defined ImageView +// instance that allows users to have the custom drawing of the output landmarks +// on it. See mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java +// as an example. +FaceDetectionResultImageView imageView = new FaceDetectionResultImageView(this); +faceDetection.setResultListener( + faceDetectionResult -> { + int width = faceDetectionResult.inputBitmap().getWidth(); + int height = faceDetectionResult.inputBitmap().getHeight(); + RelativeKeypoint noseTip = + FaceDetection.getFaceKeypoint(result, 0, FaceKeypoint.NOSE_TIP); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip coordinates (pixel values): x=%f, y=%f", + noseTip.getX() * width, noseTip.getY() * height)); + // Request canvas drawing. + imageView.setFaceDetectionResult(faceDetectionResult); + runOnUiThread(() -> imageView.update()); + }); +faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + +// ActivityResultLauncher to get an image from the gallery as Bitmap. +ActivityResultLauncher imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null && result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + // Please also rotate the Bitmap based on its orientation. + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + if (bitmap != null) { + faceDetection.send(bitmap); + } + } + }); +Intent gallery = new Intent( + Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI); +imageGetter.launch(gallery); +``` + +#### Video Input + +```java +// For video input and result rendering with OpenGL. +FaceDetectionOptions faceDetectionOptions = + FaceDetectionOptions.builder() + .setStaticImageMode(false) + .setModelSelection(0).build(); +FaceDetection faceDetection = new FaceDetection(this, faceDetectionOptions); +faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + +// Initializes a new VideoInput instance and connects it to MediaPipe Face Detection Solution. +VideoInput videoInput = new VideoInput(this); +videoInput.setNewFrameListener( + textureFrame -> faceDetection.send(textureFrame)); + +// Initializes a new GlSurfaceView with a ResultGlRenderer instance +// that provides the interfaces to run user-defined OpenGL rendering code. +// See mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java +// as an example. +SolutionGlSurfaceView glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); +glSurfaceView.setSolutionResultRenderer(new FaceDetectionResultGlRenderer()); +glSurfaceView.setRenderInputImage(true); + +faceDetection.setResultListener( + faceDetectionResult -> { + RelativeKeypoint noseTip = + FaceDetection.getFaceKeypoint(result, 0, FaceKeypoint.NOSE_TIP); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]): x=%f, y=%f", + noseTip.getX(), noseTip.getY())); + // Request GL rendering. + glSurfaceView.setRenderData(faceDetectionResult); + glSurfaceView.requestRender(); + }); + +ActivityResultLauncher videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + faceDetection.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); +Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Video.Media.INTERNAL_CONTENT_URI); +videoGetter.launch(gallery); +``` + ## Example Apps Please first see general instructions for diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index a94785324..f36296f85 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -111,6 +111,23 @@ You can find more information about the face landmark model in this :------------------------------------------------------------------------: | *Fig 2. Face landmarks: the red box indicates the cropped area as input to the landmark model, the red dots represent the 468 landmarks in 3D, and the green lines connecting landmarks illustrate the contours around the eyes, eyebrows, lips and the entire face.* | +#### Attention Mesh Model + +In addition to the [Face Landmark Model](#face-landmark-model) we provide +another model that applies +[attention](https://en.wikipedia.org/wiki/Attention_(machine_learning)) to +semantically meaningful face regions, and therefore predicting landmarks more +accurately around lips, eyes and irises, at the expense of more compute. It +enables applications like AR makeup and AR puppeteering. + +The attention mesh model can be selected in the Solution APIs via the +[refine_landmarks](#refine_landmarks) option. You can also find more information +about the model in this [paper](https://arxiv.org/abs/2006.10962). + +![attention_mesh_architecture.png](../images/attention_mesh_architecture.png) | +:---------------------------------------------------------------------------: | +*Fig 3. Attention Mesh: Overview of model architecture.* | + ## Face Geometry Module The [Face Landmark Model](#face-landmark-model) performs a single-camera face landmark @@ -145,8 +162,8 @@ be set freely, however for better results it is advised to set them as close to the *real physical camera parameters* as possible. ![face_geometry_metric_3d_space.gif](../images/face_geometry_metric_3d_space.gif) | -:----------------------------------------------------------------------------: | -*Fig 3. A visualization of multiple key elements in the Metric 3D space.* | +:-------------------------------------------------------------------------------: | +*Fig 4. A visualization of multiple key elements in the Metric 3D space.* | #### Canonical Face Model @@ -210,7 +227,7 @@ The effect renderer is implemented as a MediaPipe | ![face_geometry_renderer.gif](../images/face_geometry_renderer.gif) | | :---------------------------------------------------------------------: | -| *Fig 4. An example of face effects rendered by the Face Geometry Effect Renderer.* | +| *Fig 5. An example of face effects rendered by the Face Geometry Effect Renderer.* | ## Solution APIs @@ -234,6 +251,12 @@ unrelated, images. Default to `false`. Maximum number of faces to detect. Default to `1`. +#### refine_landmarks + +Whether to further refine the landmark coordinates around the eyes and lips, and +output additional landmarks around the irises by applying the +[Attention Mesh Model](#attention-mesh-model). Default to `false`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the face detection model for the @@ -271,6 +294,7 @@ Supported configuration options: * [static_image_mode](#static_image_mode) * [max_num_faces](#max_num_faces) +* [refine_landmarks](#refine_landmarks) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -287,6 +311,7 @@ drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) with mp_face_mesh.FaceMesh( static_image_mode=True, max_num_faces=1, + refine_landmarks=True, min_detection_confidence=0.5) as face_mesh: for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) @@ -313,12 +338,21 @@ with mp_face_mesh.FaceMesh( landmark_drawing_spec=None, connection_drawing_spec=mp_drawing_styles .get_default_face_mesh_contours_style()) + mp_drawing.draw_landmarks( + image=annotated_image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_iris_connections_style()) cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # For webcam input: drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) cap = cv2.VideoCapture(0) with mp_face_mesh.FaceMesh( + max_num_faces=1, + refine_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5) as face_mesh: while cap.isOpened(): @@ -328,12 +362,10 @@ with mp_face_mesh.FaceMesh( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = face_mesh.process(image) # Draw the face mesh annotations on the image. @@ -355,7 +387,15 @@ with mp_face_mesh.FaceMesh( landmark_drawing_spec=None, connection_drawing_spec=mp_drawing_styles .get_default_face_mesh_contours_style()) - cv2.imshow('MediaPipe FaceMesh', image) + mp_drawing.draw_landmarks( + image=image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=mp_drawing_styles + .get_default_face_mesh_iris_connections_style()) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Face Mesh', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -370,6 +410,7 @@ and the following usage example. Supported configuration options: * [maxNumFaces](#max_num_faces) +* [refineLandmarks](#refine_landmarks) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -410,8 +451,10 @@ function onResults(results) { {color: '#C0C0C070', lineWidth: 1}); drawConnectors(canvasCtx, landmarks, FACEMESH_RIGHT_EYE, {color: '#FF3030'}); drawConnectors(canvasCtx, landmarks, FACEMESH_RIGHT_EYEBROW, {color: '#FF3030'}); + drawConnectors(canvasCtx, landmarks, FACEMESH_RIGHT_IRIS, {color: '#FF3030'}); drawConnectors(canvasCtx, landmarks, FACEMESH_LEFT_EYE, {color: '#30FF30'}); drawConnectors(canvasCtx, landmarks, FACEMESH_LEFT_EYEBROW, {color: '#30FF30'}); + drawConnectors(canvasCtx, landmarks, FACEMESH_LEFT_IRIS, {color: '#30FF30'}); drawConnectors(canvasCtx, landmarks, FACEMESH_FACE_OVAL, {color: '#E0E0E0'}); drawConnectors(canvasCtx, landmarks, FACEMESH_LIPS, {color: '#E0E0E0'}); } @@ -424,6 +467,7 @@ const faceMesh = new FaceMesh({locateFile: (file) => { }}); faceMesh.setOptions({ maxNumFaces: 1, + refineLandmarks: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); @@ -444,7 +488,7 @@ camera.start(); Please first follow general [instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api) -to add MediaPipe Gradle dependencies, then try the FaceMash solution API in the +to add MediaPipe Gradle dependencies, then try the Face Mesh Solution API in the companion [example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facemesh) following @@ -455,6 +499,7 @@ Supported configuration options: * [staticImageMode](#static_image_mode) * [maxNumFaces](#max_num_faces) +* [refineLandmarks](#refine_landmarks) * runOnGpu: Run the pipeline and the model inference on GPU or CPU. #### Camera Input @@ -463,17 +508,18 @@ Supported configuration options: // For camera input and result rendering with OpenGL. FaceMeshOptions faceMeshOptions = FaceMeshOptions.builder() - .setMode(FaceMeshOptions.STREAMING_MODE) // API soon to become - .setMaxNumFaces(1) // setStaticImageMode(false) + .setStaticImageMode(false) + .setRefineLandmarks(true) + .setMaxNumFaces(1) .setRunOnGpu(true).build(); -FaceMesh facemesh = new FaceMesh(this, faceMeshOptions); -facemesh.setErrorListener( - (message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); +FaceMesh faceMesh = new FaceMesh(this, faceMeshOptions); +faceMesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); -// Initializes a new CameraInput instance and connects it to MediaPipe FaceMesh. +// Initializes a new CameraInput instance and connects it to MediaPipe Face Mesh Solution. CameraInput cameraInput = new CameraInput(this); cameraInput.setNewFrameListener( - textureFrame -> facemesh.send(textureFrame)); + textureFrame -> faceMesh.send(textureFrame)); // Initializes a new GlSurfaceView with a ResultGlRenderer instance // that provides the interfaces to run user-defined OpenGL rendering code. @@ -481,18 +527,18 @@ cameraInput.setNewFrameListener( // as an example. SolutionGlSurfaceView glSurfaceView = new SolutionGlSurfaceView<>( - this, facemesh.getGlContext(), facemesh.getGlMajorVersion()); + this, faceMesh.getGlContext(), faceMesh.getGlMajorVersion()); glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); glSurfaceView.setRenderInputImage(true); -facemesh.setResultListener( +faceMesh.setResultListener( faceMeshResult -> { NormalizedLandmark noseLandmark = result.multiFaceLandmarks().get(0).getLandmarkList().get(1); Log.i( TAG, String.format( - "MediaPipe FaceMesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + "MediaPipe Face Mesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", noseLandmark.getX(), noseLandmark.getY())); // Request GL rendering. glSurfaceView.setRenderData(faceMeshResult); @@ -504,7 +550,7 @@ glSurfaceView.post( () -> cameraInput.start( this, - facemesh.getGlContext(), + faceMesh.getGlContext(), CameraInput.CameraFacing.FRONT, glSurfaceView.getWidth(), glSurfaceView.getHeight())); @@ -516,17 +562,18 @@ glSurfaceView.post( // For reading images from gallery and drawing the output in an ImageView. FaceMeshOptions faceMeshOptions = FaceMeshOptions.builder() - .setMode(FaceMeshOptions.STATIC_IMAGE_MODE) // API soon to become - .setMaxNumFaces(1) // setStaticImageMode(true) + .setStaticImageMode(true) + .setRefineLandmarks(true) + .setMaxNumFaces(1) .setRunOnGpu(true).build(); -FaceMesh facemesh = new FaceMesh(this, faceMeshOptions); +FaceMesh faceMesh = new FaceMesh(this, faceMeshOptions); -// Connects MediaPipe FaceMesh to the user-defined ImageView instance that allows -// users to have the custom drawing of the output landmarks on it. +// Connects MediaPipe Face Mesh Solution to the user-defined ImageView instance +// that allows users to have the custom drawing of the output landmarks on it. // See mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultImageView.java // as an example. FaceMeshResultImageView imageView = new FaceMeshResultImageView(this); -facemesh.setResultListener( +faceMesh.setResultListener( faceMeshResult -> { int width = faceMeshResult.inputBitmap().getWidth(); int height = faceMeshResult.inputBitmap().getHeight(); @@ -535,14 +582,14 @@ facemesh.setResultListener( Log.i( TAG, String.format( - "MediaPipe FaceMesh nose coordinates (pixel values): x=%f, y=%f", + "MediaPipe Face Mesh nose coordinates (pixel values): x=%f, y=%f", noseLandmark.getX() * width, noseLandmark.getY() * height)); // Request canvas drawing. imageView.setFaceMeshResult(faceMeshResult); runOnUiThread(() -> imageView.update()); }); -facemesh.setErrorListener( - (message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); +faceMesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); // ActivityResultLauncher to get an image from the gallery as Bitmap. ActivityResultLauncher imageGetter = @@ -556,11 +603,12 @@ ActivityResultLauncher imageGetter = bitmap = MediaStore.Images.Media.getBitmap( this.getContentResolver(), resultIntent.getData()); + // Please also rotate the Bitmap based on its orientation. } catch (IOException e) { Log.e(TAG, "Bitmap reading error:" + e); } if (bitmap != null) { - facemesh.send(bitmap); + faceMesh.send(bitmap); } } }); @@ -575,17 +623,18 @@ imageGetter.launch(gallery); // For video input and result rendering with OpenGL. FaceMeshOptions faceMeshOptions = FaceMeshOptions.builder() - .setMode(FaceMeshOptions.STREAMING_MODE) // API soon to become - .setMaxNumFaces(1) // setStaticImageMode(false) + .setStaticImageMode(false) + .setRefineLandmarks(true) + .setMaxNumFaces(1) .setRunOnGpu(true).build(); -FaceMesh facemesh = new FaceMesh(this, faceMeshOptions); -facemesh.setErrorListener( - (message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); +FaceMesh faceMesh = new FaceMesh(this, faceMeshOptions); +faceMesh.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); -// Initializes a new VideoInput instance and connects it to MediaPipe FaceMesh. +// Initializes a new VideoInput instance and connects it to MediaPipe Face Mesh Solution. VideoInput videoInput = new VideoInput(this); videoInput.setNewFrameListener( - textureFrame -> facemesh.send(textureFrame)); + textureFrame -> faceMesh.send(textureFrame)); // Initializes a new GlSurfaceView with a ResultGlRenderer instance // that provides the interfaces to run user-defined OpenGL rendering code. @@ -593,18 +642,18 @@ videoInput.setNewFrameListener( // as an example. SolutionGlSurfaceView glSurfaceView = new SolutionGlSurfaceView<>( - this, facemesh.getGlContext(), facemesh.getGlMajorVersion()); + this, faceMesh.getGlContext(), faceMesh.getGlMajorVersion()); glSurfaceView.setSolutionResultRenderer(new FaceMeshResultGlRenderer()); glSurfaceView.setRenderInputImage(true); -facemesh.setResultListener( +faceMesh.setResultListener( faceMeshResult -> { NormalizedLandmark noseLandmark = result.multiFaceLandmarks().get(0).getLandmarkList().get(1); Log.i( TAG, String.format( - "MediaPipe FaceMesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + "MediaPipe Face Mesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", noseLandmark.getX(), noseLandmark.getY())); // Request GL rendering. glSurfaceView.setRenderData(faceMeshResult); @@ -623,7 +672,7 @@ ActivityResultLauncher videoGetter = videoInput.start( this, resultIntent.getData(), - facemesh.getGlContext(), + faceMesh.getGlContext(), glSurfaceView.getWidth(), glSurfaceView.getHeight())); } diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index c3088d64c..eaadb526a 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -269,12 +269,10 @@ with mp_hands.Hands( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = hands.process(image) # Draw the hand annotations on the image. @@ -288,7 +286,8 @@ with mp_hands.Hands( mp_hands.HAND_CONNECTIONS, mp_drawing_styles.get_default_hand_landmarks_style(), mp_drawing_styles.get_default_hand_connections_style()) - cv2.imshow('MediaPipe Hands', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Hands', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -372,7 +371,7 @@ camera.start(); Please first follow general [instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api) -to add MediaPipe Gradle dependencies, then try the Hands solution API in the +to add MediaPipe Gradle dependencies, then try the Hands Solution API in the companion [example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/hands) following @@ -391,14 +390,14 @@ Supported configuration options: // For camera input and result rendering with OpenGL. HandsOptions handsOptions = HandsOptions.builder() - .setMode(HandsOptions.STREAMING_MODE) // API soon to become - .setMaxNumHands(1) // setStaticImageMode(false) + .setStaticImageMode(false) + .setMaxNumHands(1) .setRunOnGpu(true).build(); Hands hands = new Hands(this, handsOptions); hands.setErrorListener( (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); -// Initializes a new CameraInput instance and connects it to MediaPipe Hands. +// Initializes a new CameraInput instance and connects it to MediaPipe Hands Solution. CameraInput cameraInput = new CameraInput(this); cameraInput.setNewFrameListener( textureFrame -> hands.send(textureFrame)); @@ -444,13 +443,13 @@ glSurfaceView.post( // For reading images from gallery and drawing the output in an ImageView. HandsOptions handsOptions = HandsOptions.builder() - .setMode(HandsOptions.STATIC_IMAGE_MODE) // API soon to become - .setMaxNumHands(1) // setStaticImageMode(true) + .setStaticImageMode(true) + .setMaxNumHands(1) .setRunOnGpu(true).build(); Hands hands = new Hands(this, handsOptions); -// Connects MediaPipe Hands to the user-defined ImageView instance that allows -// users to have the custom drawing of the output landmarks on it. +// Connects MediaPipe Hands Solution to the user-defined ImageView instance that +// allows users to have the custom drawing of the output landmarks on it. // See mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java // as an example. HandsResultImageView imageView = new HandsResultImageView(this); @@ -484,6 +483,7 @@ ActivityResultLauncher imageGetter = bitmap = MediaStore.Images.Media.getBitmap( this.getContentResolver(), resultIntent.getData()); + // Please also rotate the Bitmap based on its orientation. } catch (IOException e) { Log.e(TAG, "Bitmap reading error:" + e); } @@ -503,14 +503,14 @@ imageGetter.launch(gallery); // For video input and result rendering with OpenGL. HandsOptions handsOptions = HandsOptions.builder() - .setMode(HandsOptions.STREAMING_MODE) // API soon to become - .setMaxNumHands(1) // setStaticImageMode(false) + .setStaticImageMode(false) + .setMaxNumHands(1) .setRunOnGpu(true).build(); Hands hands = new Hands(this, handsOptions); hands.setErrorListener( (message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); -// Initializes a new VideoInput instance and connects it to MediaPipe Hands. +// Initializes a new VideoInput instance and connects it to MediaPipe Hands Solution. VideoInput videoInput = new VideoInput(this); videoInput.setNewFrameListener( textureFrame -> hands.send(textureFrame)); diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index 0532a33dd..c8c60c284 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -147,6 +147,18 @@ If set to `true`, the solution filters pose landmarks across different input images to reduce jitter, but ignored if [static_image_mode](#static_image_mode) is also set to `true`. Default to `true`. +#### enable_segmentation + +If set to `true`, in addition to the pose, face and hand landmarks the solution +also generates the segmentation mask. Default to `false`. + +#### smooth_segmentation + +If set to `true`, the solution filters segmentation masks across different input +images to reduce jitter. Ignored if [enable_segmentation](#enable_segmentation) +is `false` or [static_image_mode](#static_image_mode) is `true`. Default to +`true`. + #### min_detection_confidence Minimum confidence value (`[0.0, 1.0]`) from the person-detection model for the @@ -207,6 +219,15 @@ the camera. The magnitude of `z` uses roughly the same scale as `x`. A list of 21 hand landmarks on the right hand, in the same representation as [left_hand_landmarks](#left_hand_landmarks). +#### segmentation_mask + +The output segmentation mask, predicted only when +[enable_segmentation](#enable_segmentation) is set to `true`. The mask has the +same width and height as the input image, and contains values in `[0.0, 1.0]` +where `1.0` and `0.0` indicate high certainty of a "human" and "background" +pixel respectively. Please refer to the platform-specific usage examples below +for usage details. + ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to @@ -218,6 +239,8 @@ Supported configuration options: * [static_image_mode](#static_image_mode) * [model_complexity](#model_complexity) * [smooth_landmarks](#smooth_landmarks) +* [enable_segmentation](#enable_segmentation) +* [smooth_segmentation](#smooth_segmentation) * [min_detection_confidence](#min_detection_confidence) * [min_tracking_confidence](#min_tracking_confidence) @@ -232,7 +255,8 @@ mp_holistic = mp.solutions.holistic IMAGE_FILES = [] with mp_holistic.Holistic( static_image_mode=True, - model_complexity=2) as holistic: + model_complexity=2, + enable_segmentation=True) as holistic: for idx, file in enumerate(IMAGE_FILES): image = cv2.imread(file) image_height, image_width, _ = image.shape @@ -245,8 +269,16 @@ with mp_holistic.Holistic( f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' ) - # Draw pose, left and right hands, and face landmarks on the image. + annotated_image = image.copy() + # Draw segmentation on the image. + # To improve segmentation around boundaries, consider applying a joint + # bilateral filter to "results.segmentation_mask" with "image". + condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1 + bg_image = np.zeros(image.shape, dtype=np.uint8) + bg_image[:] = BG_COLOR + annotated_image = np.where(condition, annotated_image, bg_image) + # Draw pose, left and right hands, and face landmarks on the image. mp_drawing.draw_landmarks( annotated_image, results.face_landmarks, @@ -277,12 +309,10 @@ with mp_holistic.Holistic( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = holistic.process(image) # Draw landmark annotation on the image. @@ -301,7 +331,8 @@ with mp_holistic.Holistic( mp_holistic.POSE_CONNECTIONS, landmark_drawing_spec=mp_drawing_styles .get_default_pose_landmarks_style()) - cv2.imshow('MediaPipe Holistic', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Holistic', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() @@ -317,6 +348,8 @@ Supported configuration options: * [modelComplexity](#model_complexity) * [smoothLandmarks](#smooth_landmarks) +* [enableSegmentation](#enable_segmentation) +* [smoothSegmentation](#smooth_segmentation) * [minDetectionConfidence](#min_detection_confidence) * [minTrackingConfidence](#min_tracking_confidence) @@ -349,8 +382,20 @@ const canvasCtx = canvasElement.getContext('2d'); function onResults(results) { canvasCtx.save(); canvasCtx.clearRect(0, 0, canvasElement.width, canvasElement.height); + canvasCtx.drawImage(results.segmentationMask, 0, 0, + canvasElement.width, canvasElement.height); + + // Only overwrite existing pixels. + canvasCtx.globalCompositeOperation = 'source-in'; + canvasCtx.fillStyle = '#00FF00'; + canvasCtx.fillRect(0, 0, canvasElement.width, canvasElement.height); + + // Only overwrite missing pixels. + canvasCtx.globalCompositeOperation = 'destination-atop'; canvasCtx.drawImage( results.image, 0, 0, canvasElement.width, canvasElement.height); + + canvasCtx.globalCompositeOperation = 'source-over'; drawConnectors(canvasCtx, results.poseLandmarks, POSE_CONNECTIONS, {color: '#00FF00', lineWidth: 4}); drawLandmarks(canvasCtx, results.poseLandmarks, @@ -374,6 +419,8 @@ const holistic = new Holistic({locateFile: (file) => { holistic.setOptions({ modelComplexity: 1, smoothLandmarks: true, + enableSegmentation: true, + smoothSegmentation: true, minDetectionConfidence: 0.5, minTrackingConfidence: 0.5 }); diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 2f3001722..1ae14f8f1 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -41,7 +41,10 @@ one over the other. * Face landmark model: [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark.tflite), [TF.js model](https://tfhub.dev/mediapipe/facemesh/1) -* [Model card](https://mediapipe.page.link/facemesh-mc) +* Face landmark model w/ attention (aka Attention Mesh): + [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_landmark/face_landmark_with_attention.tflite) +* [Model card](https://mediapipe.page.link/facemesh-mc), + [Model card (w/ attention)](https://mediapipe.page.link/attentionmesh-mc) ### [Iris](https://google.github.io/mediapipe/solutions/iris) diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index d7dc8f045..25259d678 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -338,11 +338,10 @@ with mp_objectron.Objectron(static_image_mode=False, # If loading a video, use 'break' instead of 'continue'. continue - # Convert the BGR image to RGB. - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = objectron.process(image) # Draw the box landmarks on the image. @@ -354,7 +353,8 @@ with mp_objectron.Objectron(static_image_mode=False, image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS) mp_drawing.draw_axis(image, detected_object.rotation, detected_object.translation) - cv2.imshow('MediaPipe Objectron', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Objectron', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 271199bb5..3c893d83a 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -316,12 +316,10 @@ with mp_pose.Pose( # If loading a video, use 'break' instead of 'continue'. continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) # To improve performance, optionally mark the image as not writeable to # pass by reference. image.flags.writeable = False + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = pose.process(image) # Draw the pose annotation on the image. @@ -332,7 +330,8 @@ with mp_pose.Pose( results.pose_landmarks, mp_pose.POSE_CONNECTIONS, landmark_drawing_spec=mp_drawing_styles.get_default_pose_landmarks_style()) - cv2.imshow('MediaPipe Pose', image) + # Flip the image horizontally for a selfie-view display. + cv2.imshow('MediaPipe Pose', cv2.flip(image, 1)) if cv2.waitKey(5) & 0xFF == 27: break cap.release() diff --git a/mediapipe/calculators/core/flow_limiter_calculator.proto b/mediapipe/calculators/core/flow_limiter_calculator.proto index 0f7c925ae..a3a71a294 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.proto +++ b/mediapipe/calculators/core/flow_limiter_calculator.proto @@ -30,7 +30,7 @@ message FlowLimiterCalculatorOptions { optional int32 max_in_flight = 1 [default = 1]; // The maximum number of frames queued waiting for processing. - // The default value limits to 1 frame awaiting processing. + // The default value limits to 0 frames awaiting processing. optional int32 max_in_queue = 2 [default = 0]; // The maximum time in microseconds to wait for a frame to finish processing. diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc index c8f1177d5..a80136be7 100644 --- a/mediapipe/calculators/core/split_vector_calculator.cc +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -80,4 +80,7 @@ typedef SplitVectorCalculator SplitClassificationListVectorCalculator; REGISTER_CALCULATOR(SplitClassificationListVectorCalculator); +typedef SplitVectorCalculator SplitUint64tVectorCalculator; +REGISTER_CALCULATOR(SplitUint64tVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 07f7d5f46..8c9305ffb 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -480,8 +480,7 @@ RectSpec ImageCroppingCalculator::GetCropSpecs(const CalculatorContext* cc, if (cc->Inputs().HasTag(kRectTag)) { const auto& rect = cc->Inputs().Tag(kRectTag).Get(); // Only use the rect if it is valid. - if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && - rect.y_center() >= 0) { + if (rect.width() > 0 && rect.height() > 0) { x_center = rect.x_center(); y_center = rect.y_center(); crop_width = rect.width(); diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index ee1bcdf96..76cc845e2 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -337,12 +337,15 @@ absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) { !cc->Inputs().Tag("FLIP_VERTICALLY").IsEmpty()) { flip_vertically_ = cc->Inputs().Tag("FLIP_VERTICALLY").Get(); } - if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS") && - !cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) { - const auto& image_size = - cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get>(); - output_width_ = image_size.first; - output_height_ = image_size.second; + if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { + if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) { + return absl::OkStatus(); + } else { + const auto& image_size = + cc->Inputs().Tag("OUTPUT_DIMENSIONS").Get>(); + output_width_ = image_size.first; + output_height_ = image_size.second; + } } if (use_gpu_) { @@ -506,6 +509,14 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { ComputeOutputDimensions(input_width, input_height, &output_width, &output_height); + if (scale_mode_ == mediapipe::ScaleMode_Mode_FILL_AND_CROP) { + const float scale = + std::min(static_cast(output_width_) / input_width, + static_cast(output_height_) / input_height); + output_width = std::round(input_width * scale); + output_height = std::round(input_height * scale); + } + if (cc->Outputs().HasTag("LETTERBOX_PADDING")) { auto padding = absl::make_unique>(); ComputeOutputLetterboxPadding(input_width, input_height, output_width, diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index 87a661be6..04c3b2cf6 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -53,7 +53,7 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; // The alpha channel can be set to a single value, or come from an image mask. // If the input image has an alpha channel, it will be updated. // If the input image doesn't have an alpha channel, one will be added. -// Adding alpha channel to a Grayscale (single channel) input is not suported. +// Adding alpha channel to a Grayscale (single channel) input is not supported. // // Inputs: // One of the following two IMAGE tags: diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index eb8950510..92ed084c9 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -1384,6 +1384,32 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "landmarks_refinement_calculator_proto", + srcs = ["landmarks_refinement_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "landmarks_refinement_calculator", + srcs = ["landmarks_refinement_calculator.cc"], + hdrs = ["landmarks_refinement_calculator.h"], + deps = [ + ":landmarks_refinement_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:core_proto", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, +) + cc_test( name = "refine_landmarks_from_heatmap_calculator_test", srcs = ["refine_landmarks_from_heatmap_calculator_test.cc"], diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.cc b/mediapipe/calculators/util/landmarks_refinement_calculator.cc new file mode 100644 index 000000000..8f734ac88 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.cc @@ -0,0 +1,197 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/util/landmarks_refinement_calculator.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/util/landmarks_refinement_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/port/ret_check.h" + +namespace mediapipe { + +namespace api2 { + +namespace { + +absl::StatusOr GetNumberOfRefinedLandmarks( + const proto_ns::RepeatedPtrField< + LandmarksRefinementCalculatorOptions::Refinement>& refinements) { + // Gather all used indexes. + std::set idxs; + for (int i = 0; i < refinements.size(); ++i) { + const auto& refinement = refinements.Get(i); + for (int i = 0; i < refinement.indexes_mapping_size(); ++i) { + idxs.insert(refinement.indexes_mapping(i)); + } + } + + // Check that indxes start with 0 and there is no gaps between min and max + // indexes. + RET_CHECK(!idxs.empty()) + << "There should be at least one landmark in indexes mapping"; + int idxs_min = *idxs.begin(); + int idxs_max = *idxs.rbegin(); + int n_idxs = idxs.size(); + RET_CHECK_EQ(idxs_min, 0) + << "Indexes are expected to start with 0 instead of " << idxs_min; + RET_CHECK_EQ(idxs_max, n_idxs - 1) + << "Indexes should have no gaps but " << idxs_max - n_idxs + 1 + << " indexes are missing"; + + return n_idxs; +} + +void RefineXY(const proto_ns::RepeatedField& indexes_mapping, + const NormalizedLandmarkList& landmarks, + NormalizedLandmarkList* refined_landmarks) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const auto& landmark = landmarks.landmark(i); + auto* refined_landmark = + refined_landmarks->mutable_landmark(indexes_mapping.Get(i)); + refined_landmark->set_x(landmark.x()); + refined_landmark->set_y(landmark.y()); + } +} + +float GetZAverage(const NormalizedLandmarkList& landmarks, + const proto_ns::RepeatedField& indexes) { + double z_sum = 0; + for (int i = 0; i < indexes.size(); ++i) { + z_sum += landmarks.landmark(indexes.Get(i)).z(); + } + return z_sum / indexes.size(); +} + +void RefineZ( + const proto_ns::RepeatedField& indexes_mapping, + const LandmarksRefinementCalculatorOptions::ZRefinement& z_refinement, + const NormalizedLandmarkList& landmarks, + NormalizedLandmarkList* refined_landmarks) { + if (z_refinement.has_none()) { + // Do nothing and keep Z that is already in refined landmarks. + } else if (z_refinement.has_copy()) { + for (int i = 0; i < landmarks.landmark_size(); ++i) { + refined_landmarks->mutable_landmark(indexes_mapping.Get(i)) + ->set_z(landmarks.landmark(i).z()); + } + } else if (z_refinement.has_assign_average()) { + const float z_average = + GetZAverage(*refined_landmarks, + z_refinement.assign_average().indexes_for_average()); + for (int i = 0; i < indexes_mapping.size(); ++i) { + refined_landmarks->mutable_landmark(indexes_mapping.Get(i)) + ->set_z(z_average); + } + } else { + CHECK(false) << "Z refinement is either not specified or not supported"; + } +} + +} // namespace + +class LandmarksRefinementCalculatorImpl + : public NodeImpl { + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + + // Validate refinements. + for (int i = 0; i < options_.refinement_size(); ++i) { + const auto& refinement = options_.refinement(i); + RET_CHECK_GT(refinement.indexes_mapping_size(), 0) + << "Refinement " << i << " has no indexes mapping"; + RET_CHECK(refinement.has_z_refinement()) + << "Refinement " << i << " has no Z refinement specified"; + RET_CHECK(refinement.z_refinement().has_none() ^ + refinement.z_refinement().has_copy() ^ + refinement.z_refinement().has_assign_average()) + << "Exactly one Z refinement should be specified"; + + const auto z_refinement = refinement.z_refinement(); + if (z_refinement.has_assign_average()) { + RET_CHECK_GT(z_refinement.assign_average().indexes_for_average_size(), + 0) + << "When using assign average Z refinement at least one index for " + "averagin should be specified"; + } + } + + // Validate indexes mapping and get total number of refined landmarks. + ASSIGN_OR_RETURN(n_refined_landmarks_, + GetNumberOfRefinedLandmarks(options_.refinement())); + + // Validate that number of refinements and landmark streams is the same. + RET_CHECK_EQ(kLandmarks(cc).Count(), options_.refinement_size()) + << "There are " << options_.refinement_size() << " refinements while " + << kLandmarks(cc).Count() << " landmark streams"; + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // If any of the refinement landmarks is missing - refinement won't happen. + for (const auto& landmarks_stream : kLandmarks(cc)) { + if (landmarks_stream.IsEmpty()) { + return absl::OkStatus(); + } + } + + // Initialize refined landmarks list. + auto refined_landmarks = absl::make_unique(); + for (int i = 0; i < n_refined_landmarks_; ++i) { + refined_landmarks->add_landmark(); + } + + // Apply input landmarks to outpu refined landmarks in provided order. + for (int i = 0; i < kLandmarks(cc).Count(); ++i) { + const auto& landmarks = kLandmarks(cc)[i].Get(); + const auto& refinement = options_.refinement(i); + + // Check number of landmarks in mapping and stream are the same. + RET_CHECK_EQ(landmarks.landmark_size(), refinement.indexes_mapping_size()) + << "There are " << landmarks.landmark_size() + << " refinement landmarks while mapping has " + << refinement.indexes_mapping_size(); + + // Refine X and Y. + RefineXY(refinement.indexes_mapping(), landmarks, + refined_landmarks.get()); + + // Refine Z. + RefineZ(refinement.indexes_mapping(), refinement.z_refinement(), + landmarks, refined_landmarks.get()); + + // Visibility and presence are not currently refined and are left as `0`. + } + + kRefinedLandmarks(cc).Send(std::move(refined_landmarks)); + return absl::OkStatus(); + } + + private: + LandmarksRefinementCalculatorOptions options_; + int n_refined_landmarks_ = 0; +}; + +MEDIAPIPE_NODE_IMPLEMENTATION(LandmarksRefinementCalculatorImpl); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.h b/mediapipe/calculators/util/landmarks_refinement_calculator.h new file mode 100644 index 000000000..1edadcd5b --- /dev/null +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.h @@ -0,0 +1,85 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_REFINEMENT_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_REFINEMENT_CALCULATOR_H_ + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { + +namespace api2 { + +// A calculator to refine one set of landmarks with another. +// +// Inputs: +// LANDMARKS: Multiple NormalizedLandmarkList to use for +// refinement. They will be applied to the resulting REFINED_LANDMARKS in +// the provided order. Each list should be non empty and contain the same +// amount of landmarks as indexes in mapping. Number of lists should be the +// same as number of refinements in options. +// +// Outputs: +// REFINED_LANDMARKS: A NormalizedLandmarkList with refined landmarks. Number +// of produced landmarks is equal to to the maximum index mapping number in +// calculator options (calculator verifies that there are no gaps in the +// mapping). +// +// Examples config: +// node { +// calculator: "LandmarksRefinementCalculator" +// input_stream: "LANDMARKS:0:mesh_landmarks" +// input_stream: "LANDMARKS:1:lips_landmarks" +// input_stream: "LANDMARKS:2:left_eye_landmarks" +// input_stream: "LANDMARKS:3:right_eye_landmarks" +// output_stream: "REFINED_LANDMARKS:landmarks" +// options: { +// [mediapipe.LandmarksRefinementCalculatorOptions.ext] { +// refinement: { +// indexes_mapping: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +// z_refinement: { copy {} } +// } +// refinement: { +// indexes_mapping: [0, 1, 2, 3] +// z_refinement: { none {} } +// } +// refinement: { +// indexes_mapping: [4, 5] +// z_refinement: { none {} } +// } +// refinement: { +// indexes_mapping: [6, 7] +// z_refinement: { none {} } +// } +// } +// } +// } +// +class LandmarksRefinementCalculator : public NodeIntf { + public: + static constexpr Input<::mediapipe::NormalizedLandmarkList>::Multiple + kLandmarks{"LANDMARKS"}; + static constexpr Output<::mediapipe::NormalizedLandmarkList> + kRefinedLandmarks{"REFINED_LANDMARKS"}; + + MEDIAPIPE_NODE_INTERFACE(LandmarksRefinementCalculator, kLandmarks, + kRefinedLandmarks); +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_UTIL_LANDMARKS_REFINEMENT_CALCULATOR_H_ diff --git a/mediapipe/calculators/util/landmarks_refinement_calculator.proto b/mediapipe/calculators/util/landmarks_refinement_calculator.proto new file mode 100644 index 000000000..e5234e713 --- /dev/null +++ b/mediapipe/calculators/util/landmarks_refinement_calculator.proto @@ -0,0 +1,71 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message LandmarksRefinementCalculatorOptions { + extend CalculatorOptions { + optional LandmarksRefinementCalculatorOptions ext = 381914658; + } + + // Do nothing and keep those Z that are already present in the resulting set + // of landmarks. + message ZRefinementNone {} + + // Simply copy Z values from the given set of landmarks to the resulting set + // of landmarks. + message ZRefinementCopy {} + + // Calculate average of the specified set of landmarks in the resulting set + // and use it as Z for all given landmarks when assigning their values to the + // resulting set of landmarks. + message ZRefinementAssignAverage { + // Indexes of the resulting landmarks to use for average. Should be non + // empty. + repeated int32 indexes_for_average = 1; + } + + // Specifies the set of instructions on assigning z value from the given set + // of landmarks to the resulting set of landmarks. + message ZRefinement { + // Exactly one Z refinement option should be specified. + oneof z_refinement_options { + ZRefinementNone none = 1; + ZRefinementCopy copy = 2; + ZRefinementAssignAverage assign_average = 3; + } + } + + // Specifies the set of instructions of assigning values to the resulting set + // of landmarks. + message Refinement { + // Maps indexes of the given set of landmarks to indexes of the resulting + // set of landmarks. Should be non empty and contain the same amount of + // indexes as landmarks in the corresponding input stream. + repeated int32 indexes_mapping = 1; + + // Z refinement instructions. + optional ZRefinement z_refinement = 2; + } + + // Refinement instructions for every landmarks input stream. Applied in the + // same order as defined. Should be the same amount of refinements as landmark + // input streams in the calculator. Union of index mappings should start with + // 0 and cover a contineous range. + repeated Refinement refinement = 1; +} diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index f2cec3ae3..263ef85c6 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -86,11 +86,11 @@ inline void GetMinMaxZ(const LandmarkListType& landmarks, float* z_min, } template -bool IsLandmarkVisibileAndPresent(const LandmarkType& landmark, - bool utilize_visibility, - float visibility_threshold, - bool utilize_presence, - float presence_threshold) { +bool IsLandmarkVisibleAndPresent(const LandmarkType& landmark, + bool utilize_visibility, + float visibility_threshold, + bool utilize_presence, + float presence_threshold) { if (utilize_visibility && landmark.has_visibility() && landmark.visibility() < visibility_threshold) { return false; @@ -153,12 +153,16 @@ void AddConnectionsWithDepth(const LandmarkListType& landmarks, const Color& max_depth_line_color, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { + if (landmark_connections[i] >= landmarks.landmark_size() || + landmark_connections[i + 1] >= landmarks.landmark_size()) { + continue; + } const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( ld0, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold) || - !IsLandmarkVisibileAndPresent( + !IsLandmarkVisibleAndPresent( ld1, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold)) { continue; @@ -196,12 +200,16 @@ void AddConnections(const LandmarkListType& landmarks, const Color& connection_color, float thickness, bool normalized, RenderData* render_data) { for (int i = 0; i < landmark_connections.size(); i += 2) { + if (landmark_connections[i] >= landmarks.landmark_size() || + landmark_connections[i + 1] >= landmarks.landmark_size()) { + continue; + } const auto& ld0 = landmarks.landmark(landmark_connections[i]); const auto& ld1 = landmarks.landmark(landmark_connections[i + 1]); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( ld0, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold) || - !IsLandmarkVisibileAndPresent( + !IsLandmarkVisibleAndPresent( ld1, utilize_visibility, visibility_threshold, utilize_presence, presence_threshold)) { continue; @@ -317,7 +325,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < landmarks.landmark_size(); ++i) { const Landmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( landmark, options_.utilize_visibility(), options_.visibility_threshold(), options_.utilize_presence(), options_.presence_threshold())) { @@ -363,7 +371,7 @@ absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < landmarks.landmark_size(); ++i) { const NormalizedLandmark& landmark = landmarks.landmark(i); - if (!IsLandmarkVisibileAndPresent( + if (!IsLandmarkVisibleAndPresent( landmark, options_.utilize_visibility(), options_.visibility_threshold(), options_.utilize_presence(), options_.presence_threshold())) { diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index e0a759bdb..15bb26826 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -36,7 +36,7 @@ inline float NormalizeRadians(float angle) { } // namespace // Performs geometric transformation to the input Rect or NormalizedRect, -// correpsonding to input stream RECT or NORM_RECT respectively. When the input +// corresponding to input stream RECT or NORM_RECT respectively. When the input // is NORM_RECT, an addition input stream IMAGE_SIZE is required, which is a // std::pair representing the image width and height. // diff --git a/mediapipe/examples/android/solutions/create_win_symlinks.bat b/mediapipe/examples/android/solutions/create_win_symlinks.bat index ea641b6e9..57bafeb2b 100644 --- a/mediapipe/examples/android/solutions/create_win_symlinks.bat +++ b/mediapipe/examples/android/solutions/create_win_symlinks.bat @@ -12,5 +12,12 @@ cd /d %~dp0 cd facemesh\src\main rm res mklink /d res ..\..\..\res + +@rem for face detection example app. +cd /d %~dp0 +cd facedetection\src\main +rm res +mklink /d res ..\..\..\res + dir pause diff --git a/mediapipe/examples/android/solutions/facedetection/build.gradle b/mediapipe/examples/android/solutions/facedetection/build.gradle new file mode 100644 index 000000000..a3264a1b8 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/build.gradle @@ -0,0 +1,51 @@ +plugins { + id 'com.android.application' +} + +android { + compileSdkVersion 30 + buildToolsVersion "30.0.3" + + defaultConfig { + applicationId "com.google.mediapipe.apps.facedetection" + minSdkVersion 21 + targetSdkVersion 30 + versionCode 1 + versionName "1.0" + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar', '*.aar']) + implementation 'androidx.appcompat:appcompat:1.3.0' + implementation 'com.google.android.material:material:1.3.0' + implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation 'androidx.exifinterface:exifinterface:1.3.3' + testImplementation 'junit:junit:4.+' + androidTestImplementation 'androidx.test.ext:junit:1.1.2' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' + // MediaPipe Face Detection Solution components. + implementation 'com.google.mediapipe:solution-core:latest.release' + implementation 'com.google.mediapipe:facedetection:latest.release' + // MediaPipe deps + implementation 'com.google.flogger:flogger:0.6' + implementation 'com.google.flogger:flogger-system-backend:0.6' + implementation 'com.google.guava:guava:27.0.1-android' + implementation 'com.google.protobuf:protobuf-java:3.11.4' + // CameraX core library + def camerax_version = "1.0.0-beta10" + implementation "androidx.camera:camera-core:$camerax_version" + implementation "androidx.camera:camera-camera2:$camerax_version" + implementation "androidx.camera:camera-lifecycle:$camerax_version" +} diff --git a/mediapipe/examples/android/solutions/facedetection/proguard-rules.pro b/mediapipe/examples/android/solutions/facedetection/proguard-rules.pro new file mode 100644 index 000000000..f1b424510 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/AndroidManifest.xml b/mediapipe/examples/android/solutions/facedetection/src/main/AndroidManifest.xml new file mode 100644 index 000000000..ffb743d2d --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/AndroidManifest.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/BUILD b/mediapipe/examples/android/solutions/facedetection/src/main/BUILD new file mode 100644 index 000000000..5044b55ed --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/BUILD @@ -0,0 +1,46 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +android_binary( + name = "facedetection", + srcs = glob(["**/*.java"]), + custom_package = "com.google.mediapipe.examples.facedetection", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.examples.facedetection", + }, + multidex = "native", + resource_files = ["//mediapipe/examples/android/solutions:resource_files"], + deps = [ + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/solutioncore:camera_input", + "//mediapipe/java/com/google/mediapipe/solutioncore:mediapipe_jni_lib", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_rendering", + "//mediapipe/java/com/google/mediapipe/solutioncore:video_input", + "//mediapipe/java/com/google/mediapipe/solutions/facedetection", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java new file mode 100644 index 000000000..df1847178 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultGlRenderer.java @@ -0,0 +1,146 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facedetection; + +import android.opengl.GLES20; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; +import com.google.mediapipe.solutions.facedetection.FaceKeypoint; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + +/** A custom implementation of {@link ResultGlRenderer} to render {@link FaceDetectionResult}. */ +public class FaceDetectionResultGlRenderer implements ResultGlRenderer { + private static final String TAG = "FaceDetectionResultGlRenderer"; + + private static final float[] KEYPOINT_COLOR = new float[] {1f, 0f, 0f, 1f}; + private static final float KEYPOINT_SIZE = 16f; + private static final float[] BBOX_COLOR = new float[] {0f, 1f, 0f, 1f}; + private static final int BBOX_THICKNESS = 8; + private static final String VERTEX_SHADER = + "uniform mat4 uProjectionMatrix;\n" + + "uniform float uPointSize;\n" + + "attribute vec4 vPosition;\n" + + "void main() {\n" + + " gl_Position = uProjectionMatrix * vPosition;\n" + + " gl_PointSize = uPointSize;" + + "}"; + private static final String FRAGMENT_SHADER = + "precision mediump float;\n" + + "uniform vec4 uColor;\n" + + "void main() {\n" + + " gl_FragColor = uColor;\n" + + "}"; + private int program; + private int positionHandle; + private int pointSizeHandle; + private int projectionMatrixHandle; + private int colorHandle; + + private int loadShader(int type, String shaderCode) { + int shader = GLES20.glCreateShader(type); + GLES20.glShaderSource(shader, shaderCode); + GLES20.glCompileShader(shader); + return shader; + } + + @Override + public void setupRendering() { + program = GLES20.glCreateProgram(); + int vertexShader = loadShader(GLES20.GL_VERTEX_SHADER, VERTEX_SHADER); + int fragmentShader = loadShader(GLES20.GL_FRAGMENT_SHADER, FRAGMENT_SHADER); + GLES20.glAttachShader(program, vertexShader); + GLES20.glAttachShader(program, fragmentShader); + GLES20.glLinkProgram(program); + positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); + pointSizeHandle = GLES20.glGetUniformLocation(program, "uPointSize"); + projectionMatrixHandle = GLES20.glGetUniformLocation(program, "uProjectionMatrix"); + colorHandle = GLES20.glGetUniformLocation(program, "uColor"); + } + + @Override + public void renderResult(FaceDetectionResult result, float[] projectionMatrix) { + if (result == null) { + return; + } + GLES20.glUseProgram(program); + GLES20.glUniformMatrix4fv(projectionMatrixHandle, 1, false, projectionMatrix, 0); + GLES20.glUniform1f(pointSizeHandle, KEYPOINT_SIZE); + int numDetectedFaces = result.multiFaceDetections().size(); + for (int i = 0; i < numDetectedFaces; ++i) { + drawDetection(result.multiFaceDetections().get(i)); + } + } + + /** + * Deletes the shader program. + * + *

This is only necessary if one wants to release the program while keeping the context around. + */ + public void release() { + GLES20.glDeleteProgram(program); + } + + private void drawDetection(Detection detection) { + if (!detection.hasLocationData()) { + return; + } + // Draw keypoints. + float[] points = new float[FaceKeypoint.NUM_KEY_POINTS * 2]; + for (int i = 0; i < FaceKeypoint.NUM_KEY_POINTS; ++i) { + points[2 * i] = detection.getLocationData().getRelativeKeypoints(i).getX(); + points[2 * i + 1] = detection.getLocationData().getRelativeKeypoints(i).getY(); + } + GLES20.glUniform4fv(colorHandle, 1, KEYPOINT_COLOR, 0); + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(points.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(points); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_POINTS, 0, FaceKeypoint.NUM_KEY_POINTS); + if (!detection.getLocationData().hasRelativeBoundingBox()) { + return; + } + // Draw bounding box. + float left = detection.getLocationData().getRelativeBoundingBox().getXmin(); + float top = detection.getLocationData().getRelativeBoundingBox().getYmin(); + float right = left + detection.getLocationData().getRelativeBoundingBox().getWidth(); + float bottom = top + detection.getLocationData().getRelativeBoundingBox().getHeight(); + drawLine(top, left, top, right); + drawLine(bottom, left, bottom, right); + drawLine(top, left, bottom, left); + drawLine(top, right, bottom, right); + } + + private void drawLine(float y1, float x1, float y2, float x2) { + GLES20.glUniform4fv(colorHandle, 1, BBOX_COLOR, 0); + GLES20.glLineWidth(BBOX_THICKNESS); + float[] vertex = {x1, y1, x2, y2}; + FloatBuffer vertexBuffer = + ByteBuffer.allocateDirect(vertex.length * 4) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer() + .put(vertex); + vertexBuffer.position(0); + GLES20.glEnableVertexAttribArray(positionHandle); + GLES20.glVertexAttribPointer(positionHandle, 2, GLES20.GL_FLOAT, false, 0, vertexBuffer); + GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); + } +} diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java new file mode 100644 index 000000000..4a2895b0e --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/FaceDetectionResultImageView.java @@ -0,0 +1,108 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facedetection; + +import static java.lang.Math.min; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; +import com.google.mediapipe.solutions.facedetection.FaceKeypoint; + +/** An ImageView implementation for displaying {@link FaceDetectionResult}. */ +public class FaceDetectionResultImageView extends AppCompatImageView { + private static final String TAG = "FaceDetectionResultImageView"; + + private static final int KEYPOINT_COLOR = Color.RED; + private static final int KEYPOINT_RADIUS = 15; + private static final int BBOX_COLOR = Color.GREEN; + private static final int BBOX_THICKNESS = 10; + private Bitmap latest; + + public FaceDetectionResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link FaceDetectionResult} to render. + * + * @param result a {@link FaceDetectionResult} object that contains the solution outputs and the + * input {@link Bitmap}. + */ + public void setFaceDetectionResult(FaceDetectionResult result) { + if (result == null) { + return; + } + Bitmap bmInput = result.inputBitmap(); + int width = bmInput.getWidth(); + int height = bmInput.getHeight(); + latest = Bitmap.createBitmap(width, height, bmInput.getConfig()); + Canvas canvas = new Canvas(latest); + + canvas.drawBitmap(bmInput, new Matrix(), null); + int numDetectedFaces = result.multiFaceDetections().size(); + for (int i = 0; i < numDetectedFaces; ++i) { + drawDetectionOnCanvas(result.multiFaceDetections().get(i), canvas, width, height); + } + } + + /** Updates the image view with the latest {@link FaceDetectionResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawDetectionOnCanvas(Detection detection, Canvas canvas, int width, int height) { + if (!detection.hasLocationData()) { + return; + } + // Draw keypoints. + Paint keypointPaint = new Paint(); + keypointPaint.setColor(KEYPOINT_COLOR); + for (int i = 0; i < FaceKeypoint.NUM_KEY_POINTS; ++i) { + int xPixel = + min( + (int) (detection.getLocationData().getRelativeKeypoints(i).getX() * width), + width - 1); + int yPixel = + min( + (int) (detection.getLocationData().getRelativeKeypoints(i).getY() * height), + height - 1); + canvas.drawCircle(xPixel, yPixel, KEYPOINT_RADIUS, keypointPaint); + } + if (!detection.getLocationData().hasRelativeBoundingBox()) { + return; + } + // Draw bounding box. + Paint bboxPaint = new Paint(); + bboxPaint.setColor(BBOX_COLOR); + bboxPaint.setStyle(Paint.Style.STROKE); + bboxPaint.setStrokeWidth(BBOX_THICKNESS); + float left = detection.getLocationData().getRelativeBoundingBox().getXmin() * width; + float top = detection.getLocationData().getRelativeBoundingBox().getYmin() * height; + float right = left + detection.getLocationData().getRelativeBoundingBox().getWidth() * width; + float bottom = top + detection.getLocationData().getRelativeBoundingBox().getHeight() * height; + canvas.drawRect(left, top, right, bottom, bboxPaint); + } +} diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/MainActivity.java b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/MainActivity.java new file mode 100644 index 000000000..7e1ad28d5 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/java/com/google/mediapipe/examples/facedetection/MainActivity.java @@ -0,0 +1,341 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.examples.facedetection; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint; +import com.google.mediapipe.solutioncore.CameraInput; +import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; +import com.google.mediapipe.solutioncore.VideoInput; +import com.google.mediapipe.solutions.facedetection.FaceDetection; +import com.google.mediapipe.solutions.facedetection.FaceDetectionOptions; +import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; +import com.google.mediapipe.solutions.facedetection.FaceKeypoint; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Face Detection app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + + private FaceDetection faceDetection; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + private InputSource inputSource = InputSource.UNKNOWN; + + // Image demo UI and image loader components. + private ActivityResultLauncher imageGetter; + private FaceDetectionResultImageView imageView; + // Video demo UI and video loader components. + private VideoInput videoInput; + private ActivityResultLauncher videoGetter; + // Live camera demo UI and camera components. + private CameraInput cameraInput; + + private SolutionGlSurfaceView glSurfaceView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupStaticImageDemoUiComponents(); + setupVideoDemoUiComponents(); + setupLiveDemoUiComponents(); + } + + @Override + protected void onResume() { + super.onResume(); + if (inputSource == InputSource.CAMERA) { + // Restarts the camera and the opengl surface rendering. + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); + glSurfaceView.post(this::startCamera); + glSurfaceView.setVisibility(View.VISIBLE); + } else if (inputSource == InputSource.VIDEO) { + videoInput.resume(); + } + } + + @Override + protected void onPause() { + super.onPause(); + if (inputSource == InputSource.CAMERA) { + glSurfaceView.setVisibility(View.GONE); + cameraInput.close(); + } else if (inputSource == InputSource.VIDEO) { + videoInput.pause(); + } + } + + /** Sets up the UI components for the static image demo. */ + private void setupStaticImageDemoUiComponents() { + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + try { + bitmap = + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData()); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + int orientation = + new ExifInterface(imageData) + .getAttributeInt( + ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation != ExifInterface.ORIENTATION_NORMAL) { + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + bitmap = + Bitmap.createBitmap( + bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true); + } + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + faceDetection.send(bitmap); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + stopCurrentPipeline(); + setupStaticImageModePipeline(); + } + // Reads images from gallery. + Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Images.Media.INTERNAL_CONTENT_URI); + imageGetter.launch(gallery); + }); + imageView = new FaceDetectionResultImageView(this); + } + + /** Sets up core workflow for static image mode. */ + private void setupStaticImageModePipeline() { + this.inputSource = InputSource.IMAGE; + // Initializes a new MediaPipe Face Detection solution instance in the static image mode. + faceDetection = + new FaceDetection( + this, + FaceDetectionOptions.builder() + .setStaticImageMode(true) + .setModelSelection(0) + .setMinDetectionConfidence(0.5f) + .build()); + + // Connects MediaPipe Face Detection solution to the user-defined FaceDetectionResultImageView. + faceDetection.setResultListener( + faceDetectionResult -> { + logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ true); + imageView.setFaceDetectionResult(faceDetectionResult); + runOnUiThread(() -> imageView.update()); + }); + faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + /** Sets up the UI components for the video demo. */ + private void setupVideoDemoUiComponents() { + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + glSurfaceView.post( + () -> + videoInput.start( + this, + resultIntent.getData(), + faceDetection.getGlContext(), + glSurfaceView.getWidth(), + glSurfaceView.getHeight())); + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.VIDEO); + // Reads video from gallery. + Intent gallery = + new Intent(Intent.ACTION_PICK, MediaStore.Video.Media.INTERNAL_CONTENT_URI); + videoGetter.launch(gallery); + }); + } + + /** Sets up the UI components for the live demo with camera input. */ + private void setupLiveDemoUiComponents() { + Button startCameraButton = findViewById(R.id.button_start_camera); + startCameraButton.setOnClickListener( + v -> { + if (inputSource == InputSource.CAMERA) { + return; + } + stopCurrentPipeline(); + setupStreamingModePipeline(InputSource.CAMERA); + }); + } + + /** Sets up core workflow for streaming mode. */ + private void setupStreamingModePipeline(InputSource inputSource) { + this.inputSource = inputSource; + // Initializes a new MediaPipe Face Detection solution instance in the streaming mode. + faceDetection = + new FaceDetection( + this, + FaceDetectionOptions.builder().setStaticImageMode(false).setModelSelection(0).build()); + faceDetection.setErrorListener( + (message, e) -> Log.e(TAG, "MediaPipe Face Detection error:" + message)); + + if (inputSource == InputSource.CAMERA) { + cameraInput = new CameraInput(this); + cameraInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); + } else if (inputSource == InputSource.VIDEO) { + videoInput = new VideoInput(this); + videoInput.setNewFrameListener(textureFrame -> faceDetection.send(textureFrame)); + } + + // Initializes a new Gl surface view with a user-defined FaceDetectionResultGlRenderer. + glSurfaceView = + new SolutionGlSurfaceView<>( + this, faceDetection.getGlContext(), faceDetection.getGlMajorVersion()); + glSurfaceView.setSolutionResultRenderer(new FaceDetectionResultGlRenderer()); + glSurfaceView.setRenderInputImage(true); + faceDetection.setResultListener( + faceDetectionResult -> { + logNoseTipKeypoint(faceDetectionResult, /*faceIndex=*/ 0, /*showPixelValues=*/ false); + glSurfaceView.setRenderData(faceDetectionResult); + glSurfaceView.requestRender(); + }); + + // The runnable to start camera after the gl surface view is attached. + // For video input source, videoInput.start() will be called when the video uri is available. + if (inputSource == InputSource.CAMERA) { + glSurfaceView.post(this::startCamera); + } + + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + imageView.setVisibility(View.GONE); + frameLayout.removeAllViewsInLayout(); + frameLayout.addView(glSurfaceView); + glSurfaceView.setVisibility(View.VISIBLE); + frameLayout.requestLayout(); + } + + private void startCamera() { + cameraInput.start( + this, + faceDetection.getGlContext(), + CameraInput.CameraFacing.FRONT, + glSurfaceView.getWidth(), + glSurfaceView.getHeight()); + } + + private void stopCurrentPipeline() { + if (cameraInput != null) { + cameraInput.setNewFrameListener(null); + cameraInput.close(); + } + if (videoInput != null) { + videoInput.setNewFrameListener(null); + videoInput.close(); + } + if (glSurfaceView != null) { + glSurfaceView.setVisibility(View.GONE); + } + if (faceDetection != null) { + faceDetection.close(); + } + } + + private void logNoseTipKeypoint( + FaceDetectionResult result, int faceIndex, boolean showPixelValues) { + RelativeKeypoint noseTip = + FaceDetection.getFaceKeypoint(result, faceIndex, FaceKeypoint.NOSE_TIP); + // For Bitmaps, show the pixel values. For texture inputs, show the normalized coordinates. + if (showPixelValues) { + int width = result.inputBitmap().getWidth(); + int height = result.inputBitmap().getHeight(); + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip coordinates (pixel values): x=%f, y=%f", + noseTip.getX() * width, noseTip.getY() * height)); + } else { + Log.i( + TAG, + String.format( + "MediaPipe Face Detection nose tip normalized coordinates (value range: [0, 1]):" + + " x=%f, y=%f", + noseTip.getX(), noseTip.getY())); + } + } +} diff --git a/mediapipe/examples/android/solutions/facedetection/src/main/res b/mediapipe/examples/android/solutions/facedetection/src/main/res new file mode 120000 index 000000000..fc8850136 --- /dev/null +++ b/mediapipe/examples/android/solutions/facedetection/src/main/res @@ -0,0 +1 @@ +../../../res \ No newline at end of file diff --git a/mediapipe/examples/android/solutions/facemesh/build.gradle b/mediapipe/examples/android/solutions/facemesh/build.gradle index 74aedf095..8e6f39956 100644 --- a/mediapipe/examples/android/solutions/facemesh/build.gradle +++ b/mediapipe/examples/android/solutions/facemesh/build.gradle @@ -31,15 +31,16 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.3.0' implementation 'com.google.android.material:material:1.3.0' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation 'androidx.exifinterface:exifinterface:1.3.3' testImplementation 'junit:junit:4.+' androidTestImplementation 'androidx.test.ext:junit:1.1.2' androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' - // MediaPipe hands solution API and solution-core. + // MediaPipe Face Mesh Solution components. implementation 'com.google.mediapipe:solution-core:latest.release' implementation 'com.google.mediapipe:facemesh:latest.release' // MediaPipe deps - implementation 'com.google.flogger:flogger:latest.release' - implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.flogger:flogger:0.6' + implementation 'com.google.flogger:flogger-system-backend:0.6' implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.protobuf:protobuf-java:3.11.4' // CameraX core library diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/BUILD b/mediapipe/examples/android/solutions/facemesh/src/main/BUILD index 591102c3e..515f03b6b 100644 --- a/mediapipe/examples/android/solutions/facemesh/src/main/BUILD +++ b/mediapipe/examples/android/solutions/facemesh/src/main/BUILD @@ -38,6 +38,7 @@ android_binary( "//third_party:opencv", "@maven//:androidx_activity_activity", "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", "@maven//:androidx_fragment_fragment", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java index fd6c533d3..1b7eca9d6 100644 --- a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/FaceMeshResultGlRenderer.java @@ -15,11 +15,10 @@ package com.google.mediapipe.examples.facemesh; import android.opengl.GLES20; -import android.opengl.Matrix; import com.google.common.collect.ImmutableSet; import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.solutioncore.ResultGlBoundary; import com.google.mediapipe.solutioncore.ResultGlRenderer; +import com.google.mediapipe.solutions.facemesh.FaceMesh; import com.google.mediapipe.solutions.facemesh.FaceMeshConnections; import com.google.mediapipe.solutions.facemesh.FaceMeshResult; import java.nio.ByteBuffer; @@ -27,7 +26,7 @@ import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.List; -/** A custom implementation of {@link ResultGlRenderer} to render MediaPope FaceMesh results. */ +/** A custom implementation of {@link ResultGlRenderer} to render {@link FaceMeshResult}. */ public class FaceMeshResultGlRenderer implements ResultGlRenderer { private static final String TAG = "FaceMeshResultGlRenderer"; @@ -46,10 +45,10 @@ public class FaceMeshResultGlRenderer implements ResultGlRendererThis is only necessary if one wants to release the program while keeping the context around. */ @@ -159,13 +159,9 @@ public class FaceMeshResultGlRenderer implements ResultGlRenderer faceLandmarkList, diff --git a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java index 27c89a93e..27e10ad1c 100644 --- a/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java +++ b/mediapipe/examples/android/solutions/facemesh/src/main/java/com/google/mediapipe/examples/facemesh/MainActivity.java @@ -16,6 +16,7 @@ package com.google.mediapipe.examples.facemesh; import android.content.Intent; import android.graphics.Bitmap; +import android.graphics.Matrix; import android.os.Bundle; import android.provider.MediaStore; import androidx.appcompat.app.AppCompatActivity; @@ -25,6 +26,8 @@ import android.widget.Button; import android.widget.FrameLayout; import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; import com.google.mediapipe.solutioncore.CameraInput; import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; @@ -33,8 +36,9 @@ import com.google.mediapipe.solutions.facemesh.FaceMesh; import com.google.mediapipe.solutions.facemesh.FaceMeshOptions; import com.google.mediapipe.solutions.facemesh.FaceMeshResult; import java.io.IOException; +import java.io.InputStream; -/** Main activity of MediaPipe FaceMesh app. */ +/** Main activity of MediaPipe Face Mesh app. */ public class MainActivity extends AppCompatActivity { private static final String TAG = "MainActivity"; @@ -57,12 +61,14 @@ public class MainActivity extends AppCompatActivity { private ActivityResultLauncher videoGetter; // Live camera demo UI and camera components. private CameraInput cameraInput; + private SolutionGlSurfaceView glSurfaceView; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); + // TODO: Add a toggle to switch between the original face mesh and attention mesh. setupStaticImageDemoUiComponents(); setupVideoDemoUiComponents(); setupLiveDemoUiComponents(); @@ -111,6 +117,35 @@ public class MainActivity extends AppCompatActivity { } catch (IOException e) { Log.e(TAG, "Bitmap reading error:" + e); } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + int orientation = + new ExifInterface(imageData) + .getAttributeInt( + ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation != ExifInterface.ORIENTATION_NORMAL) { + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + bitmap = + Bitmap.createBitmap( + bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true); + } + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } if (bitmap != null) { facemesh.send(bitmap); } @@ -132,26 +167,27 @@ public class MainActivity extends AppCompatActivity { imageView = new FaceMeshResultImageView(this); } - /** The core MediaPipe FaceMesh setup workflow for its static image mode. */ + /** Sets up core workflow for static image mode. */ private void setupStaticImageModePipeline() { this.inputSource = InputSource.IMAGE; - // Initializes a new MediaPipe FaceMesh instance in the static image mode. + // Initializes a new MediaPipe Face Mesh solution instance in the static image mode. facemesh = new FaceMesh( this, FaceMeshOptions.builder() - .setMode(FaceMeshOptions.STATIC_IMAGE_MODE) + .setStaticImageMode(true) + .setRefineLandmarks(true) .setRunOnGpu(RUN_ON_GPU) .build()); - // Connects MediaPipe FaceMesh to the user-defined FaceMeshResultImageView. + // Connects MediaPipe Face Mesh solution to the user-defined FaceMeshResultImageView. facemesh.setResultListener( faceMeshResult -> { logNoseLandmark(faceMeshResult, /*showPixelValues=*/ true); imageView.setFaceMeshResult(faceMeshResult); runOnUiThread(() -> imageView.update()); }); - facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); // Updates the preview layout. FrameLayout frameLayout = findViewById(R.id.preview_display_layout); @@ -207,25 +243,24 @@ public class MainActivity extends AppCompatActivity { }); } - /** The core MediaPipe FaceMesh setup workflow for its streaming mode. */ + /** Sets up core workflow for streaming mode. */ private void setupStreamingModePipeline(InputSource inputSource) { this.inputSource = inputSource; - // Initializes a new MediaPipe FaceMesh instance in the streaming mode. + // Initializes a new MediaPipe Face Mesh solution instance in the streaming mode. facemesh = new FaceMesh( this, FaceMeshOptions.builder() - .setMode(FaceMeshOptions.STREAMING_MODE) + .setStaticImageMode(false) + .setRefineLandmarks(true) .setRunOnGpu(RUN_ON_GPU) .build()); - facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe FaceMesh error:" + message)); + facemesh.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Face Mesh error:" + message)); if (inputSource == InputSource.CAMERA) { - // Initializes a new CameraInput instance and connects it to MediaPipe FaceMesh. cameraInput = new CameraInput(this); cameraInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); } else if (inputSource == InputSource.VIDEO) { - // Initializes a new VideoInput instance and connects it to MediaPipe FaceMesh. videoInput = new VideoInput(this); videoInput.setNewFrameListener(textureFrame -> facemesh.send(textureFrame)); } @@ -295,13 +330,13 @@ public class MainActivity extends AppCompatActivity { Log.i( TAG, String.format( - "MediaPipe FaceMesh nose coordinates (pixel values): x=%f, y=%f", + "MediaPipe Face Mesh nose coordinates (pixel values): x=%f, y=%f", noseLandmark.getX() * width, noseLandmark.getY() * height)); } else { Log.i( TAG, String.format( - "MediaPipe FaceMesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", + "MediaPipe Face Mesh nose normalized coordinates (value range: [0, 1]): x=%f, y=%f", noseLandmark.getX(), noseLandmark.getY())); } } diff --git a/mediapipe/examples/android/solutions/hands/build.gradle b/mediapipe/examples/android/solutions/hands/build.gradle index 27629fd5d..c2cfc123c 100644 --- a/mediapipe/examples/android/solutions/hands/build.gradle +++ b/mediapipe/examples/android/solutions/hands/build.gradle @@ -31,15 +31,16 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.3.0' implementation 'com.google.android.material:material:1.3.0' implementation 'androidx.constraintlayout:constraintlayout:2.0.4' + implementation 'androidx.exifinterface:exifinterface:1.3.3' testImplementation 'junit:junit:4.+' androidTestImplementation 'androidx.test.ext:junit:1.1.2' androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0' - // MediaPipe hands solution API and solution-core. + // MediaPipe Hands Solution components. implementation 'com.google.mediapipe:solution-core:latest.release' implementation 'com.google.mediapipe:hands:latest.release' // MediaPipe deps - implementation 'com.google.flogger:flogger:latest.release' - implementation 'com.google.flogger:flogger-system-backend:latest.release' + implementation 'com.google.flogger:flogger:0.6' + implementation 'com.google.flogger:flogger-system-backend:0.6' implementation 'com.google.guava:guava:27.0.1-android' implementation 'com.google.protobuf:protobuf-java:3.11.4' // CameraX core library diff --git a/mediapipe/examples/android/solutions/hands/src/main/BUILD b/mediapipe/examples/android/solutions/hands/src/main/BUILD index 0d71e4a95..d3c304b57 100644 --- a/mediapipe/examples/android/solutions/hands/src/main/BUILD +++ b/mediapipe/examples/android/solutions/hands/src/main/BUILD @@ -38,6 +38,7 @@ android_binary( "//third_party:opencv", "@maven//:androidx_activity_activity", "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", "@maven//:androidx_fragment_fragment", "@maven//:com_google_guava_guava", ], diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java index 720ae5509..7c1884a43 100644 --- a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultGlRenderer.java @@ -15,9 +15,7 @@ package com.google.mediapipe.examples.hands; import android.opengl.GLES20; -import android.opengl.Matrix; import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.solutioncore.ResultGlBoundary; import com.google.mediapipe.solutioncore.ResultGlRenderer; import com.google.mediapipe.solutions.hands.Hands; import com.google.mediapipe.solutions.hands.HandsResult; @@ -26,16 +24,16 @@ import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.util.List; -/** A custom implementation of {@link ResultGlRenderer} to render MediaPope Hands results. */ +/** A custom implementation of {@link ResultGlRenderer} to render {@link HandsResult}. */ public class HandsResultGlRenderer implements ResultGlRenderer { private static final String TAG = "HandsResultGlRenderer"; private static final float CONNECTION_THICKNESS = 20.0f; private static final String VERTEX_SHADER = - "uniform mat4 uTransformMatrix;\n" + "uniform mat4 uProjectionMatrix;\n" + "attribute vec4 vPosition;\n" + "void main() {\n" - + " gl_Position = uTransformMatrix * vPosition;\n" + + " gl_Position = uProjectionMatrix * vPosition;\n" + "}"; private static final String FRAGMENT_SHADER = "precision mediump float;\n" @@ -44,8 +42,7 @@ public class HandsResultGlRenderer implements ResultGlRenderer { + "}"; private int program; private int positionHandle; - private int transformMatrixHandle; - private final float[] transformMatrix = new float[16]; + private int projectionMatrixHandle; private int loadShader(int type, String shaderCode) { int shader = GLES20.glCreateShader(type); @@ -63,27 +60,16 @@ public class HandsResultGlRenderer implements ResultGlRenderer { GLES20.glAttachShader(program, fragmentShader); GLES20.glLinkProgram(program); positionHandle = GLES20.glGetAttribLocation(program, "vPosition"); - transformMatrixHandle = GLES20.glGetUniformLocation(program, "uTransformMatrix"); + projectionMatrixHandle = GLES20.glGetUniformLocation(program, "uProjectionMatrix"); } @Override - public void renderResult(HandsResult result, ResultGlBoundary boundary) { + public void renderResult(HandsResult result, float[] projectionMatrix) { if (result == null) { return; } GLES20.glUseProgram(program); - // Sets the transform matrix to align the result rendering with the scaled output texture. - // Also flips the rendering vertically since OpenGL assumes the coordinate origin is at the - // bottom-left corner, whereas MediaPipe landmark data assumes the coordinate origin is at the - // top-left corner. - Matrix.setIdentityM(transformMatrix, 0); - Matrix.scaleM( - transformMatrix, - 0, - 2 / (boundary.right() - boundary.left()), - -2 / (boundary.top() - boundary.bottom()), - 1.0f); - GLES20.glUniformMatrix4fv(transformMatrixHandle, 1, false, transformMatrix, 0); + GLES20.glUniformMatrix4fv(projectionMatrixHandle, 1, false, projectionMatrix, 0); GLES20.glLineWidth(CONNECTION_THICKNESS); int numHands = result.multiHandLandmarks().size(); @@ -93,7 +79,7 @@ public class HandsResultGlRenderer implements ResultGlRenderer { } /** - * Calls this to delete the shader program. + * Deletes the shader program. * *

This is only necessary if one wants to release the program while keeping the context around. */ @@ -101,16 +87,11 @@ public class HandsResultGlRenderer implements ResultGlRenderer { GLES20.glDeleteProgram(program); } - // TODO: Better hand landmark and hand connection drawing. private void drawLandmarks(List handLandmarkList) { for (Hands.Connection c : Hands.HAND_CONNECTIONS) { - float[] vertex = new float[4]; NormalizedLandmark start = handLandmarkList.get(c.start()); - vertex[0] = normalizedLandmarkValue(start.getX()); - vertex[1] = normalizedLandmarkValue(start.getY()); NormalizedLandmark end = handLandmarkList.get(c.end()); - vertex[2] = normalizedLandmarkValue(end.getX()); - vertex[3] = normalizedLandmarkValue(end.getY()); + float[] vertex = {start.getX(), start.getY(), end.getX(), end.getY()}; FloatBuffer vertexBuffer = ByteBuffer.allocateDirect(vertex.length * 4) .order(ByteOrder.nativeOrder()) @@ -122,10 +103,4 @@ public class HandsResultGlRenderer implements ResultGlRenderer { GLES20.glDrawArrays(GLES20.GL_LINES, 0, 2); } } - - // Normalizes the value from the landmark value range:[0, 1] to the standard OpenGL coordinate - // value range: [-1, 1]. - private float normalizedLandmarkValue(float value) { - return value * 2 - 1; - } } diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java index d4052d4e9..3d3c2a3c1 100644 --- a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/HandsResultImageView.java @@ -27,7 +27,7 @@ import com.google.mediapipe.solutions.hands.Hands; import com.google.mediapipe.solutions.hands.HandsResult; import java.util.List; -/** An ImageView implementation for displaying MediaPipe Hands results. */ +/** An ImageView implementation for displaying {@link HandsResult}. */ public class HandsResultImageView extends AppCompatImageView { private static final String TAG = "HandsResultImageView"; @@ -66,7 +66,7 @@ public class HandsResultImageView extends AppCompatImageView { } } - /** Updates the image view with the latest hands result. */ + /** Updates the image view with the latest {@link HandsResult}. */ public void update() { postInvalidate(); if (latest != null) { @@ -74,7 +74,6 @@ public class HandsResultImageView extends AppCompatImageView { } } - // TODO: Better hand landmark and hand connection drawing. private void drawLandmarksOnCanvas( List handLandmarkList, Canvas canvas, int width, int height) { // Draw connections. diff --git a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java index 379219942..8bcd82744 100644 --- a/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java +++ b/mediapipe/examples/android/solutions/hands/src/main/java/com/google/mediapipe/examples/hands/MainActivity.java @@ -16,6 +16,7 @@ package com.google.mediapipe.examples.hands; import android.content.Intent; import android.graphics.Bitmap; +import android.graphics.Matrix; import android.os.Bundle; import android.provider.MediaStore; import androidx.appcompat.app.AppCompatActivity; @@ -25,6 +26,8 @@ import android.widget.Button; import android.widget.FrameLayout; import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; import com.google.mediapipe.solutioncore.CameraInput; import com.google.mediapipe.solutioncore.SolutionGlSurfaceView; @@ -34,6 +37,7 @@ import com.google.mediapipe.solutions.hands.Hands; import com.google.mediapipe.solutions.hands.HandsOptions; import com.google.mediapipe.solutions.hands.HandsResult; import java.io.IOException; +import java.io.InputStream; /** Main activity of MediaPipe Hands app. */ public class MainActivity extends AppCompatActivity { @@ -59,6 +63,7 @@ public class MainActivity extends AppCompatActivity { private ActivityResultLauncher videoGetter; // Live camera demo UI and camera components. private CameraInput cameraInput; + private SolutionGlSurfaceView glSurfaceView; @Override @@ -113,6 +118,35 @@ public class MainActivity extends AppCompatActivity { } catch (IOException e) { Log.e(TAG, "Bitmap reading error:" + e); } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + int orientation = + new ExifInterface(imageData) + .getAttributeInt( + ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + if (orientation != ExifInterface.ORIENTATION_NORMAL) { + Matrix matrix = new Matrix(); + switch (orientation) { + case ExifInterface.ORIENTATION_ROTATE_90: + matrix.postRotate(90); + break; + case ExifInterface.ORIENTATION_ROTATE_180: + matrix.postRotate(180); + break; + case ExifInterface.ORIENTATION_ROTATE_270: + matrix.postRotate(270); + break; + default: + matrix.postRotate(0); + } + bitmap = + Bitmap.createBitmap( + bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true); + } + } catch (IOException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } if (bitmap != null) { hands.send(bitmap); } @@ -134,20 +168,20 @@ public class MainActivity extends AppCompatActivity { imageView = new HandsResultImageView(this); } - /** The core MediaPipe Hands setup workflow for its static image mode. */ + /** Sets up core workflow for static image mode. */ private void setupStaticImageModePipeline() { this.inputSource = InputSource.IMAGE; - // Initializes a new MediaPipe Hands instance in the static image mode. + // Initializes a new MediaPipe Hands solution instance in the static image mode. hands = new Hands( this, HandsOptions.builder() - .setMode(HandsOptions.STATIC_IMAGE_MODE) + .setStaticImageMode(true) .setMaxNumHands(1) .setRunOnGpu(RUN_ON_GPU) .build()); - // Connects MediaPipe Hands to the user-defined HandsResultImageView. + // Connects MediaPipe Hands solution to the user-defined HandsResultImageView. hands.setResultListener( handsResult -> { logWristLandmark(handsResult, /*showPixelValues=*/ true); @@ -210,26 +244,24 @@ public class MainActivity extends AppCompatActivity { }); } - /** The core MediaPipe Hands setup workflow for its streaming mode. */ + /** Sets up core workflow for streaming mode. */ private void setupStreamingModePipeline(InputSource inputSource) { this.inputSource = inputSource; - // Initializes a new MediaPipe Hands instance in the streaming mode. + // Initializes a new MediaPipe Hands solution instance in the streaming mode. hands = new Hands( this, HandsOptions.builder() - .setMode(HandsOptions.STREAMING_MODE) + .setStaticImageMode(false) .setMaxNumHands(1) .setRunOnGpu(RUN_ON_GPU) .build()); hands.setErrorListener((message, e) -> Log.e(TAG, "MediaPipe Hands error:" + message)); if (inputSource == InputSource.CAMERA) { - // Initializes a new CameraInput instance and connects it to MediaPipe Hands. cameraInput = new CameraInput(this); cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); } else if (inputSource == InputSource.VIDEO) { - // Initializes a new VideoInput instance and connects it to MediaPipe Hands. videoInput = new VideoInput(this); videoInput.setNewFrameListener(textureFrame -> hands.send(textureFrame)); } diff --git a/mediapipe/examples/android/solutions/settings.gradle b/mediapipe/examples/android/solutions/settings.gradle index adc81ab91..c050ba4bf 100644 --- a/mediapipe/examples/android/solutions/settings.gradle +++ b/mediapipe/examples/android/solutions/settings.gradle @@ -1,3 +1,4 @@ rootProject.name = "mediapipe-solutions-examples" -include ':hands' +include ':facedetection' include ':facemesh' +include ':hands' diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD index 378132c15..edef0b860 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facemeshgpu/BUILD @@ -37,7 +37,7 @@ android_binary( srcs = glob(["*.java"]), assets = [ "//mediapipe/graphs/face_mesh:face_mesh_mobile_gpu.binarypb", - "//mediapipe/modules/face_landmark:face_landmark.tflite", + "//mediapipe/modules/face_landmark:face_landmark_with_attention.tflite", "//mediapipe/modules/face_detection:face_detection_short_range.tflite", ], assets_dir = "", diff --git a/mediapipe/examples/ios/facemeshgpu/BUILD b/mediapipe/examples/ios/facemeshgpu/BUILD index a8d5ef668..d10d531ca 100644 --- a/mediapipe/examples/ios/facemeshgpu/BUILD +++ b/mediapipe/examples/ios/facemeshgpu/BUILD @@ -63,7 +63,7 @@ objc_library( data = [ "//mediapipe/graphs/face_mesh:face_mesh_mobile_gpu.binarypb", "//mediapipe/modules/face_detection:face_detection_short_range.tflite", - "//mediapipe/modules/face_landmark:face_landmark.tflite", + "//mediapipe/modules/face_landmark:face_landmark_with_attention.tflite", ], deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", diff --git a/mediapipe/framework/formats/image_multi_pool.cc b/mediapipe/framework/formats/image_multi_pool.cc index b79c81db8..655064d36 100644 --- a/mediapipe/framework/formats/image_multi_pool.cc +++ b/mediapipe/framework/formats/image_multi_pool.cc @@ -23,6 +23,7 @@ #if !MEDIAPIPE_DISABLE_GPU #ifdef __APPLE__ #include "mediapipe/objc/CFHolder.h" +#include "mediapipe/objc/util.h" #endif // __APPLE__ #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 86c9ff30d..324948778 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -89,6 +89,18 @@ cc_library( ], ) +cc_library( + name = "commandlineflags", + hdrs = [ + "commandlineflags.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//third_party:glog", + "@com_google_absl//absl/flags:flag", + ], +) + cc_library( name = "core_proto", hdrs = [ diff --git a/mediapipe/framework/port/commandlineflags.h b/mediapipe/framework/port/commandlineflags.h new file mode 100644 index 000000000..a3d17c71e --- /dev/null +++ b/mediapipe/framework/port/commandlineflags.h @@ -0,0 +1,30 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ +#define MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ + +#include "gflags/gflags.h" +namespace absl { +template +T GetFlag(const T& f) { + return f; +} +template +void SetFlag(T* f, const U& u) { + *f = u; +} +} // namespace absl + +#endif // MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index db92cfd38..479d2a184 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -202,6 +202,7 @@ cc_library( "//mediapipe/framework:packet_type", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -291,7 +292,9 @@ mediapipe_cc_test( data = [":node_chain_subgraph.proto"], requires_full_emulation = False, deps = [ + ":options_field_util", ":options_registry", + ":options_syntax_util", ":options_util", "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", @@ -305,8 +308,8 @@ mediapipe_cc_test( "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib", - "//mediapipe/framework/tool:options_syntax_util", "//mediapipe/util:header_util", + "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc index 043015b2a..0fdbc47ab 100644 --- a/mediapipe/framework/tool/options_field_util.cc +++ b/mediapipe/framework/tool/options_field_util.cc @@ -8,11 +8,13 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/name_util.h" #include "mediapipe/framework/tool/proto_util_lite.h" @@ -31,6 +33,9 @@ using ::mediapipe::proto_ns::io::StringOutputStream; // Utility functions for OptionsFieldUtil. namespace { +// The type name for the proto3 "Any" type. +constexpr absl::string_view kGoogleProtobufAny = "google.protobuf.Any"; + // Converts a FieldDescriptor::Type to the corresponding FieldType. FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) { return static_cast(type); @@ -81,7 +86,7 @@ absl::Status WriteValue(const FieldData& value, FieldType field_type, return absl::UnimplementedError( absl::StrCat("Cannot write type: ", field_type)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Serializes a packet value. @@ -167,6 +172,7 @@ absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, // Deserializes a packet from a protobuf field. absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field, FieldData* result) { + RET_CHECK_NE(field, nullptr); FieldType field_type = AsFieldType(field->type()); std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE) ? field->message_type()->full_name() @@ -174,47 +180,137 @@ absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field, return ReadValue(bytes, field_type, message_type, result); } -// Converts a chain of fields and indexes into field-numbers and indexes. -ProtoUtilLite::ProtoPath AsProtoPath(const FieldPath& field_path) { - ProtoUtilLite::ProtoPath result; - for (auto field : field_path) { - result.push_back({field.first->number(), field.second}); +// Reads all values from a repeated field. +absl::Status GetFieldValues(const FieldData& message_data, + const FieldDescriptor& field, + std::vector* result) { + const std::string& message_bytes = message_data.message_value().value(); + FieldType field_type = AsFieldType(field.type()); + ProtoUtilLite proto_util; + ProtoUtilLite::ProtoPath proto_path = {{field.number(), 0}}; + int count; + MP_RETURN_IF_ERROR( + proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count)); + std::vector field_values; + MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, count, + field_type, &field_values)); + for (int i = 0; i < count; ++i) { + FieldData r; + MP_RETURN_IF_ERROR(ReadField(field_values[i], &field, &r)); + result->push_back(std::move(r)); } + return absl::OkStatus(); +} + +// Reads one value from a field. +absl::Status GetFieldValue(const FieldData& message_data, + const FieldPathEntry& entry, FieldData* result) { + RET_CHECK_NE(entry.field, nullptr); + const std::string& message_bytes = message_data.message_value().value(); + FieldType field_type = AsFieldType(entry.field->type()); + ProtoUtilLite proto_util; + ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}}; + std::vector field_values; + MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1, + field_type, &field_values)); + MP_RETURN_IF_ERROR(ReadField(field_values[0], entry.field, result)); + return absl::OkStatus(); +} + +// Writes one value to a field. +absl::Status SetFieldValue(const FieldPathEntry& entry, const FieldData& value, + FieldData* result) { + std::vector field_values; + ProtoUtilLite proto_util; + FieldType field_type = AsFieldType(entry.field->type()); + ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}}; + std::string* message_bytes = result->mutable_message_value()->mutable_value(); + int field_count; + MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path, + field_type, &field_count)); + if (entry.index > field_count) { + return absl::OutOfRangeError( + absl::StrCat("Option field index out of range: ", entry.index)); + } + int replace_length = entry.index < field_count ? 1 : 0; + std::string field_value; + MP_RETURN_IF_ERROR(WriteField(value, entry.field, &field_value)); + MP_RETURN_IF_ERROR(proto_util.ReplaceFieldRange( + message_bytes, proto_path, replace_length, field_type, {field_value})); + return absl::OkStatus(); +} + +// Returns true for a field of type "google.protobuf.Any". +bool IsProtobufAny(const FieldDescriptor* field) { + return AsFieldType(field->type()) == FieldType::TYPE_MESSAGE && + field->message_type()->full_name() == kGoogleProtobufAny; +} + +// Returns the message FieldData from a serialized protobuf.Any. +FieldData ParseProtobufAny(const FieldData& data) { + protobuf::Any any; + any.ParseFromString(data.message_value().value()); + FieldData result; + result.mutable_message_value()->set_value(std::string(any.value())); + result.mutable_message_value()->set_type_url(any.type_url()); return result; } -// Returns the options protobuf for a subgraph. -// TODO: Ensure that this works with multiple options protobufs. -absl::Status GetOptionsMessage( - const proto_ns::RepeatedPtrField& options_any, - const proto_ns::MessageLite& options_ext, FieldData* result) { - // Read the "graph_options" or "node_options" field. - for (const auto& options : options_any) { - if (options.type_url().empty()) { - continue; - } - result->mutable_message_value()->set_type_url(options.type_url()); - result->mutable_message_value()->set_value(std::string(options.value())); - return mediapipe::OkStatus(); - } +// Returns the serialized protobuf.Any containing a message FieldData. +FieldData SerializeProtobufAny(const FieldData& data) { + protobuf::Any any; + any.set_value(data.message_value().value()); + any.set_type_url(data.message_value().type_url()); + FieldData result; + result.mutable_message_value()->set_value(any.SerializeAsString()); + result.mutable_message_value()->set_type_url(TypeUrl(kGoogleProtobufAny)); + return result; +} - // Read the "options" field. - FieldData message_data; - *message_data.mutable_message_value()->mutable_value() = - options_ext.SerializeAsString(); - message_data.mutable_message_value()->set_type_url(options_ext.GetTypeName()); - std::vector ext_fields; - OptionsRegistry::FindAllExtensions(options_ext.GetTypeName(), &ext_fields); - for (auto ext_field : ext_fields) { - absl::Status status = GetField({{ext_field, 0}}, message_data, result); - if (!status.ok()) { - return status; - } - if (result->has_message_value()) { - return status; +// Returns the field index of an extension type in a repeated field. +StatusOr FindExtensionIndex(const FieldData& message_data, + FieldPathEntry* entry) { + if (entry->field == nullptr || !IsProtobufAny(entry->field)) { + return -1; + } + std::string& extension_type = entry->extension_type; + std::vector field_values; + RET_CHECK_NE(entry->field, nullptr); + MP_RETURN_IF_ERROR( + GetFieldValues(message_data, *entry->field, &field_values)); + for (int i = 0; i < field_values.size(); ++i) { + FieldData extension = ParseProtobufAny(field_values[i]); + if (extension_type == "*" || + ParseTypeUrl(extension.message_value().type_url()) == extension_type) { + return i; } } - return mediapipe::OkStatus(); + return -1; +} + +// Returns true if the value of a field is available. +bool HasField(const FieldPath& field_path, const FieldData& message_data) { + FieldData value; + return GetField(field_path, message_data, &value).ok() && + value.value_case() != mediapipe::FieldData::VALUE_NOT_SET; +} + +// Returns the extension field containing the specified extension-type. +const FieldDescriptor* FindExtensionField(const FieldData& message_data, + absl::string_view extension_type) { + std::string message_type = + ParseTypeUrl(message_data.message_value().type_url()); + std::vector extensions; + OptionsRegistry::FindAllExtensions(message_type, &extensions); + for (const FieldDescriptor* extension : extensions) { + if (extension->message_type()->full_name() == extension_type) { + return extension; + } + if (extension_type == "*" && HasField({{extension, 0}}, message_data)) { + return extension; + } + } + return nullptr; } // Sets a protobuf in a repeated protobuf::Any field. @@ -234,6 +330,20 @@ void SetOptionsMessage( *options_any->mutable_value() = node_options.message_value().value(); } +// Returns the count of values in a repeated field. +int FieldCount(const FieldData& message_data, const FieldDescriptor* field) { + const std::string& message_bytes = message_data.message_value().value(); + FieldType field_type = AsFieldType(field->type()); + ProtoUtilLite proto_util; + ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}}; + int count; + if (proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count) + .ok()) { + return count; + } + return 0; +} + } // anonymous namespace // Deserializes a packet containing a MessageLite value. @@ -247,8 +357,8 @@ absl::Status ReadMessage(const std::string& value, const std::string& type_name, } // Merge two options FieldData values. -absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over, - FieldData* result) { +absl::Status MergeMessages(const FieldData& base, const FieldData& over, + FieldData* result) { absl::Status status; if (over.value_case() == FieldData::VALUE_NOT_SET) { *result = base; @@ -278,28 +388,148 @@ absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over, return status; } +// Returns either the extension field or the repeated protobuf.Any field index +// holding the specified extension-type. +absl::Status FindExtension(const FieldData& message_data, + FieldPathEntry* entry) { + if (entry->extension_type.empty()) { + return absl::OkStatus(); + } + + // For repeated protobuf::Any, find the index for the extension_type. + ASSIGN_OR_RETURN(int index, FindExtensionIndex(message_data, entry)); + if (index != -1) { + entry->index = index; + return absl::OkStatus(); + } + + // Returns the extension field containing the specified extension-type. + std::string& extension_type = entry->extension_type; + const FieldDescriptor* field = + FindExtensionField(message_data, extension_type); + if (field != nullptr) { + entry->field = field; + entry->index = 0; + return absl::OkStatus(); + } + return absl::NotFoundError( + absl::StrCat("Option extension not found: ", extension_type)); +} + +// Return the FieldPath referencing an extension message. +FieldPath GetExtensionPath(const std::string& parent_type, + const std::string& extension_type, + const std::string& field_name, + bool is_protobuf_any) { + FieldPath result; + const tool::Descriptor* parent_descriptor = + tool::OptionsRegistry::GetProtobufDescriptor(parent_type); + FieldPathEntry field_entry; + field_entry.field = parent_descriptor->FindFieldByName(field_name); + if (is_protobuf_any) { + field_entry.extension_type = extension_type; + result = {std::move(field_entry)}; + } else { + field_entry.index = 0; + FieldPathEntry extension_entry; + extension_entry.extension_type = extension_type; + result = {std::move(field_entry), std::move(extension_entry)}; + } + return result; +} + +// Returns the requested options protobuf for a graph node. +absl::Status GetNodeOptions(const FieldData& message_data, + const std::string& extension_type, + FieldData* result) { + constexpr char kOptionsName[] = "options"; + constexpr char kNodeOptionsName[] = "node_options"; + std::string parent_type = options_field_util::ParseTypeUrl( + std::string(message_data.message_value().type_url())); + FieldPath path; + Status status; + path = GetExtensionPath(parent_type, extension_type, kOptionsName, false); + status = GetField(path, message_data, result); + if (status.ok()) { + return status; + } + path = GetExtensionPath(parent_type, extension_type, kNodeOptionsName, true); + status = GetField(path, message_data, result); + return status; +} + +// Returns the requested options protobuf for a graph. +absl::Status GetGraphOptions(const FieldData& message_data, + const std::string& extension_type, + FieldData* result) { + constexpr char kOptionsName[] = "options"; + constexpr char kGraphOptionsName[] = "graph_options"; + std::string parent_type = options_field_util::ParseTypeUrl( + std::string(message_data.message_value().type_url())); + FieldPath path; + Status status; + path = GetExtensionPath(parent_type, extension_type, kOptionsName, false); + status = GetField(path, message_data, result); + if (status.ok()) { + return status; + } + path = GetExtensionPath(parent_type, extension_type, kGraphOptionsName, true); + status = GetField(path, message_data, result); + return status; +} + +// Reads a FieldData value from a protobuf field. +absl::Status GetField(const FieldPath& field_path, + const FieldData& message_data, FieldData* result) { + if (field_path.empty()) { + *result->mutable_message_value() = message_data.message_value(); + return absl::OkStatus(); + } + FieldPathEntry head = field_path.front(); + FieldPath tail = field_path; + tail.erase(tail.begin()); + if (!head.extension_type.empty()) { + MP_RETURN_IF_ERROR(FindExtension(message_data, &head)); + } + if (tail.empty() && FieldCount(message_data, head.field) == 0) { + return absl::OkStatus(); + } + MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, result)); + if (IsProtobufAny(head.field)) { + *result = ParseProtobufAny(*result); + } + if (!tail.empty()) { + FieldData child = *result; + MP_RETURN_IF_ERROR(GetField(tail, child, result)); + } + return absl::OkStatus(); +} + // Writes a FieldData value into protobuf field. absl::Status SetField(const FieldPath& field_path, const FieldData& value, FieldData* message_data) { if (field_path.empty()) { *message_data->mutable_message_value() = value.message_value(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - ProtoUtilLite proto_util; - const FieldDescriptor* field = field_path.back().first; - FieldType field_type = AsFieldType(field->type()); - std::string field_value; - MP_RETURN_IF_ERROR(WriteField(value, field, &field_value)); - ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path); - std::string* message_bytes = - message_data->mutable_message_value()->mutable_value(); - int field_count; - MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path, - field_type, &field_count)); - MP_RETURN_IF_ERROR( - proto_util.ReplaceFieldRange(message_bytes, AsProtoPath(field_path), - field_count, field_type, {field_value})); - return mediapipe::OkStatus(); + FieldPathEntry head = field_path.front(); + FieldPath tail = field_path; + tail.erase(tail.begin()); + if (!head.extension_type.empty()) { + MP_RETURN_IF_ERROR(FindExtension(*message_data, &head)); + } + if (tail.empty()) { + MP_RETURN_IF_ERROR(SetFieldValue(head, value, message_data)); + } else { + FieldData child; + MP_RETURN_IF_ERROR(GetFieldValue(*message_data, head, &child)); + MP_RETURN_IF_ERROR(SetField(tail, value, &child)); + if (IsProtobufAny(head.field)) { + child = SerializeProtobufAny(child); + } + MP_RETURN_IF_ERROR(SetFieldValue(head, child, message_data)); + } + return absl::OkStatus(); } // Merges a packet value into nested protobuf Message. @@ -308,7 +538,7 @@ absl::Status MergeField(const FieldPath& field_path, const FieldData& value, absl::Status status; FieldType field_type = field_path.empty() ? FieldType::TYPE_MESSAGE - : AsFieldType(field_path.back().first->type()); + : AsFieldType(field_path.back().field->type()); std::string message_type = (value.has_message_value()) ? ParseTypeUrl(std::string(value.message_value().type_url())) @@ -317,49 +547,12 @@ absl::Status MergeField(const FieldPath& field_path, const FieldData& value, if (field_type == FieldType::TYPE_MESSAGE) { FieldData b; status.Update(GetField(field_path, *message_data, &b)); - status.Update(MergeOptionsMessages(b, v, &v)); + status.Update(MergeMessages(b, v, &v)); } status.Update(SetField(field_path, v, message_data)); return status; } -// Reads a packet value from a protobuf field. -absl::Status GetField(const FieldPath& field_path, - const FieldData& message_data, FieldData* result) { - if (field_path.empty()) { - *result->mutable_message_value() = message_data.message_value(); - return mediapipe::OkStatus(); - } - ProtoUtilLite proto_util; - const FieldDescriptor* field = field_path.back().first; - FieldType field_type = AsFieldType(field->type()); - std::vector field_values; - ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path); - const std::string& message_bytes = message_data.message_value().value(); - int field_count; - MP_RETURN_IF_ERROR(proto_util.GetFieldCount(message_bytes, proto_path, - field_type, &field_count)); - if (field_count == 0) { - return mediapipe::OkStatus(); - } - MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1, - field_type, &field_values)); - MP_RETURN_IF_ERROR(ReadField(field_values.front(), field, result)); - return mediapipe::OkStatus(); -} - -// Returns the options protobuf for a graph. -absl::Status GetOptionsMessage(const CalculatorGraphConfig& config, - FieldData* result) { - return GetOptionsMessage(config.graph_options(), config.options(), result); -} - -// Returns the options protobuf for a node. -absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node, - FieldData* result) { - return GetOptionsMessage(node.node_options(), node.options(), result); -} - // Sets the node_options field in a Node, and clears the options field. void SetOptionsMessage(const FieldData& node_options, CalculatorGraphConfig::Node* node) { @@ -367,6 +560,16 @@ void SetOptionsMessage(const FieldData& node_options, node->clear_options(); } +// Serialize a MessageLite to a FieldData. +FieldData AsFieldData(const proto_ns::MessageLite& message) { + FieldData result; + *result.mutable_message_value()->mutable_value() = + message.SerializePartialAsString(); + *result.mutable_message_value()->mutable_type_url() = + TypeUrl(message.GetTypeName()); + return result; +} + // Represents a protobuf enum value stored in a Packet. struct ProtoEnum { ProtoEnum(int32 v) : value(v) {} @@ -415,7 +618,7 @@ absl::Status AsPacket(const FieldData& data, Packet* result) { case FieldData::VALUE_NOT_SET: *result = Packet(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } absl::Status AsFieldData(Packet packet, FieldData* result) { @@ -436,7 +639,7 @@ absl::Status AsFieldData(Packet packet, FieldData* result) { packet.GetProtoMessageLite().SerializeAsString()); result->mutable_message_value()->set_type_url( TypeUrl(packet.GetProtoMessageLite().GetTypeName())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (kTypeIds->count(packet.GetTypeId()) == 0) { @@ -473,7 +676,7 @@ absl::Status AsFieldData(Packet packet, FieldData* result) { result->set_string_value(packet.Get()); break; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::string TypeUrl(absl::string_view type_name) { diff --git a/mediapipe/framework/tool/options_field_util.h b/mediapipe/framework/tool/options_field_util.h index 2dda09ca3..f3c82e95d 100644 --- a/mediapipe/framework/tool/options_field_util.h +++ b/mediapipe/framework/tool/options_field_util.h @@ -19,8 +19,15 @@ namespace tool { // Utility to read and write Packet data from protobuf fields. namespace options_field_util { -// A chain of nested fields and indexes. -using FieldPath = std::vector>; +// A protobuf field and index description. +struct FieldPathEntry { + const FieldDescriptor* field = nullptr; + int index = -1; + std::string extension_type; +}; + +// A chain of nested protobuf fields and indexes. +using FieldPath = std::vector; // Writes a field value into protobuf field. absl::Status SetField(const FieldPath& field_path, const FieldData& value, @@ -39,21 +46,26 @@ absl::Status ReadMessage(const std::string& value, const std::string& type_name, Packet* result); // Merge two options protobuf field values. -absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over, - FieldData* result); +absl::Status MergeMessages(const FieldData& base, const FieldData& over, + FieldData* result); -// Returns the options protobuf for a graph. -absl::Status GetOptionsMessage(const CalculatorGraphConfig& config, - FieldData* result); +// Returns the requested options protobuf for a graph. +absl::Status GetNodeOptions(const FieldData& message_data, + const std::string& extension_type, + FieldData* result); -// Returns the options protobuf for a node. -absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node, - FieldData* result); +// Returns the requested options protobuf for a graph node. +absl::Status GetGraphOptions(const FieldData& message_data, + const std::string& extension_type, + FieldData* result); // Sets the node_options field in a Node, and clears the options field. void SetOptionsMessage(const FieldData& node_options, CalculatorGraphConfig::Node* node); +// Serialize a MessageLite to a FieldData. +FieldData AsFieldData(const proto_ns::MessageLite& message); + // Constructs a Packet for a FieldData proto. absl::Status AsPacket(const FieldData& data, Packet* result); diff --git a/mediapipe/framework/tool/options_syntax_util.cc b/mediapipe/framework/tool/options_syntax_util.cc index 0112189fb..e51b0ac59 100644 --- a/mediapipe/framework/tool/options_syntax_util.cc +++ b/mediapipe/framework/tool/options_syntax_util.cc @@ -5,17 +5,42 @@ #include #include +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/tool/name_util.h" +#include "mediapipe/framework/tool/options_registry.h" namespace mediapipe { namespace tool { +namespace { + +// StrSplit Delimiter to split strings at single colon tokens, ignoring +// double-colon tokens. +class SingleColonDelimiter { + public: + SingleColonDelimiter() {} + absl::string_view Find(absl::string_view text, size_t pos) const { + while (pos < text.length()) { + size_t p = text.find(':', pos); + p = (p == absl::string_view::npos) ? text.length() : p; + if (p >= text.length() - 1 || text[p + 1] != ':') { + return text.substr(p, 1); + } + pos = p + 2; + } + return text.substr(text.length(), 0); + } +}; + +} // namespace + // Helper functions for parsing the graph options syntax. class OptionsSyntaxUtil::OptionsSyntaxHelper { public: @@ -31,13 +56,32 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper { // Returns the option protobuf field name for a tag or packet name. absl::string_view OptionFieldName(absl::string_view name) { return name; } + // Return the extension-type specified for an option field. + absl::string_view ExtensionType(absl::string_view option_name) { + constexpr absl::string_view kExt = "Ext::"; + if (absl::StartsWithIgnoreCase(option_name, kExt)) { + return option_name.substr(kExt.size()); + } + return ""; + } + + // Returns the field names encoded in an options tag. + std::vector OptionTagNames(absl::string_view tag) { + if (absl::StartsWith(tag, syntax_.tag_name)) { + tag = tag.substr(syntax_.tag_name.length()); + } else if (absl::StartsWith(tag, syntax_.packet_name)) { + tag = tag.substr(syntax_.packet_name.length()); + } + if (absl::StartsWith(tag, syntax_.separator)) { + tag = tag.substr(syntax_.separator.length()); + } + return absl::StrSplit(tag, syntax_.separator); + } + // Returns the field-path for an option stream-tag. - FieldPath OptionFieldPath(const std::string& tag, + FieldPath OptionFieldPath(absl::string_view tag, const Descriptor* descriptor) { - int prefix = syntax_.tag_name.length() + syntax_.separator.length(); - std::string suffix = tag.substr(prefix); - std::vector name_tags = - absl::StrSplit(suffix, syntax_.separator); + std::vector name_tags = OptionTagNames(tag); FieldPath result; for (absl::string_view name_tag : name_tags) { if (name_tag.empty()) { @@ -46,8 +90,16 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper { absl::string_view option_name = OptionFieldName(name_tag); int index; if (absl::SimpleAtoi(option_name, &index)) { - result.back().second = index; + result.back().index = index; + } + if (!ExtensionType(option_name).empty()) { + std::string extension_type = std::string(ExtensionType(option_name)); + result.push_back({nullptr, 0, extension_type}); + descriptor = OptionsRegistry::GetProtobufDescriptor(extension_type); } else { + if (descriptor == nullptr) { + break; + } auto field = descriptor->FindFieldByName(std::string(option_name)); descriptor = field ? field->message_type() : nullptr; result.push_back({std::move(field), 0}); @@ -78,7 +130,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper { } // Converts slash-separated field names into a tag name. - std::string OptionFieldsTag(const std::string& option_names) { + std::string OptionFieldsTag(absl::string_view option_names) { std::string tag_prefix = syntax_.tag_name + syntax_.separator; std::vector names = absl::StrSplit(option_names, '/'); if (!names.empty() && names[0] == syntax_.tag_name) { @@ -129,15 +181,18 @@ OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name, OptionsSyntaxUtil::~OptionsSyntaxUtil() {} -std::string OptionsSyntaxUtil::OptionFieldsTag( - const std::string& option_names) { +std::string OptionsSyntaxUtil::OptionFieldsTag(absl::string_view option_names) { return syntax_helper_->OptionFieldsTag(option_names); } OptionsSyntaxUtil::FieldPath OptionsSyntaxUtil::OptionFieldPath( - const std::string& tag, const Descriptor* descriptor) { + absl::string_view tag, const Descriptor* descriptor) { return syntax_helper_->OptionFieldPath(tag, descriptor); } +std::vector OptionsSyntaxUtil::StrSplitTags( + absl::string_view tag_and_name) { + return absl::StrSplit(tag_and_name, SingleColonDelimiter()); +} } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_syntax_util.h b/mediapipe/framework/tool/options_syntax_util.h index a75341b97..09c3fbb78 100644 --- a/mediapipe/framework/tool/options_syntax_util.h +++ b/mediapipe/framework/tool/options_syntax_util.h @@ -28,12 +28,15 @@ class OptionsSyntaxUtil { ~OptionsSyntaxUtil(); // Converts slash-separated field names into a tag name. - std::string OptionFieldsTag(const std::string& option_names); + std::string OptionFieldsTag(absl::string_view option_names); // Returns the field-path for an option stream-tag. - FieldPath OptionFieldPath(const std::string& tag, + FieldPath OptionFieldPath(absl::string_view tag, const Descriptor* descriptor); + // Splits a std::string into "tag" and "name" delimited by a single colon. + std::vector StrSplitTags(absl::string_view tag_and_name); + private: class OptionsSyntaxHelper; std::unique_ptr syntax_helper_; diff --git a/mediapipe/framework/tool/options_util.cc b/mediapipe/framework/tool/options_util.cc index 5d7c64b75..05bb15415 100644 --- a/mediapipe/framework/tool/options_util.cc +++ b/mediapipe/framework/tool/options_util.cc @@ -7,6 +7,7 @@ #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/input_stream_shard.h" #include "mediapipe/framework/output_side_packet.h" @@ -24,50 +25,77 @@ namespace mediapipe { namespace tool { +using options_field_util::FieldPath; +using options_field_util::GetField; +using options_field_util::GetGraphOptions; +using options_field_util::GetNodeOptions; +using options_field_util::MergeField; +using options_field_util::MergeMessages; + +// Returns the type for the root options message if specified. +std::string ExtensionType(const std::string& option_fields_tag) { + OptionsSyntaxUtil syntax_util; + options_field_util::FieldPath field_path = + syntax_util.OptionFieldPath(option_fields_tag, nullptr); + std::string result = !field_path.empty() ? field_path[0].extension_type : ""; + return !result.empty() ? result : "*"; +} + +// Constructs a FieldPath for field names starting at a message type. +FieldPath GetPath(const std::string& path_tag, + const std::string& message_type) { + OptionsSyntaxUtil syntax_util; + const Descriptor* descriptor = + OptionsRegistry::GetProtobufDescriptor(message_type); + return syntax_util.OptionFieldPath(path_tag, descriptor); +} + +// Returns the message type for a FieldData. +std::string MessageType(FieldData message) { + return options_field_util::ParseTypeUrl( + std::string(message.message_value().type_url())); +} + // Copy literal options from graph_options to node_options. absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node, CalculatorGraphConfig* config) { Status status; - FieldData config_options, parent_node_options, graph_options; - status.Update( - options_field_util::GetOptionsMessage(*config, &config_options)); - status.Update( - options_field_util::GetOptionsMessage(parent_node, &parent_node_options)); - status.Update(options_field_util::MergeOptionsMessages( - config_options, parent_node_options, &graph_options)); - const Descriptor* options_descriptor = - OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl( - std::string(graph_options.message_value().type_url()))); - if (!options_descriptor) { - return status; - } + FieldData graph_data = options_field_util::AsFieldData(*config); + FieldData parent_data = options_field_util::AsFieldData(parent_node); OptionsSyntaxUtil syntax_util; for (auto& node : *config->mutable_node()) { - FieldData node_data; - status.Update(options_field_util::GetOptionsMessage(node, &node_data)); - if (!node_data.has_message_value() || node.option_value_size() == 0) { - continue; - } - const Descriptor* node_options_descriptor = - OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl( - std::string(node_data.message_value().type_url()))); - if (!node_options_descriptor) { - continue; - } + FieldData node_data = options_field_util::AsFieldData(node); + for (const std::string& option_def : node.option_value()) { - std::vector tag_and_name = absl::StrSplit(option_def, ':'); + std::vector tag_and_name = + syntax_util.StrSplitTags(option_def); std::string graph_tag = syntax_util.OptionFieldsTag(tag_and_name[1]); + std::string graph_extension_type = ExtensionType(graph_tag); std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]); + std::string node_extension_type = ExtensionType(node_tag); + FieldData graph_options; + GetGraphOptions(graph_data, graph_extension_type, &graph_options) + .IgnoreError(); + FieldData parent_options; + GetNodeOptions(parent_data, graph_extension_type, &parent_options) + .IgnoreError(); + status.Update( + MergeMessages(graph_options, parent_options, &graph_options)); + FieldData node_options; + status.Update( + GetNodeOptions(node_data, node_extension_type, &node_options)); + if (!node_options.has_message_value() || + !graph_options.has_message_value()) { + continue; + } + FieldPath graph_path = GetPath(graph_tag, MessageType(graph_options)); + FieldPath node_path = GetPath(node_tag, MessageType(node_options)); FieldData packet_data; - status.Update(options_field_util::GetField( - syntax_util.OptionFieldPath(graph_tag, options_descriptor), - graph_options, &packet_data)); - status.Update(options_field_util::MergeField( - syntax_util.OptionFieldPath(node_tag, node_options_descriptor), - packet_data, &node_data)); + status.Update(GetField(graph_path, graph_options, &packet_data)); + status.Update(MergeField(node_path, packet_data, &node_options)); + options_field_util::SetOptionsMessage(node_options, &node); } - options_field_util::SetOptionsMessage(node_data, &node); } return status; } diff --git a/mediapipe/framework/tool/options_util_test.cc b/mediapipe/framework/tool/options_util_test.cc index 55263d00e..b3ea619b6 100644 --- a/mediapipe/framework/tool/options_util_test.cc +++ b/mediapipe/framework/tool/options_util_test.cc @@ -15,6 +15,7 @@ #include #include +#include "absl/strings/string_view.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/deps/message_matchers.h" #include "mediapipe/framework/port/gtest.h" @@ -22,6 +23,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/testdata/night_light_calculator.pb.h" #include "mediapipe/framework/tool/node_chain_subgraph.pb.h" +#include "mediapipe/framework/tool/options_field_util.h" #include "mediapipe/framework/tool/options_registry.h" #include "mediapipe/framework/tool/options_syntax_util.h" @@ -51,6 +53,35 @@ class NightLightCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(NightLightCalculator); +using tool::options_field_util::FieldPath; + +// Validates FieldPathEntry contents. +bool Equals(const tool::options_field_util::FieldPathEntry& entry, + const std::string& field_name, int index, + const std::string& extension_type) { + const std::string& name = entry.field ? entry.field->name() : ""; + return name == field_name && entry.index == index && + entry.extension_type == extension_type; +} + +// Serializes a MessageLite into FieldData.message_value. +FieldData AsFieldData(const proto_ns::MessageLite& message) { + FieldData result; + *result.mutable_message_value()->mutable_value() = + message.SerializeAsString(); + result.mutable_message_value()->set_type_url(message.GetTypeName()); + return result; +} + +// Returns the type for the root options message if specified. +std::string ExtensionType(const std::string& option_fields_tag) { + tool::OptionsSyntaxUtil syntax_util; + tool::options_field_util::FieldPath field_path = + syntax_util.OptionFieldPath(option_fields_tag, nullptr); + std::string result = !field_path.empty() ? field_path[0].extension_type : ""; + return !result.empty() ? result : "*"; +} + // Tests for calculator and graph options. // class OptionsUtilTest : public ::testing::Test { @@ -150,8 +181,8 @@ TEST_F(OptionsUtilTest, OptionsSyntaxUtil) { EXPECT_EQ(tag, "OPTIONS/sub_options/num_lights"); field_path = syntax_util.OptionFieldPath(tag, descriptor); EXPECT_EQ(field_path.size(), 2); - EXPECT_EQ(field_path[0].first->name(), "sub_options"); - EXPECT_EQ(field_path[1].first->name(), "num_lights"); + EXPECT_EQ(field_path[0].field->name(), "sub_options"); + EXPECT_EQ(field_path[1].field->name(), "num_lights"); } { // A tag syntax with a text-coded separator. @@ -160,10 +191,100 @@ TEST_F(OptionsUtilTest, OptionsSyntaxUtil) { EXPECT_EQ(tag, "OPTIONS_Z0Z_sub_options_Z0Z_num_lights"); field_path = syntax_util.OptionFieldPath(tag, descriptor); EXPECT_EQ(field_path.size(), 2); - EXPECT_EQ(field_path[0].first->name(), "sub_options"); - EXPECT_EQ(field_path[1].first->name(), "num_lights"); + EXPECT_EQ(field_path[0].field->name(), "sub_options"); + EXPECT_EQ(field_path[1].field->name(), "num_lights"); } } +TEST_F(OptionsUtilTest, OptionFieldPath) { + tool::OptionsSyntaxUtil syntax_util; + std::vector split; + split = syntax_util.StrSplitTags("a/graph/option:a/node/option"); + EXPECT_EQ(2, split.size()); + EXPECT_EQ(split[0], "a/graph/option"); + EXPECT_EQ(split[1], "a/node/option"); + split = syntax_util.StrSplitTags("Ext::a/graph/option:Ext::a/node/option"); + EXPECT_EQ(2, split.size()); + EXPECT_EQ(split[0], "Ext::a/graph/option"); + EXPECT_EQ(split[1], "Ext::a/node/option"); + + split = + syntax_util.StrSplitTags("chain_length:options/sub_options/num_lights"); + EXPECT_EQ(2, split.size()); + EXPECT_EQ(split[0], "chain_length"); + EXPECT_EQ(split[1], "options/sub_options/num_lights"); + const tool::Descriptor* descriptor = + tool::OptionsRegistry::GetProtobufDescriptor( + "mediapipe.NightLightCalculatorOptions"); + tool::options_field_util::FieldPath field_path = + syntax_util.OptionFieldPath(split[1], descriptor); + EXPECT_EQ(field_path.size(), 2); + EXPECT_EQ(field_path[0].field->name(), "sub_options"); + EXPECT_EQ(field_path[1].field->name(), "num_lights"); +} + +TEST_F(OptionsUtilTest, FindOptionsMessage) { + tool::OptionsSyntaxUtil syntax_util; + std::vector split; + split = + syntax_util.StrSplitTags("chain_length:options/sub_options/num_lights"); + EXPECT_EQ(2, split.size()); + EXPECT_EQ(split[0], "chain_length"); + EXPECT_EQ(split[1], "options/sub_options/num_lights"); + const tool::Descriptor* descriptor = + tool::OptionsRegistry::GetProtobufDescriptor( + "mediapipe.NightLightCalculatorOptions"); + tool::options_field_util::FieldPath field_path = + syntax_util.OptionFieldPath(split[1], descriptor); + EXPECT_EQ(field_path.size(), 2); + EXPECT_TRUE(Equals(field_path[0], "sub_options", 0, "")); + EXPECT_TRUE(Equals(field_path[1], "num_lights", 0, "")); + + { + // NightLightCalculatorOptions in Node.options. + CalculatorGraphConfig::Node node; + NightLightCalculatorOptions* options = + node.mutable_options()->MutableExtension( + NightLightCalculatorOptions::ext); + options->mutable_sub_options()->add_num_lights(33); + + // Retrieve the specified option. + FieldData node_data = AsFieldData(node); + auto path = field_path; + std::string node_extension_type = ExtensionType(std::string(split[1])); + FieldData node_options; + MP_EXPECT_OK(tool::options_field_util::GetNodeOptions( + node_data, node_extension_type, &node_options)); + FieldData packet_data; + MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options, + &packet_data)); + EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value); + EXPECT_EQ(packet_data.int32_value(), 33); + } + + { + // NightLightCalculatorOptions in Node.node_options. + CalculatorGraphConfig::Node node; + NightLightCalculatorOptions options; + options.mutable_sub_options()->add_num_lights(33); + node.add_node_options()->PackFrom(options); + + // Retrieve the specified option. + FieldData node_data = AsFieldData(node); + auto path = field_path; + std::string node_extension_type = ExtensionType(std::string(split[1])); + FieldData node_options; + MP_EXPECT_OK(tool::options_field_util::GetNodeOptions( + node_data, node_extension_type, &node_options)); + FieldData packet_data; + MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options, + &packet_data)); + EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value); + EXPECT_EQ(packet_data.int32_value(), 33); + } + + // TODO: Test with specified extension_type. +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index da53b98a4..0db044d96 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -207,16 +207,20 @@ cc_library( cc_library( name = "gpu_buffer", + srcs = ["gpu_buffer.cc"], hdrs = ["gpu_buffer.h"], visibility = ["//visibility:public"], deps = [ ":gl_base", + ":gl_context", ":gpu_buffer_format", + "//mediapipe/framework/formats:image_frame", ] + select({ "//conditions:default": [ ":gl_texture_buffer", ], "//mediapipe:ios": [ + "//mediapipe/objc:util", "//mediapipe/objc:CFHolder", ], "//mediapipe:macos": [ @@ -478,6 +482,7 @@ cc_library( "//mediapipe:ios": [ ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", + "//mediapipe/objc:util", ], "//mediapipe:macos": [ ":pixel_buffer_pool_util", @@ -498,55 +503,40 @@ cc_library( ], ) -HELPER_ANDROID_SRCS = [ - "gl_calculator_helper_impl_android.cc", - "gl_calculator_helper_impl_common.cc", -] - -HELPER_ANDROID_HDRS = [ - "egl_surface_holder.h", -] - -HELPER_COMMON_SRCS = [ - "gl_calculator_helper.cc", -] - -HELPER_COMMON_HDRS = [ - "gl_calculator_helper.h", - "gl_calculator_helper_impl.h", -] - -HELPER_IOS_SRCS = [ - "gl_calculator_helper_impl_ios.mm", - "gl_calculator_helper_impl_common.cc", -] - -HELPER_IOS_FRAMEWORKS = [ - "AVFoundation", - "CoreVideo", - "CoreGraphics", - "CoreMedia", - "GLKit", - "QuartzCore", -] + select({ - "//conditions:default": [ - "OpenGLES", +cc_library( + name = "egl_surface_holder", + hdrs = ["egl_surface_holder.h"], + deps = [ + ":gl_base", + "@com_google_absl//absl/synchronization", ], - "//mediapipe:macos": [ - "OpenGL", - "AppKit", - ], -}) +) cc_library( name = "gl_calculator_helper", - srcs = select({ - "//conditions:default": HELPER_COMMON_SRCS + HELPER_ANDROID_SRCS, - "//mediapipe:apple": [], - }), - hdrs = HELPER_COMMON_HDRS + select({ - "//conditions:default": HELPER_ANDROID_HDRS, - "//mediapipe:apple": [], + srcs = [ + "gl_calculator_helper.cc", + "gl_calculator_helper_impl_common.cc", + ], + hdrs = [ + "gl_calculator_helper.h", + "gl_calculator_helper_impl.h", + ], + linkopts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-framework AVFoundation", + "-framework CoreVideo", + "-framework CoreGraphics", + "-framework CoreMedia", + "-framework GLKit", + "-framework QuartzCore", + ], + }) + select({ + "//conditions:default": [], + "//mediapipe:macos": [ + "-framework AppKit", + ], }), visibility = ["//visibility:public"], deps = [ @@ -582,34 +572,20 @@ cc_library( ] + select({ "//conditions:default": [ ], - "//mediapipe:apple": [ - ":gl_calculator_helper_ios", - "//mediapipe/objc:util", - "//mediapipe/objc:CFHolder", - ], + "//mediapipe:apple": [], }), ) +# TODO: remove objc_library( name = "gl_calculator_helper_ios", - srcs = HELPER_COMMON_SRCS + HELPER_IOS_SRCS, - hdrs = HELPER_COMMON_HDRS, copts = [ "-Wno-shorten-64-to-32", "-std=c++17", ], - sdk_frameworks = HELPER_IOS_FRAMEWORKS, visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", - ":gpu_buffer", - ":gpu_buffer_multi_pool", - ":gpu_service", - ":gpu_shared_data_internal", - ":shader_util", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:image", + ":gl_calculator_helper", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:util", ], @@ -769,6 +745,7 @@ cc_library( srcs = ["gl_surface_sink_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":egl_surface_holder", ":gl_calculator_helper", ":gl_quad_renderer", ":gpu_buffer", diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index aa708e731..efd822ae6 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -24,15 +24,8 @@ #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_service.h" -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif - namespace mediapipe { -GlTexture::GlTexture(GLuint name, int width, int height) - : name_(name), width_(width), height_(height), target_(GL_TEXTURE_2D) {} - // The constructor and destructor need to be defined here so that // std::unique_ptr can see the full definition of GlCalculatorHelperImpl. // In the header, it is an incomplete type. diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index b5cc69990..d6bd71895 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -31,8 +31,6 @@ #ifdef __APPLE__ #include - -#include "mediapipe/objc/CFHolder.h" #endif // __APPLE__ namespace mediapipe { @@ -42,14 +40,6 @@ class GlTexture; class GpuResources; struct GpuSharedData; -#ifdef __APPLE__ -#if TARGET_OS_OSX -typedef CVOpenGLTextureRef CVTextureType; -#else -typedef CVOpenGLESTextureRef CVTextureType; -#endif // TARGET_OS_OSX -#endif // __APPLE__ - using ImageFrameSharedPtr = std::shared_ptr; // TODO: remove this and Process below, or make Process available @@ -174,14 +164,12 @@ class GlCalculatorHelper { class GlTexture { public: GlTexture() {} - GlTexture(GLuint name, int width, int height); - ~GlTexture() { Release(); } - int width() const { return width_; } - int height() const { return height_; } - GLenum target() const { return target_; } - GLuint name() const { return name_; } + int width() const { return view_.width(); } + int height() const { return view_.height(); } + GLenum target() const { return view_.target(); } + GLuint name() const { return view_.name(); } // Returns a buffer that can be sent to another calculator. // & manages sync token @@ -190,26 +178,12 @@ class GlTexture { std::unique_ptr GetFrame() const; // Releases texture memory & manages sync token - void Release(); + void Release() { view_.Release(); } private: + explicit GlTexture(GlTextureView view) : view_(std::move(view)) {} friend class GlCalculatorHelperImpl; - GlCalculatorHelperImpl* helper_impl_ = nullptr; - GLuint name_ = 0; - int width_ = 0; - int height_ = 0; - GLenum target_ = GL_TEXTURE_2D; - -#ifdef MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - // For CVPixelBufferRef-based rendering - CFHolder cv_texture_; -#else - // Keeps track of whether this texture mapping is for read access, so that - // we can create a consumer sync point when releasing it. - bool for_reading_ = false; -#endif - GpuBuffer gpu_buffer_; - int plane_ = 0; + GlTextureView view_; }; // Returns the entry with the given tag if the collection uses tags, with the diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h index 438111183..40b53f571 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ b/mediapipe/gpu/gl_calculator_helper_impl.h @@ -58,19 +58,14 @@ class GlCalculatorHelperImpl { GlContext& GetGlContext() const; // For internal use. - void ReadTexture(const GlTexture& texture, void* output, size_t size); + static void ReadTexture(const GlTextureView& view, void* output, size_t size); private: // Makes a GpuBuffer accessible as a texture in the GL context. - GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, int plane); - -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - GlTexture MapGlTextureBuffer(const GlTextureBufferSharedPtr& texture_buffer); - GlTextureBufferSharedPtr MakeGlTextureBuffer(const ImageFrame& image_frame); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - - // Sets default texture filtering parameters. - void SetStandardTextureParams(GLenum target, GLint internal_format); + GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, int plane, + bool for_reading); + void AttachGlTexture(GlTexture& texture, const GpuBuffer& gpu_buffer, + int plane, bool for_reading); // Create the framebuffer for rendering. void CreateFramebuffer(); @@ -80,10 +75,6 @@ class GlCalculatorHelperImpl { GLuint framebuffer_ = 0; GpuResources& gpu_resources_; - - // Necessary to compute for a given GlContext in order to properly enforce the - // SetStandardTextureParams. - bool can_linear_filter_float_textures_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gl_calculator_helper_impl_android.cc b/mediapipe/gpu/gl_calculator_helper_impl_android.cc deleted file mode 100644 index 340734335..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl_android.cc +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "mediapipe/gpu/gl_calculator_helper_impl.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" - -namespace mediapipe { - -// TODO: move this method to GlCalculatorHelper, then we can -// access its framebuffer instead of requiring that one is already set. -template <> -std::unique_ptr GlTexture::GetFrame() const { - auto output = - absl::make_unique(ImageFormat::SRGBA, width_, height_, - ImageFrame::kGlDefaultAlignmentBoundary); - - CHECK(helper_impl_); - helper_impl_->ReadTexture(*this, output->MutablePixelData(), - output->PixelDataSize()); - - return output; -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { -#ifdef __EMSCRIPTEN__ - // When WebGL is used, the GL context may be spontaneously lost which can - // cause GpuBuffer allocations to fail. In that case, return a dummy buffer - // to allow processing of the current frame complete. - if (!gpu_buffer_) { - return std::make_unique(); - } -#endif // __EMSCRIPTEN__ - - CHECK(gpu_buffer_); - // Inform the GlTextureBuffer that we have produced new content, and create - // a producer sync point. - gpu_buffer_.GetGlTextureBufferSharedPtr()->Updated( - helper_impl_->GetGlContext().CreateSyncToken()); - -#ifdef __ANDROID__ - // On (some?) Android devices, the texture may need to be explicitly - // detached from the current framebuffer. - // TODO: is this necessary even with the unbind in BindFramebuffer? - // It is not clear if this affected other contexts too, but let's keep it - // while in doubt. - GLint type = GL_NONE; - glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE, - &type); - if (type == GL_TEXTURE) { - GLint color_attachment = 0; - glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME, - &color_attachment); - if (color_attachment == name_) { - glBindFramebuffer(GL_FRAMEBUFFER, 0); - } - } - - // Some Android drivers log a GL_INVALID_ENUM error after the first - // glGetFramebufferAttachmentParameteriv call if there is no bound object, - // even though it should be ok to ask for the type and get back GL_NONE. - // Let's just ignore any pending errors here. - GLenum error; - while ((error = glGetError()) != GL_NO_ERROR) { - } - -#endif // __ANDROID__ - return absl::make_unique(gpu_buffer_); -} - -void GlTexture::Release() { - if (for_reading_ && gpu_buffer_) { - // Inform the GlTextureBuffer that we have finished accessing its contents, - // and create a consumer sync point. - gpu_buffer_.GetGlTextureBufferSharedPtr()->DidRead( - helper_impl_->GetGlContext().CreateSyncToken()); - } - helper_impl_ = nullptr; - for_reading_ = false; - gpu_buffer_ = nullptr; - plane_ = 0; - name_ = 0; - width_ = 0; - height_ = 0; -} - -} // namespace mediapipe diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc index fb2685bcb..fef255049 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_common.cc @@ -25,17 +25,6 @@ GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc, GpuResources* gpu_resources) : gpu_resources_(*gpu_resources) { gl_context_ = gpu_resources_.gl_context(cc); -// GL_ES_VERSION_2_0 and up (at least through ES 3.2) may contain the extension. -// Checking against one also checks against higher ES versions. So this checks -// against GLES >= 2.0. -#if GL_ES_VERSION_2_0 - // No linear float filtering by default, check extensions. - can_linear_filter_float_textures_ = - gl_context_->HasGlExtension("OES_texture_float_linear"); -#else - // Any float32 texture we create should automatically have linear filtering. - can_linear_filter_float_textures_ = true; -#endif // GL_ES_VERSION_2_0 } GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { @@ -101,98 +90,59 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) { #endif } -void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target, - GLint internal_format) { - // Default to using linear filter everywhere. For float32 textures, fall back - // to GL_NEAREST if linear filtering unsupported. - GLint filter; - switch (internal_format) { - case GL_R32F: - case GL_RG32F: - case GL_RGBA32F: - // 32F (unlike 16f) textures do not always support texture filtering - // (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION]) - filter = can_linear_filter_float_textures_ ? GL_LINEAR : GL_NEAREST; - break; - default: - filter = GL_LINEAR; - } - glTexParameteri(target, GL_TEXTURE_MIN_FILTER, filter); - glTexParameteri(target, GL_TEXTURE_MAG_FILTER, filter); - glTexParameteri(target, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); - glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); -} +GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, + int plane, bool for_reading) { + GlTextureView view = gpu_buffer.GetGlTextureView(plane, for_reading); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const ImageFrame& image_frame) { - GlTexture texture = MapGlTextureBuffer(MakeGlTextureBuffer(image_frame)); - texture.for_reading_ = true; - return texture; + if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { + // TODO: do the params need to be reset here?? + glBindTexture(view.target(), view.name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + gpu_buffer.format(), view.plane(), GetGlVersion()); + gl_context_->SetStandardTextureParams(view.target(), + info.gl_internal_format); + glBindTexture(view.target(), 0); + } + + return GlTexture(std::move(view)); } GlTexture GlCalculatorHelperImpl::CreateSourceTexture( const GpuBuffer& gpu_buffer) { - GlTexture texture = MapGpuBuffer(gpu_buffer, 0); - texture.for_reading_ = true; - return texture; + return MapGpuBuffer(gpu_buffer, 0, true); } GlTexture GlCalculatorHelperImpl::CreateSourceTexture( const GpuBuffer& gpu_buffer, int plane) { - GlTexture texture = MapGpuBuffer(gpu_buffer, plane); - texture.for_reading_ = true; - return texture; + return MapGpuBuffer(gpu_buffer, plane, true); } -GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, - int plane) { - CHECK_EQ(plane, 0); - return MapGlTextureBuffer(gpu_buffer.GetGlTextureBufferSharedPtr()); -} - -GlTexture GlCalculatorHelperImpl::MapGlTextureBuffer( - const GlTextureBufferSharedPtr& texture_buffer) { - // Insert wait call to sync with the producer. - texture_buffer->WaitOnGpu(); - GlTexture texture; - texture.helper_impl_ = this; - texture.gpu_buffer_ = GpuBuffer(texture_buffer); - texture.plane_ = 0; - texture.width_ = texture_buffer->width_; - texture.height_ = texture_buffer->height_; - texture.target_ = texture_buffer->target_; - texture.name_ = texture_buffer->name_; - - if (texture_buffer->format() != GpuBufferFormat::kUnknown) { - // TODO: do the params need to be reset here?? - glBindTexture(texture.target(), texture.name()); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - texture_buffer->format(), texture.plane_, GetGlVersion()); - SetStandardTextureParams(texture.target(), info.gl_internal_format); - glBindTexture(texture.target(), 0); - } - - return texture; -} - -GlTextureBufferSharedPtr GlCalculatorHelperImpl::MakeGlTextureBuffer( +GlTexture GlCalculatorHelperImpl::CreateSourceTexture( const ImageFrame& image_frame) { - CHECK(gl_context_->IsCurrent()); - - auto buffer = GlTextureBuffer::Create(image_frame); - - if (buffer->format_ != GpuBufferFormat::kUnknown) { - glBindTexture(GL_TEXTURE_2D, buffer->name_); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - buffer->format_, /*plane=*/0, GetGlVersion()); - SetStandardTextureParams(buffer->target_, info.gl_internal_format); - glBindTexture(GL_TEXTURE_2D, 0); - } - - return buffer; + GlTexture texture = + MapGpuBuffer(GpuBuffer::CopyingImageFrame(image_frame), 0, true); + return texture; +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + return view_.gpu_buffer().AsImageFrame(); +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + auto gpu_buffer = view_.gpu_buffer(); +#ifdef __EMSCRIPTEN__ + // When WebGL is used, the GL context may be spontaneously lost which can + // cause GpuBuffer allocations to fail. In that case, return a dummy buffer + // to allow processing of the current frame complete. + if (!gpu_buffer) { + return std::make_unique(); + } +#endif // __EMSCRIPTEN__ + view_.DoneWriting(); + return absl::make_unique(gpu_buffer); } -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( int width, int height, GpuBufferFormat format) { @@ -202,44 +152,9 @@ GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( GpuBuffer buffer = gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format); - GlTexture texture = MapGpuBuffer(buffer, 0); + GlTexture texture = MapGpuBuffer(buffer, 0, false); return texture; } -void GlCalculatorHelperImpl::ReadTexture(const GlTexture& texture, void* output, - size_t size) { - CHECK_GE(size, texture.width_ * texture.height_ * 4); - - GLint current_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo); - CHECK_NE(current_fbo, 0); - - GLint color_attachment_name; - glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME, - &color_attachment_name); - if (color_attachment_name != texture.name_) { - // Save the viewport. Note that we assume that the color attachment is a - // GL_TEXTURE_2D texture. - GLint viewport[4]; - glGetIntegerv(GL_VIEWPORT, viewport); - - // Set the data from GLTexture object. - glViewport(0, 0, texture.width_, texture.height_); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - texture.target_, texture.name_, 0); - glReadPixels(0, 0, texture.width_, texture.height_, GL_RGBA, - GL_UNSIGNED_BYTE, output); - - // Restore from the saved viewport and color attachment name. - glViewport(viewport[0], viewport[1], viewport[2], viewport[3]); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, - color_attachment_name, 0); - } else { - glReadPixels(0, 0, texture.width_, texture.height_, GL_RGBA, - GL_UNSIGNED_BYTE, output); - } -} - } // namespace mediapipe diff --git a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm deleted file mode 100644 index e91d36e2c..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mediapipe/gpu/gl_calculator_helper_impl.h" - -#if TARGET_OS_OSX -#import -#else -#import -#endif // TARGET_OS_OSX -#import - -#include "absl/memory/memory.h" -#include "mediapipe/gpu/gpu_buffer_multi_pool.h" -#include "mediapipe/gpu/pixel_buffer_pool_util.h" -#include "mediapipe/objc/util.h" - -namespace mediapipe { - -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const mediapipe::ImageFrame& image_frame) { - GlTexture texture; - - texture.helper_impl_ = this; - texture.width_ = image_frame.Width(); - texture.height_ = image_frame.Height(); - auto format = GpuBufferFormatForImageFormat(image_frame.Format()); - - GlTextureInfo info = GlTextureInfoForGpuBufferFormat(format, 0, GetGlVersion()); - - glGenTextures(1, &texture.name_); - glBindTexture(GL_TEXTURE_2D, texture.name_); - glTexImage2D(GL_TEXTURE_2D, 0, info.gl_internal_format, texture.width_, - texture.height_, 0, info.gl_format, info.gl_type, - image_frame.PixelData()); - SetStandardTextureParams(GL_TEXTURE_2D, info.gl_internal_format); - return texture; -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer) { - return MapGpuBuffer(gpu_buffer, 0); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer, int plane) { - return MapGpuBuffer(gpu_buffer, plane); -} - -GlTexture GlCalculatorHelperImpl::MapGpuBuffer( - const GpuBuffer& gpu_buffer, int plane) { - CVReturn err; - GlTexture texture; - texture.helper_impl_ = this; - texture.gpu_buffer_ = gpu_buffer; - texture.plane_ = plane; - - const GlTextureInfo info = - GlTextureInfoForGpuBufferFormat(gpu_buffer.format(), plane, GetGlVersion()); - // When scale is not 1, we still give the nominal size of the image. - texture.width_ = gpu_buffer.width(); - texture.height_ = gpu_buffer.height(); - -#if TARGET_OS_OSX - CVOpenGLTextureRef cv_texture_temp; - err = CVOpenGLTextureCacheCreateTextureFromImage( - kCFAllocatorDefault, gl_context_->cv_texture_cache(), gpu_buffer.GetCVPixelBufferRef(), NULL, - &cv_texture_temp); - NSCAssert(cv_texture_temp && !err, - @"Error at CVOpenGLTextureCacheCreateTextureFromImage %d", err); - texture.cv_texture_.adopt(cv_texture_temp); - texture.target_ = CVOpenGLTextureGetTarget(*texture.cv_texture_); - texture.name_ = CVOpenGLTextureGetName(*texture.cv_texture_); -#else - CVOpenGLESTextureRef cv_texture_temp; - err = CVOpenGLESTextureCacheCreateTextureFromImage( - kCFAllocatorDefault, gl_context_->cv_texture_cache(), gpu_buffer.GetCVPixelBufferRef(), NULL, - GL_TEXTURE_2D, info.gl_internal_format, texture.width_ / info.downscale, - texture.height_ / info.downscale, info.gl_format, info.gl_type, plane, - &cv_texture_temp); - NSCAssert(cv_texture_temp && !err, - @"Error at CVOpenGLESTextureCacheCreateTextureFromImage %d", err); - texture.cv_texture_.adopt(cv_texture_temp); - texture.target_ = CVOpenGLESTextureGetTarget(*texture.cv_texture_); - texture.name_ = CVOpenGLESTextureGetName(*texture.cv_texture_); -#endif // TARGET_OS_OSX - - glBindTexture(texture.target(), texture.name()); - SetStandardTextureParams(texture.target(), info.gl_internal_format); - - return texture; -} -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -template<> -std::unique_ptr GlTexture::GetFrame() const { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - if (gpu_buffer_.GetCVPixelBufferRef()) { - return CreateImageFrameForCVPixelBuffer(gpu_buffer_.GetCVPixelBufferRef()); - } - - ImageFormat::Format image_format = - ImageFormatForGpuBufferFormat(gpu_buffer_.format()); - CHECK(helper_impl_); - GlTextureInfo info = - GlTextureInfoForGpuBufferFormat(gpu_buffer_.format(), plane_, helper_impl_->GetGlVersion()); - - auto output = absl::make_unique( - image_format, width_, height_); - - glReadPixels(0, 0, width_, height_, info.gl_format, info.gl_type, - output->MutablePixelData()); - return output; -#else - CHECK(gpu_buffer_.format() == GpuBufferFormat::kBGRA32); - auto output = - absl::make_unique(ImageFormat::SRGBA, width_, height_, - ImageFrame::kGlDefaultAlignmentBoundary); - - CHECK(helper_impl_); - helper_impl_->ReadTexture(*this, output->MutablePixelData(), output->PixelDataSize()); - - return output; -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -} - -template<> -std::unique_ptr GlTexture::GetFrame() const { - NSCAssert(gpu_buffer_, @"gpu_buffer_ must be valid"); -#if TARGET_IPHONE_SIMULATOR - CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef(); - CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); - NSCAssert(err == kCVReturnSuccess, @"CVPixelBufferLockBaseAddress failed: %d", err); - OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); - size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); - uint8_t* pixel_ptr = static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); - if (pixel_format == kCVPixelFormatType_32BGRA) { - // TODO: restore previous framebuffer? Move this to helper so we can - // use BindFramebuffer? - glViewport(0, 0, width_, height_); - glFramebufferTexture2D( - GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target_, name_, 0); - - size_t contiguous_bytes_per_row = width_ * 4; - if (bytes_per_row == contiguous_bytes_per_row) { - glReadPixels(0, 0, width_, height_, GL_BGRA, GL_UNSIGNED_BYTE, pixel_ptr); - } else { - std::vector contiguous_buffer(contiguous_bytes_per_row * height_); - uint8_t* temp_ptr = contiguous_buffer.data(); - glReadPixels(0, 0, width_, height_, GL_BGRA, GL_UNSIGNED_BYTE, temp_ptr); - for (int i = 0; i < height_; ++i) { - memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); - temp_ptr += contiguous_bytes_per_row; - pixel_ptr += bytes_per_row; - } - } - } else { - uint32_t format_big = CFSwapInt32HostToBig(pixel_format); - NSLog(@"unsupported pixel format: %.4s", (char*)&format_big); - } - err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); - NSCAssert(err == kCVReturnSuccess, @"CVPixelBufferUnlockBaseAddress failed: %d", err); -#endif - return absl::make_unique(gpu_buffer_); -} - -void GlTexture::Release() { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - if (*cv_texture_) { - cv_texture_.reset(NULL); - } else if (name_) { - // This is only needed because of the glGenTextures in - // CreateSourceTexture(ImageFrame)... change. - glDeleteTextures(1, &name_); - } -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - helper_impl_ = nullptr; - gpu_buffer_ = nullptr; - plane_ = 0; - name_ = 0; - width_ = 0; - height_ = 0; -} - -} // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 0c6865a86..c27a8b44e 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -222,6 +222,9 @@ bool GlContext::HasGlExtension(absl::string_view extension) const { // to work with GL_EXTENSIONS for newer GL versions, so we must maintain both // variations of this function. absl::Status GlContext::GetGlExtensions() { + // RET_CHECK logs by default, but here we just want to check the precondition; + // we'll fall back to the alternative implementation for older versions. + RET_CHECK(gl_major_version_ >= 3).SetNoLogging(); gl_extensions_.clear(); // glGetStringi only introduced in GL 3.0+; so we exit out this function if // we don't have that function defined, regardless of version number reported. @@ -330,13 +333,24 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ << " (" << glGetString(GL_VERSION) << ")"; - if (gl_major_version_ >= 3) { + { auto status = GetGlExtensions(); - if (status.ok()) { - return absl::OkStatus(); + if (!status.ok()) { + status = GetGlExtensionsCompat(); } + MP_RETURN_IF_ERROR(status); } - return GetGlExtensionsCompat(); + +#if GL_ES_VERSION_2_0 // This actually means "is GLES available". + // No linear float filtering by default, check extensions. + can_linear_filter_float_textures_ = + HasGlExtension("OES_texture_float_linear"); +#else + // Desktop GL should always allow linear filtering. + can_linear_filter_float_textures_ = true; +#endif // GL_ES_VERSION_2_0 + + return absl::OkStatus(); }); } @@ -841,4 +855,25 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, return GlTextureInfoForGpuBufferFormat(format, plane, ctx->GetGlVersion()); } +void GlContext::SetStandardTextureParams(GLenum target, GLint internal_format) { + // Default to using linear filter everywhere. For float32 textures, fall back + // to GL_NEAREST if linear filtering unsupported. + GLint filter; + switch (internal_format) { + case GL_R32F: + case GL_RG32F: + case GL_RGBA32F: + // 32F (unlike 16f) textures do not always support texture filtering + // (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION]) + filter = can_linear_filter_float_textures_ ? GL_LINEAR : GL_NEAREST; + break; + default: + filter = GL_LINEAR; + } + glTexParameteri(target, GL_TEXTURE_MIN_FILTER, filter); + glTexParameteri(target, GL_TEXTURE_MAG_FILTER, filter); + glTexParameteri(target, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); +} + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 024ed9e5f..6cab706a5 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -276,6 +276,9 @@ class GlContext : public std::enable_shared_from_this { }).IgnoreError(); } + // Sets default texture filtering parameters. + void SetStandardTextureParams(GLenum target, GLint internal_format); + // These are used for testing specific SyncToken implementations. Do not use // outside of tests. enum class SyncTokenTypeForTest { @@ -342,11 +345,11 @@ class GlContext : public std::enable_shared_from_this { // This wraps a thread_local. static std::weak_ptr& CurrentContext(); - static absl::Status SwitchContext(ContextBinding* old_context, + static absl::Status SwitchContext(ContextBinding* saved_context, const ContextBinding& new_context); - absl::Status EnterContext(ContextBinding* previous_context); - absl::Status ExitContext(const ContextBinding* previous_context); + absl::Status EnterContext(ContextBinding* saved_context); + absl::Status ExitContext(const ContextBinding* saved_context); void DestroyContext(); bool HasContext() const; @@ -383,7 +386,7 @@ class GlContext : public std::enable_shared_from_this { static void GetCurrentContextBinding(ContextBinding* binding); // Makes the context described by new_context current on this thread. static absl::Status SetCurrentContextBinding( - const ContextBinding& new_context); + const ContextBinding& new_binding); // If not null, a dedicated thread used to execute tasks on this context. // Used on Android due to expensive context switching on some configurations. @@ -396,6 +399,10 @@ class GlContext : public std::enable_shared_from_this { // so we should be fine storing the extension pieces as string_view's. std::set gl_extensions_; + // Used by SetStandardTextureParams. Do we want several of these bools, or a + // better mechanism? + bool can_linear_filter_float_textures_; + // Number of glFinish calls completed on the GL thread. // Changes should be guarded by mutex_. However, we use simple atomic // loads for efficiency on the fast path. diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 26b165f8e..9386f2ce2 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -85,7 +85,7 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, return std::move(context); } -absl::Status GlContext::CreateContextInternal(EGLContext external_context, +absl::Status GlContext::CreateContextInternal(EGLContext share_context, int gl_version) { CHECK(gl_version == 2 || gl_version == 3); @@ -131,8 +131,7 @@ absl::Status GlContext::CreateContextInternal(EGLContext external_context, // clang-format on }; - context_ = - eglCreateContext(display_, config_, external_context, context_attr); + context_ = eglCreateContext(display_, config_, share_context, context_attr); int error = eglGetError(); RET_CHECK(context_ != EGL_NO_CONTEXT) << "Could not create GLES " << gl_version << " context; " @@ -149,7 +148,7 @@ absl::Status GlContext::CreateContextInternal(EGLContext external_context, return absl::OkStatus(); } -absl::Status GlContext::CreateContext(EGLContext external_context) { +absl::Status GlContext::CreateContext(EGLContext share_context) { EGLint major = 0; EGLint minor = 0; @@ -163,11 +162,11 @@ absl::Status GlContext::CreateContext(EGLContext external_context) { LOG(INFO) << "Successfully initialized EGL. Major : " << major << " Minor: " << minor; - auto status = CreateContextInternal(external_context, 3); + auto status = CreateContextInternal(share_context, 3); if (!status.ok()) { LOG(WARNING) << "Creating a context with OpenGL ES 3 failed: " << status; LOG(WARNING) << "Fall back on OpenGL ES 2."; - status = CreateContextInternal(external_context, 2); + status = CreateContextInternal(share_context, 2); } MP_RETURN_IF_ERROR(status); diff --git a/mediapipe/gpu/gl_context_internal.h b/mediapipe/gpu/gl_context_internal.h index 16b7bf9bf..d683d4447 100644 --- a/mediapipe/gpu/gl_context_internal.h +++ b/mediapipe/gpu/gl_context_internal.h @@ -36,7 +36,7 @@ class GlContext::DedicatedThread { DedicatedThread& operator=(DedicatedThread) = delete; absl::Status Run(GlStatusFunction gl_func); - void RunWithoutWaiting(GlVoidFunction gl_fund); + void RunWithoutWaiting(GlVoidFunction gl_func); bool IsCurrentThread(); diff --git a/mediapipe/gpu/gl_ios_test.mm b/mediapipe/gpu/gl_ios_test.mm index 05147566a..674d5f126 100644 --- a/mediapipe/gpu/gl_ios_test.mm +++ b/mediapipe/gpu/gl_ios_test.mm @@ -175,18 +175,16 @@ mediapipe::GlCalculatorHelper helper; helper.InitializeForTest(&gpuData); - std::vector> sizes{ - {200, 300}, - {200, 299}, - {196, 300}, - {194, 300}, - {193, 300}, - }; - for (const auto& width_height : sizes) { - mediapipe::GlTexture texture = - helper.CreateDestinationTexture(width_height.first, width_height.second); - XCTAssertNotEqual(texture.name(), 0); - } + helper.RunInGlContext([&helper] { + std::vector> sizes{ + {200, 300}, {200, 299}, {196, 300}, {194, 300}, {193, 300}, + }; + for (const auto& width_height : sizes) { + mediapipe::GlTexture texture = + helper.CreateDestinationTexture(width_height.first, width_height.second); + XCTAssertNotEqual(texture.name(), 0); + } + }); } - (void)testSimpleConversionFromFormat:(OSType)cvPixelFormat { diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc new file mode 100644 index 000000000..d370d2b53 --- /dev/null +++ b/mediapipe/gpu/gpu_buffer.cc @@ -0,0 +1,260 @@ +#include "mediapipe/gpu/gpu_buffer.h" + +#include "mediapipe/gpu/gl_context.h" + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/objc/util.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +namespace mediapipe { + +void GlTextureView::Release() { + if (detach_) detach_(*this); + detach_ = nullptr; + gl_context_ = nullptr; + gpu_buffer_ = nullptr; + plane_ = 0; + name_ = 0; + width_ = 0; + height_ = 0; +} + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#if TARGET_OS_OSX +typedef CVOpenGLTextureRef CVTextureType; +#else +typedef CVOpenGLESTextureRef CVTextureType; +#endif // TARGET_OS_OSX + +GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const { + CVReturn err; + auto gl_context = GlContext::GetCurrent(); + CHECK(gl_context); +#if TARGET_OS_OSX + CVTextureType cv_texture_temp; + err = CVOpenGLTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, gl_context->cv_texture_cache(), + GetCVPixelBufferRef(), NULL, &cv_texture_temp); + CHECK(cv_texture_temp && !err) + << "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err; + CFHolder cv_texture; + cv_texture.adopt(cv_texture_temp); + return GlTextureView( + gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture), + CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane, + [cv_texture]( + mediapipe::GlTextureView&) { /* only retains cv_texture */ }); +#else + const GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + format(), plane, gl_context->GetGlVersion()); + CVTextureType cv_texture_temp; + err = CVOpenGLESTextureCacheCreateTextureFromImage( + kCFAllocatorDefault, gl_context->cv_texture_cache(), + GetCVPixelBufferRef(), NULL, GL_TEXTURE_2D, info.gl_internal_format, + width() / info.downscale, height() / info.downscale, info.gl_format, + info.gl_type, plane, &cv_texture_temp); + CHECK(cv_texture_temp && !err) + << "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err; + CFHolder cv_texture; + cv_texture.adopt(cv_texture_temp); + return GlTextureView( + gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture), + CVOpenGLESTextureGetName(*cv_texture), width(), height(), *this, plane, + [cv_texture]( + mediapipe::GlTextureView&) { /* only retains cv_texture */ }); +#endif // TARGET_OS_OSX +} + +GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) { + auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); + // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only + // deals with absl::Status in MediaPipe OSS. + CHECK_OK(maybe_buffer.status()); + return GpuBuffer(std::move(maybe_buffer).value()); +} + +std::unique_ptr GpuBuffer::AsImageFrame() const { + CHECK(GetCVPixelBufferRef()); + return CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef()); +} + +void GlTextureView::DoneWriting() const { + CHECK(gpu_buffer_); +#if TARGET_IPHONE_SIMULATOR + CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef(); + CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << err; + OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer); + size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer); + uint8_t* pixel_ptr = + static_cast(CVPixelBufferGetBaseAddress(pixel_buffer)); + if (pixel_format == kCVPixelFormatType_32BGRA) { + // TODO: restore previous framebuffer? Move this to helper so we + // can use BindFramebuffer? + glViewport(0, 0, width(), height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target(), + name(), 0); + + size_t contiguous_bytes_per_row = width() * 4; + if (bytes_per_row == contiguous_bytes_per_row) { + glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE, + pixel_ptr); + } else { + std::vector contiguous_buffer(contiguous_bytes_per_row * + height()); + uint8_t* temp_ptr = contiguous_buffer.data(); + glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE, + temp_ptr); + for (int i = 0; i < height(); ++i) { + memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row); + temp_ptr += contiguous_bytes_per_row; + pixel_ptr += bytes_per_row; + } + } + } else { + LOG(ERROR) << "unsupported pixel format: " << pixel_format; + } + err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); + CHECK(err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << err; +#endif +} +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const { + auto gl_context = GlContext::GetCurrent(); + CHECK(gl_context); + const GlTextureBufferSharedPtr& texture_buffer = + GetGlTextureBufferSharedPtr(); + // Insert wait call to sync with the producer. + texture_buffer->WaitOnGpu(); + CHECK_EQ(plane, 0); + GlTextureView::DetachFn detach; + if (for_reading) { + detach = [](mediapipe::GlTextureView& texture) { + // Inform the GlTextureBuffer that we have finished accessing its + // contents, and create a consumer sync point. + texture.gpu_buffer().GetGlTextureBufferSharedPtr()->DidRead( + texture.gl_context()->CreateSyncToken()); + }; + } + return GlTextureView(gl_context.get(), texture_buffer->target(), + texture_buffer->name(), width(), height(), *this, plane, + std::move(detach)); +} + +GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) { + auto gl_context = GlContext::GetCurrent(); + CHECK(gl_context); + + auto buffer = GlTextureBuffer::Create(image_frame); + + // TODO: does this need to set the texture params? We set them again when the + // texture is actually acccessed via GlTexture[View]. Or should they always be + // set on creation? + if (buffer->format() != GpuBufferFormat::kUnknown) { + glBindTexture(GL_TEXTURE_2D, buffer->name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + buffer->format(), /*plane=*/0, gl_context->GetGlVersion()); + gl_context->SetStandardTextureParams(buffer->target(), + info.gl_internal_format); + glBindTexture(GL_TEXTURE_2D, 0); + } + + return GpuBuffer(std::move(buffer)); +} + +static void ReadTexture(const GlTextureView& view, void* output, size_t size) { + // TODO: check buffer size? We could use glReadnPixels where available + // (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read + // won't overflow the buffer with glReadPixels, we'd also need to check or + // reset several glPixelStore parameters (e.g. what if someone had the + // ill-advised idea of setting GL_PACK_SKIP_PIXELS?). + CHECK(view.gl_context()); + GlTextureInfo info = + GlTextureInfoForGpuBufferFormat(view.gpu_buffer().format(), view.plane(), + view.gl_context()->GetGlVersion()); + + GLint current_fbo; + glGetIntegerv(GL_FRAMEBUFFER_BINDING, ¤t_fbo); + CHECK_NE(current_fbo, 0); + + GLint color_attachment_name; + glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME, + &color_attachment_name); + if (color_attachment_name != view.name()) { + // Save the viewport. Note that we assume that the color attachment is a + // GL_TEXTURE_2D texture. + GLint viewport[4]; + glGetIntegerv(GL_VIEWPORT, viewport); + + // Set the data from GLTextureView object. + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glReadPixels(0, 0, view.width(), view.height(), info.gl_format, + info.gl_type, output); + + // Restore from the saved viewport and color attachment name. + glViewport(viewport[0], viewport[1], viewport[2], viewport[3]); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, + color_attachment_name, 0); + } else { + glReadPixels(0, 0, view.width(), view.height(), info.gl_format, + info.gl_type, output); + } +} + +std::unique_ptr GpuBuffer::AsImageFrame() const { + ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format()); + auto output = absl::make_unique( + image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary); + auto view = GetGlTextureView(0, true); + ReadTexture(view, output->MutablePixelData(), output->PixelDataSize()); + return output; +} + +void GlTextureView::DoneWriting() const { + CHECK(gpu_buffer_); + // Inform the GlTextureBuffer that we have produced new content, and create + // a producer sync point. + gpu_buffer_.GetGlTextureBufferSharedPtr()->Updated( + gl_context()->CreateSyncToken()); + +#ifdef __ANDROID__ + // On (some?) Android devices, the texture may need to be explicitly + // detached from the current framebuffer. + // TODO: is this necessary even with the unbind in BindFramebuffer? + // It is not clear if this affected other contexts too, but let's keep it + // while in doubt. + GLint type = GL_NONE; + glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE, + &type); + if (type == GL_TEXTURE) { + GLint color_attachment = 0; + glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME, + &color_attachment); + if (color_attachment == name()) { + glBindFramebuffer(GL_FRAMEBUFFER, 0); + } + } + + // Some Android drivers log a GL_INVALID_ENUM error after the first + // glGetFramebufferAttachmentParameteriv call if there is no bound object, + // even though it should be ok to ask for the type and get back GL_NONE. + // Let's just ignore any pending errors here. + GLenum error; + while ((error = glGetError()) != GL_NO_ERROR) { + } + +#endif // __ANDROID__ +} + +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 2fd864766..f03c55bab 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -17,6 +17,7 @@ #include +#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_buffer_format.h" @@ -32,6 +33,9 @@ namespace mediapipe { +class GlContext; +class GlTextureView; + // This class wraps a platform-specific buffer of GPU data. // An instance of GpuBuffer acts as an opaque reference to the underlying // data object. @@ -84,6 +88,19 @@ class GpuBuffer { // Allow assignment from nullptr. GpuBuffer& operator=(std::nullptr_t other); + // TODO: split into read and write, remove const from write. + GlTextureView GetGlTextureView(int plane, bool for_reading) const; + + // Make a GpuBuffer copying the data from an ImageFrame. + static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame); + + // Make an ImageFrame, possibly sharing the same data. The data is shared if + // the GpuBuffer's storage supports memory sharing; otherwise, it is copied. + // In order to work correctly across platforms, callers should always treat + // the returned ImageFrame as if it shares memory with the GpuBuffer, i.e. + // treat it as immutable if the GpuBuffer must not be modified. + std::unique_ptr AsImageFrame() const; + private: #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CFHolder pixel_buffer_; @@ -92,6 +109,51 @@ class GpuBuffer { #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; +class GlTextureView { + public: + GlTextureView() {} + ~GlTextureView() { Release(); } + // TODO: make this class move-only. + + GlContext* gl_context() const { return gl_context_; } + int width() const { return width_; } + int height() const { return height_; } + GLenum target() const { return target_; } + GLuint name() const { return name_; } + const GpuBuffer& gpu_buffer() const { return gpu_buffer_; } + int plane() const { return plane_; } + + private: + friend class GpuBuffer; + using DetachFn = std::function; + GlTextureView(GlContext* context, GLenum target, GLuint name, int width, + int height, GpuBuffer gpu_buffer, int plane, DetachFn detach) + : gl_context_(context), + target_(target), + name_(name), + width_(width), + height_(height), + gpu_buffer_(std::move(gpu_buffer)), + plane_(plane), + detach_(std::move(detach)) {} + + // TODO: remove this friend declaration. + friend class GlTexture; + void Release(); + // TODO: make this non-const. + void DoneWriting() const; + + GlContext* gl_context_ = nullptr; + GLenum target_ = GL_TEXTURE_2D; + GLuint name_ = 0; + // Note: when scale is not 1, we still give the nominal size of the image. + int width_ = 0; + int height_ = 0; + GpuBuffer gpu_buffer_; + int plane_ = 0; + DetachFn detach_; +}; + #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER inline int GpuBuffer::width() const { diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 716a3b779..6e4fd38ea 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -21,10 +21,11 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" -#ifdef __APPLE__ +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #include "CoreFoundation/CFBase.h" #include "mediapipe/objc/CFHolder.h" -#endif // __APPLE__ +#include "mediapipe/objc/util.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER namespace mediapipe { diff --git a/mediapipe/gpu/pixel_buffer_pool_util.h b/mediapipe/gpu/pixel_buffer_pool_util.h index f5ae07df1..c852cf94d 100644 --- a/mediapipe/gpu/pixel_buffer_pool_util.h +++ b/mediapipe/gpu/pixel_buffer_pool_util.h @@ -63,11 +63,6 @@ CVReturn CreateCVPixelBufferWithPool(CVPixelBufferPoolRef pool, CFDictionaryRef CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold( int allocationThreshold); -// Create a CVPixelBuffer without using a pool. -CVReturn CreateCVPixelBufferWithoutPool(int width, int height, - OSType pixelFormat, - CVPixelBufferRef* outBuffer); - } // namespace mediapipe #endif // MEDIAPIPE_GPU_PIXEL_BUFFER_POOL_UTIL_H_ diff --git a/mediapipe/gpu/pixel_buffer_pool_util.mm b/mediapipe/gpu/pixel_buffer_pool_util.mm index 1e006fd48..0b13cb194 100644 --- a/mediapipe/gpu/pixel_buffer_pool_util.mm +++ b/mediapipe/gpu/pixel_buffer_pool_util.mm @@ -121,33 +121,4 @@ CVReturn CreateCVPixelBufferWithPool( return err; } -#if TARGET_IPHONE_SIMULATOR -static void FreeRefConReleaseCallback(void* refCon, const void* baseAddress) { - free(refCon); -} -#endif - -CVReturn CreateCVPixelBufferWithoutPool( - int width, int height, OSType pixelFormat, CVPixelBufferRef* outBuffer) { -#if TARGET_IPHONE_SIMULATOR - // On the simulator, syncing the texture with the pixelbuffer does not work, - // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not - // available in OpenGL ES 2, we should create the buffer so the pixels are - // contiguous. - // - // TODO: verify if we can use kIOSurfaceBytesPerRow to force - // CoreVideo to give us contiguous data. - size_t bytes_per_row = width * 4; - void* data = malloc(bytes_per_row * height); - return CVPixelBufferCreateWithBytes( - kCFAllocatorDefault, width, height, pixelFormat, data, bytes_per_row, - FreeRefConReleaseCallback, data, GetCVPixelBufferAttributesForGlCompatibility(), - outBuffer); -#else - return CVPixelBufferCreate( - kCFAllocatorDefault, width, height, pixelFormat, - GetCVPixelBufferAttributesForGlCompatibility(), outBuffer); -#endif -} - } // namespace mediapipe diff --git a/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc b/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc index db1bc3422..093a7325d 100644 --- a/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc +++ b/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc @@ -28,7 +28,7 @@ namespace mediapipe { namespace { -constexpr int kNumFaceLandmarkConnections = 124; +constexpr int kNumFaceLandmarkConnections = 132; // Pairs of landmark indices to be rendered with connections. constexpr int kFaceLandmarkConnections[] = { // Lips. @@ -43,6 +43,8 @@ constexpr int kFaceLandmarkConnections[] = { 133, // Left eyebrow. 46, 53, 53, 52, 52, 65, 65, 55, 70, 63, 63, 105, 105, 66, 66, 107, + // Left iris. + 474, 475, 475, 476, 476, 477, 477, 474, // Right eye. 263, 249, 249, 390, 390, 373, 373, 374, 374, 380, 380, 381, 381, 382, 382, 362, 263, 466, 466, 388, 388, 387, 387, 386, 386, 385, 385, 384, 384, 398, @@ -50,6 +52,8 @@ constexpr int kFaceLandmarkConnections[] = { // Right eyebrow. 276, 283, 283, 282, 282, 295, 295, 285, 300, 293, 293, 334, 334, 296, 296, 336, + // Right iris. + 469, 470, 470, 471, 471, 472, 472, 469, // Face oval. 10, 338, 338, 297, 297, 332, 332, 284, 284, 251, 251, 389, 389, 356, 356, 454, 454, 323, 323, 361, 361, 288, 288, 397, 397, 365, 365, 379, 379, 378, diff --git a/mediapipe/graphs/face_mesh/face_mesh_desktop.pbtxt b/mediapipe/graphs/face_mesh/face_mesh_desktop.pbtxt index c3aa3945d..215791a36 100644 --- a/mediapipe/graphs/face_mesh/face_mesh_desktop.pbtxt +++ b/mediapipe/graphs/face_mesh/face_mesh_desktop.pbtxt @@ -22,10 +22,12 @@ node { # Defines side packets for further use in the graph. node { calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:num_faces" + output_side_packet: "PACKET:0:num_faces" + output_side_packet: "PACKET:1:with_attention" node_options: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { packet { int_value: 1 } + packet { bool_value: true } } } } @@ -35,6 +37,7 @@ node { calculator: "FaceLandmarkFrontCpu" input_stream: "IMAGE:input_video" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" output_stream: "DETECTIONS:face_detections" diff --git a/mediapipe/graphs/face_mesh/face_mesh_desktop_live.pbtxt b/mediapipe/graphs/face_mesh/face_mesh_desktop_live.pbtxt index 57654436a..2cc563424 100644 --- a/mediapipe/graphs/face_mesh/face_mesh_desktop_live.pbtxt +++ b/mediapipe/graphs/face_mesh/face_mesh_desktop_live.pbtxt @@ -33,10 +33,12 @@ node { # Defines side packets for further use in the graph. node { calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:num_faces" + output_side_packet: "PACKET:0:num_faces" + output_side_packet: "PACKET:1:with_attention" node_options: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { packet { int_value: 1 } + packet { bool_value: true } } } } @@ -46,6 +48,7 @@ node { calculator: "FaceLandmarkFrontCpu" input_stream: "IMAGE:throttled_input_video" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" output_stream: "DETECTIONS:face_detections" diff --git a/mediapipe/graphs/face_mesh/face_mesh_desktop_live_gpu.pbtxt b/mediapipe/graphs/face_mesh/face_mesh_desktop_live_gpu.pbtxt index cfa75c2c7..ae03709fa 100644 --- a/mediapipe/graphs/face_mesh/face_mesh_desktop_live_gpu.pbtxt +++ b/mediapipe/graphs/face_mesh/face_mesh_desktop_live_gpu.pbtxt @@ -33,10 +33,12 @@ node { # Defines side packets for further use in the graph. node { calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:num_faces" + output_side_packet: "PACKET:0:num_faces" + output_side_packet: "PACKET:1:with_attention" node_options: { [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { packet { int_value: 1 } + packet { bool_value: true } } } } @@ -46,6 +48,7 @@ node { calculator: "FaceLandmarkFrontGpu" input_stream: "IMAGE:throttled_input_video" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" output_stream: "DETECTIONS:face_detections" diff --git a/mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt b/mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt index bf176765a..e9711e192 100644 --- a/mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt +++ b/mediapipe/graphs/face_mesh/face_mesh_mobile.pbtxt @@ -33,11 +33,23 @@ node { output_stream: "throttled_input_video" } +# Defines side packets for further use in the graph. +node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:with_attention" + node_options: { + [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { + packet { bool_value: true } + } + } +} + # Subgraph that detects faces and corresponding landmarks. node { calculator: "FaceLandmarkFrontGpu" input_stream: "IMAGE:throttled_input_video" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" output_stream: "DETECTIONS:face_detections" diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index f0a87ba05..6910d4d7f 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -511,7 +511,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { frame.getHeight(), frame.getTimestamp())); } - frame.waitUntilReleased(); + frame.waitUntilReleasedWithGpuSync(); if (Log.isLoggable(TAG, Log.VERBOSE)) { Log.v( TAG, diff --git a/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java index 20cb81982..4d75cc316 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/AppTextureFrame.java @@ -66,7 +66,9 @@ public class AppTextureFrame implements TextureFrame { /** * Waits until the consumer is done with the texture. - * @throws InterruptedException + * + *

This does a CPU wait for the texture to be complete. + * Use {@link waitUntilReleasedWithGpuSync} whenever possible. */ public void waitUntilReleased() throws InterruptedException { synchronized (this) { @@ -82,6 +84,26 @@ public class AppTextureFrame implements TextureFrame { } } + /** + * Waits until the consumer is done with the texture. + * + *

This method must be called within the application's GL context that will overwrite the + * TextureFrame. + */ + public void waitUntilReleasedWithGpuSync() throws InterruptedException { + synchronized (this) { + while (inUse && releaseSyncToken == null) { + wait(); + } + if (releaseSyncToken != null) { + releaseSyncToken.waitOnGpu(); + releaseSyncToken.release(); + inUse = false; + releaseSyncToken = null; + } + } + } + /** * Returns whether the texture is currently in use. * diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index e289ee74e..b3290f70e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -37,9 +37,18 @@ public class GraphTextureFrame implements TextureFrame { this.timestamp = timestamp; } - /** Returns the name of the underlying OpenGL texture. */ + /** + * Returns the name of the underlying OpenGL texture. + * + *

Note: if this texture has been obtained using getTextureFrameDeferredWait, a GPU wait on the + * producer sync will be done here. That means this method should be called on the GL context that + * will actually use the texture. + */ @Override public int getTextureName() { + // Note that, if a CPU wait has already been done, the sync point will have been + // cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait. + nativeGpuWait(nativeBufferHandle); return textureName; } @@ -92,4 +101,6 @@ public class GraphTextureFrame implements TextureFrame { private native int nativeGetTextureName(long nativeHandle); private native int nativeGetWidth(long nativeHandle); private native int nativeGetHeight(long nativeHandle); + + private native void nativeGpuWait(long nativeHandle); } diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java index 3d59c3e0a..109240bb9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketGetter.java @@ -288,7 +288,18 @@ public final class PacketGetter { */ public static GraphTextureFrame getTextureFrame(final Packet packet) { return new GraphTextureFrame( - nativeGetGpuBuffer(packet.getNativeHandle()), packet.getTimestamp()); + nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ true), packet.getTimestamp()); + } + + /** + * Works like {@link #getTextureFrame(Packet)}, but does not insert a CPU wait for the texture's + * producer before returning. Instead, a GPU wait will automatically occur when + * GraphTextureFrame#getTextureName is called. + */ + public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) { + return new GraphTextureFrame( + nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false), + packet.getTimestamp()); } private static native long nativeGetPacketFromReference(long nativePacketHandle); @@ -356,7 +367,7 @@ public final class PacketGetter { private static native int nativeGetGpuBufferName(long nativePacketHandle); - private static native long nativeGetGpuBuffer(long nativePacketHandle); + private static native long nativeGetGpuBuffer(long nativePacketHandle, boolean waitOnCpu); private PacketGetter() {} } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index e16f140b4..5f3a6527c 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -123,6 +123,7 @@ cc_library( "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:egl_surface_holder", ], "//mediapipe/gpu:disable_gpu": [ "//mediapipe/gpu:gpu_shared_data_internal", diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc index 5f41d9487..5c4470809 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.cc @@ -34,6 +34,13 @@ JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( return (*buffer)->name(); } +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGpuWait)( + JNIEnv* env, jobject thiz, jlong nativeHandle) { + GlTextureBufferSharedPtr* buffer = + reinterpret_cast(nativeHandle); + (*buffer)->WaitOnGpu(); +} + JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetWidth)( JNIEnv* env, jobject thiz, jlong nativeHandle) { GlTextureBufferSharedPtr* buffer = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h index ce6bbcbc7..4520083f9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h @@ -31,6 +31,9 @@ JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)( JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)( JNIEnv* env, jobject thiz, jlong nativeHandle); +JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGpuWait)( + JNIEnv* env, jobject thiz, jlong nativeHandle); + JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetWidth)( JNIEnv* env, jobject thiz, jlong nativeHandle); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index 1a7fd18b0..397da19eb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -437,9 +437,8 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetGpuBufferName)( return static_cast(gpu_buffer.GetGlTextureBufferSharedPtr()->name()); } -JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)(JNIEnv* env, - jobject thiz, - jlong packet) { +JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)( + JNIEnv* env, jobject thiz, jlong packet, jboolean wait_on_cpu) { mediapipe::Packet mediapipe_packet = mediapipe::android::Graph::GetPacketFromHandle(packet); mediapipe::GlTextureBufferSharedPtr ptr; @@ -459,7 +458,9 @@ JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)(JNIEnv* env, mediapipe_packet.Get(); ptr = buffer.GetGlTextureBufferSharedPtr(); } - ptr->WaitUntilComplete(); + if (wait_on_cpu) { + ptr->WaitUntilComplete(); + } return reinterpret_cast( new mediapipe::GlTextureBufferSharedPtr(ptr)); } diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h index 14b287158..6a20d3daf 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h @@ -154,9 +154,8 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetGpuBufferName)( // Returns a mediapipe::GlTextureBufferSharedPtr*. // This will survive independently of the packet. -JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)(JNIEnv* env, - jobject thiz, - jlong packet); +JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)( + JNIEnv* env, jobject thiz, jlong packet, jboolean wait_on_cpu); #ifdef __cplusplus } // extern "C" diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/BUILD b/mediapipe/java/com/google/mediapipe/solutioncore/BUILD index 9add85c69..2957672a4 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/BUILD +++ b/mediapipe/java/com/google/mediapipe/solutioncore/BUILD @@ -22,7 +22,6 @@ android_library( ["*.java"], exclude = [ "CameraInput.java", - "ResultGlBoundary.java", "ResultGlRenderer.java", "SolutionGlSurfaceView.java", "SolutionGlSurfaceViewRenderer.java", @@ -67,7 +66,6 @@ android_library( android_library( name = "solution_rendering", srcs = [ - "ResultGlBoundary.java", "ResultGlRenderer.java", "SolutionGlSurfaceView.java", "SolutionGlSurfaceViewRenderer.java", @@ -78,7 +76,6 @@ android_library( "//mediapipe/java/com/google/mediapipe/components:android_components", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/glutil", - "//third_party:autovalue", "@maven//:com_google_guava_guava", ], ) @@ -91,6 +88,8 @@ cc_binary( # TODO: Add more calculators to support other top-level solutions. deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/modules/face_detection:face_detection_full_range_image", + "//mediapipe/modules/face_detection:face_detection_short_range_image", "//mediapipe/modules/face_landmark:face_landmark_front_cpu_image", "//mediapipe/modules/face_landmark:face_landmark_front_gpu_image", "//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image", diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionBase.java b/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionBase.java index b23cf2e42..b15df8e32 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionBase.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionBase.java @@ -54,7 +54,7 @@ public class ImageSolutionBase extends SolutionBase { eglManager = new EglManager(/*parentContext=*/ null); solutionGraph.setParentGlContext(eglManager.getNativeContext()); } catch (MediaPipeException e) { - throwException("Error occurs when creating MediaPipe image solution graph. ", e); + reportError("Error occurs while creating MediaPipe image solution graph.", e); } } @@ -72,8 +72,8 @@ public class ImageSolutionBase extends SolutionBase { /** Sends a {@link TextureFrame} into solution graph for processing. */ public void send(TextureFrame textureFrame) { if (!staticImageMode && textureFrame.getTimestamp() == Long.MIN_VALUE) { - throwException( - "Error occurs when calling the solution send method. ", + reportError( + "Error occurs while calling the MediaPipe solution send method.", new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "TextureFrame's timestamp needs to be explicitly set if not in static image mode.")); @@ -98,8 +98,8 @@ public class ImageSolutionBase extends SolutionBase { /** Sends a {@link Bitmap} (static image) into solution graph for processing. */ public void send(Bitmap inputBitmap) { if (!staticImageMode) { - throwException( - "Error occurs when calling the solution send method. ", + reportError( + "Error occurs while calling the solution send method.", new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "When not in static image mode, a timestamp associated with the image is required." @@ -112,7 +112,7 @@ public class ImageSolutionBase extends SolutionBase { /** Internal implementation of sending Bitmap/TextureFrame into the MediaPipe solution. */ private synchronized void sendImage(T imageObj, long timestamp) { if (lastTimestamp >= timestamp) { - throwException( + reportError( "The received frame having a smaller timestamp than the processed timestamp.", new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), @@ -123,7 +123,7 @@ public class ImageSolutionBase extends SolutionBase { if (imageObj instanceof TextureFrame) { ((TextureFrame) imageObj).release(); } - throwException( + reportError( "The solution graph hasn't been successfully started or error occurs during graph" + " initializaton.", new MediaPipeException( @@ -140,8 +140,8 @@ public class ImageSolutionBase extends SolutionBase { } else if (imageObj instanceof Bitmap) { imagePacket = packetCreator.createRgbaImage((Bitmap) imageObj); } else { - throwException( - "The input image type is not supported. ", + reportError( + "The input image type is not supported.", new MediaPipeException( MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(), "The input image type is not supported.")); @@ -164,7 +164,7 @@ public class ImageSolutionBase extends SolutionBase { } } catch (RuntimeException e) { if (errorListener != null) { - errorListener.onError("Mediapipe error: ", e); + errorListener.onError("MediaPipe packet creation error: " + e.getMessage(), e); } else { throw e; } diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java b/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java index 51763431b..314efdcf2 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java @@ -33,6 +33,8 @@ public class OutputHandler { private ResultListener customResultListener; // The user-defined error listener. private ErrorListener customErrorListener; + // Whether the output handler should react to timestamp-bound changes by outputting empty packets. + private boolean handleTimestampBoundChanges = false; /** * Sets a callback to be invoked to convert a packet list to a solution result object. @@ -61,6 +63,20 @@ public class OutputHandler { this.customErrorListener = listener; } + /** + * Sets whether the output handler should react to timestamp-bound changes by outputting empty + * packets. + * + * @param handleTimestampBoundChanges a boolean value. + */ + public void setHandleTimestampBoundChanges(boolean handleTimestampBoundChanges) { + this.handleTimestampBoundChanges = handleTimestampBoundChanges; + } + + public boolean handleTimestampBoundChanges() { + return handleTimestampBoundChanges; + } + /** Handles a list of output packets. Invoked when packet lists become available. */ public void run(List packets) { T solutionResult = null; diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlBoundary.java b/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlBoundary.java deleted file mode 100644 index d01e8ac8f..000000000 --- a/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlBoundary.java +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2021 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.solutioncore; - -import com.google.auto.value.AutoValue; - -/** - * The left, right, bottom, and top boundaries of the visible section on the screen. The boundary - * values are typically within the range -1.0 and 1.0. - */ -@AutoValue -public abstract class ResultGlBoundary { - - static ResultGlBoundary create(float left, float right, float bottom, float top) { - return new AutoValue_ResultGlBoundary(left, right, bottom, top); - } - - public abstract float left(); - - public abstract float right(); - - public abstract float bottom(); - - public abstract float top(); -} diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlRenderer.java b/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlRenderer.java index 43f51fefe..839f59aac 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlRenderer.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/ResultGlRenderer.java @@ -20,6 +20,16 @@ public interface ResultGlRenderer { /** Sets up OpenGL rendering when the surface is created or recreated. */ void setupRendering(); - /** Renders the solution result. */ - void renderResult(T result, ResultGlBoundary boundary); + /** + * Renders the solution result. + * + * @param result a solution result object that contains the solution outputs. + * @param projectionMatrix a 4 x 4 column-vector matrix stored in column-major order (see also android.opengl.Matrix). + * It is an orthographic projection matrix that maps x and y coordinates in {@code result}, + * defined in [0, 1]x[0, 1] spanning the entire input image (with a top-left origin), to fit + * into the {@link SolutionGlSurfaceView} (with a bottom-left origin) that the input image is + * rendered into with potential cropping. + */ + void renderResult(T result, float[] projectionMatrix); } diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionBase.java b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionBase.java index 0b4a4c357..3b94e7081 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionBase.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionBase.java @@ -73,21 +73,24 @@ public class SolutionBase { AndroidAssetUtil.getAssetBytes(context.getAssets(), solutionInfo.binaryGraphPath())); } solutionGraph.addMultiStreamCallback( - solutionInfo.outputStreamNames(), outputHandler::run, /*observeTimestampBounds=*/ true); + solutionInfo.outputStreamNames(), + outputHandler::run, + /*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges()); packetCreator = new AndroidPacketCreator(solutionGraph); } catch (MediaPipeException e) { - throwException("Error occurs when creating the MediaPipe solution graph. ", e); + reportError("Error occurs while creating the MediaPipe solution graph.", e); } } - /** Throws exception with error message. */ - protected void throwException(String message, MediaPipeException e) { + /** Reports error with the detailed error message. */ + protected void reportError(String message, MediaPipeException e) { + String detailedErrorMessage = String.format("%s Error details: %s", message, e.getMessage()); if (errorListener != null) { - errorListener.onError(message, e); + errorListener.onError(detailedErrorMessage, e); } else { - Log.e(TAG, message, e); + Log.e(TAG, detailedErrorMessage, e); + throw e; } - throw e; } /** @@ -114,7 +117,7 @@ public class SolutionBase { solutionGraph.startRunningGraph(); } } catch (MediaPipeException e) { - throwException("Error occurs when starting the MediaPipe solution graph. ", e); + reportError("Error occurs while starting the MediaPipe solution graph.", e); } } @@ -123,7 +126,7 @@ public class SolutionBase { try { solutionGraph.waitUntilGraphIdle(); } catch (MediaPipeException e) { - throwException("Error occurs when waiting until the MediaPipe graph becomes idle. ", e); + reportError("Error occurs while waiting until the MediaPipe graph becomes idle.", e); } } @@ -137,12 +140,12 @@ public class SolutionBase { // Note: errors during Process are reported at the earliest opportunity, // which may be addPacket or waitUntilDone, depending on timing. For consistency, // we want to always report them using the same async handler if installed. - throwException("Error occurs when closing the Mediapipe solution graph. ", e); + reportError("Error occurs while closing the Mediapipe solution graph.", e); } try { solutionGraph.tearDown(); } catch (MediaPipeException e) { - throwException("Error occurs when closing the Mediapipe solution graph. ", e); + reportError("Error occurs while closing the Mediapipe solution graph.", e); } } } diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java index ccaa1e725..47f7695ae 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java @@ -16,6 +16,7 @@ package com.google.mediapipe.solutioncore; import android.graphics.SurfaceTexture; import android.opengl.GLES20; +import android.opengl.Matrix; import com.google.mediapipe.components.GlSurfaceViewRenderer; import com.google.mediapipe.framework.TextureFrame; import com.google.mediapipe.glutil.ShaderUtil; @@ -91,14 +92,18 @@ public class SolutionGlSurfaceViewRenderer if (nextSolutionResult != null) { solutionResult = nextSolutionResult.getAndSet(null); float[] textureBoundary = calculateTextureBoundary(); - // Scales the values from [0, 1] to [-1, 1]. - ResultGlBoundary resultGlBoundary = - ResultGlBoundary.create( - textureBoundary[0] * 2 - 1, - textureBoundary[1] * 2 - 1, - textureBoundary[2] * 2 - 1, - textureBoundary[3] * 2 - 1); - resultGlRenderer.renderResult(solutionResult, resultGlBoundary); + float[] projectionMatrix = new float[16]; + // See {@link ResultGlRenderer#renderResult}. + Matrix.orthoM( + projectionMatrix, /* result */ + 0, /* offset */ + textureBoundary[0], /* left */ + textureBoundary[1], /* right */ + textureBoundary[3], /* bottom */ + textureBoundary[2], /* top */ + -1, /* near */ + 1 /* far */); + resultGlRenderer.renderResult(solutionResult, projectionMatrix); } flush(frame); if (solutionResult != null) { diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/AndroidManifest.xml b/mediapipe/java/com/google/mediapipe/solutions/facedetection/AndroidManifest.xml new file mode 100644 index 000000000..e134efe19 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/BUILD b/mediapipe/java/com/google/mediapipe/solutions/facedetection/BUILD new file mode 100644 index 000000000..845825210 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/BUILD @@ -0,0 +1,45 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +android_library( + name = "facedetection", + srcs = [ + "FaceDetection.java", + "FaceDetectionOptions.java", + "FaceDetectionResult.java", + "FaceKeypoint.java", + ], + assets = [ + "//mediapipe/modules/face_detection:face_detection_full_range_image.binarypb", + "//mediapipe/modules/face_detection:face_detection_full_range_sparse.tflite", + "//mediapipe/modules/face_detection:face_detection_short_range.tflite", + "//mediapipe/modules/face_detection:face_detection_short_range_image.binarypb", + ], + assets_dir = "", + javacopts = ["-Acom.google.auto.value.AutoBuilderIsUnstable"], + manifest = ":AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/solutioncore:solution_base", + "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_code_findbugs_jsr305", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java new file mode 100644 index 000000000..92e10a9f1 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetection.java @@ -0,0 +1,130 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.solutions.facedetection; + +import android.content.Context; +import com.google.common.collect.ImmutableList; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.RelativeKeypoint; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.solutioncore.ErrorListener; +import com.google.mediapipe.solutioncore.ImageSolutionBase; +import com.google.mediapipe.solutioncore.OutputHandler; +import com.google.mediapipe.solutioncore.ResultListener; +import com.google.mediapipe.solutioncore.SolutionInfo; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nullable; + +/** + * MediaPipe Face Detection Solution API. + * + *

MediaPipe Face Detection processes a {@link TextureFrame} or a {@link Bitmap} and returns the + * {@link FaceDetectionResult} representing each detected face. Please refer to + * https://solutions.mediapipe.dev/face_detection#android-solution-api for usage examples. + */ +public class FaceDetection extends ImageSolutionBase { + private static final String TAG = "FaceDetection"; + + private static final String SHORT_RANGE_GRAPH_NAME = "face_detection_short_range_image.binarypb"; + private static final String FULL_RANGE_GRAPH_NAME = "face_detection_full_range_image.binarypb"; + private static final String IMAGE_INPUT_STREAM = "image"; + private static final ImmutableList OUTPUT_STREAMS = + ImmutableList.of("detections", "throttled_image"); + private static final int DETECTIONS_INDEX = 0; + private static final int INPUT_IMAGE_INDEX = 1; + private final OutputHandler outputHandler; + + /** + * Initializes MediaPipe Face Detection solution. + * + * @param context an Android {@link Context}. + * @param options the configuration options defined in {@link FaceDetectionOptions}. + */ + public FaceDetection(Context context, FaceDetectionOptions options) { + outputHandler = new OutputHandler<>(); + outputHandler.setOutputConverter( + packets -> { + FaceDetectionResult.Builder faceMeshResultBuilder = FaceDetectionResult.builder(); + try { + faceMeshResultBuilder.setMultiFaceDetections( + getProtoVector(packets.get(DETECTIONS_INDEX), Detection.parser())); + } catch (MediaPipeException e) { + reportError("Error occurs while getting MediaPipe face detection results.", e); + } + return faceMeshResultBuilder + .setImagePacket(packets.get(INPUT_IMAGE_INDEX)) + .setTimestamp( + staticImageMode ? Long.MIN_VALUE : packets.get(INPUT_IMAGE_INDEX).getTimestamp()) + .build(); + }); + + SolutionInfo solutionInfo = + SolutionInfo.builder() + .setBinaryGraphPath( + options.modelSelection() == 0 ? SHORT_RANGE_GRAPH_NAME : FULL_RANGE_GRAPH_NAME) + .setImageInputStreamName(IMAGE_INPUT_STREAM) + .setOutputStreamNames(OUTPUT_STREAMS) + .setStaticImageMode(options.staticImageMode()) + .build(); + + initialize(context, solutionInfo, outputHandler); + Map emptyInputSidePackets = new HashMap<>(); + start(emptyInputSidePackets); + } + + /** + * Sets a callback to be invoked when a {@link FaceDetectionResult} becomes available. + * + * @param listener the {@link ResultListener} callback. + */ + public void setResultListener(ResultListener listener) { + this.outputHandler.setResultListener(listener); + } + + /** + * Sets a callback to be invoked when the Face Detection solution throws errors. + * + * @param listener the {@link ErrorListener} callback. + */ + public void setErrorListener(@Nullable ErrorListener listener) { + this.outputHandler.setErrorListener(listener); + this.errorListener = listener; + } + + /** + * Gets a specific face keypoint by face index and face keypoint type. + * + * @param result the returned {@link FaceDetectionResult} object. + * @param faceIndex the face index. A smaller index maps to a detected face with a higher + * confidence score. + * @param faceKeypointType the face keypoint type defined in {@link FaceKeypoint}. + */ + public static RelativeKeypoint getFaceKeypoint( + FaceDetectionResult result, + int faceIndex, + @FaceKeypoint.FaceKeypointType int faceKeypointType) { + if (result == null + || faceIndex >= result.multiFaceDetections().size() + || faceKeypointType >= FaceKeypoint.NUM_KEY_POINTS) { + return RelativeKeypoint.getDefaultInstance(); + } + Detection detection = result.multiFaceDetections().get(faceIndex); + float x = detection.getLocationData().getRelativeKeypoints(faceKeypointType).getX(); + float y = detection.getLocationData().getRelativeKeypoints(faceKeypointType).getY(); + return RelativeKeypoint.newBuilder().setX(x).setY(y).build(); + } +} diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionOptions.java b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionOptions.java new file mode 100644 index 000000000..4158f9ff7 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionOptions.java @@ -0,0 +1,61 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.solutions.facedetection; + +import com.google.auto.value.AutoValue; + +/** + * MediaPipe Face Detection solution-specific options. + * + *

staticImageMode: Whether to treat the input images as a batch of static and possibly unrelated + * images, or a video stream. Default to false. See details in + * https://solutions.mediapipe.dev/face_detection#static_image_mode. + * + *

minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for face detection to be + * considered successful. See details in + * https://solutions.mediapipe.dev/face_detection#min_detection_confidence. + * + *

modelSelection: 0 or 1. 0 to select a short-range model that works best for faces within 2 + * meters from the camera, and 1 for a full-range model best for faces within 5 meters. See details + * in https://solutions.mediapipe.dev/face_detection#model_selection. + */ +@AutoValue +public abstract class FaceDetectionOptions { + public abstract boolean staticImageMode(); + + public abstract int modelSelection(); + + public abstract float minDetectionConfidence(); + + public static Builder builder() { + return new AutoValue_FaceDetectionOptions.Builder().withDefaultValues(); + } + + /** Builder for {@link FaceDetectionOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + public Builder withDefaultValues() { + return setStaticImageMode(false).setModelSelection(0).setMinDetectionConfidence(0.5f); + } + + public abstract Builder setStaticImageMode(boolean value); + + public abstract Builder setModelSelection(int value); + + public abstract Builder setMinDetectionConfidence(float value); + + public abstract FaceDetectionOptions build(); + } +} diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java new file mode 100644 index 000000000..d665a95f6 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceDetectionResult.java @@ -0,0 +1,65 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.solutions.facedetection; + +import android.graphics.Bitmap; +import com.google.auto.value.AutoBuilder; +import com.google.common.collect.ImmutableList; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.TextureFrame; +import com.google.mediapipe.solutioncore.ImageSolutionResult; +import java.util.List; + +/** + * FaceDetectionResult contains the detected faces, and the input {@link Bitmap} or {@link + * TextureFrame}. If not in static image mode, the timestamp field will be set to the timestamp of + * the corresponding input image. + */ +public class FaceDetectionResult extends ImageSolutionResult { + private final ImmutableList multiFaceDetections; + + FaceDetectionResult( + ImmutableList multiFaceDetections, Packet imagePacket, long timestamp) { + this.multiFaceDetections = multiFaceDetections; + this.timestamp = timestamp; + this.imagePacket = imagePacket; + } + + // Collection of detected faces, where each face is represented as a detection proto message that + // contains a bounding box and 6 {@link FaceKeypoint}s. The bounding box is composed of xmin and + // width (both normalized to [0.0, 1.0] by the image width) and ymin and height (both normalized + // to [0.0, 1.0] by the image height). Each keypoint is composed of x and y, which are normalized + // to [0.0, 1.0] by the image width and height respectively. + public ImmutableList multiFaceDetections() { + return multiFaceDetections; + } + + public static Builder builder() { + return new AutoBuilder_FaceDetectionResult_Builder(); + } + + /** Builder for {@link FaceDetectionResult}. */ + @AutoBuilder + public abstract static class Builder { + abstract Builder setMultiFaceDetections(List value); + + abstract Builder setTimestamp(long value); + + abstract Builder setImagePacket(Packet value); + + abstract FaceDetectionResult build(); + } +} diff --git a/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceKeypoint.java b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceKeypoint.java new file mode 100644 index 000000000..9ac0ffa13 --- /dev/null +++ b/mediapipe/java/com/google/mediapipe/solutions/facedetection/FaceKeypoint.java @@ -0,0 +1,42 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.solutions.facedetection; + +import androidx.annotation.IntDef; + +/** The 6 face keypoints. */ +public final class FaceKeypoint { + public static final int NUM_KEY_POINTS = 6; + + public static final int RIGHT_EYE = 0; + public static final int LEFT_EYE = 1; + public static final int NOSE_TIP = 2; + public static final int MOUTH_CENTER = 3; + public static final int RIGHT_EAR_TRAGION = 4; + public static final int LEFT_EAR_TRAGION = 5; + + /** Represents a face keypoint type. */ + @IntDef({ + RIGHT_EYE, + LEFT_EYE, + NOSE_TIP, + MOUTH_CENTER, + RIGHT_EAR_TRAGION, + LEFT_EAR_TRAGION, + }) + public @interface FaceKeypointType {} + + private FaceKeypoint() {} +} diff --git a/mediapipe/java/com/google/mediapipe/solutions/facemesh/BUILD b/mediapipe/java/com/google/mediapipe/solutions/facemesh/BUILD index 4d843e897..2c0f4b0af 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/facemesh/BUILD +++ b/mediapipe/java/com/google/mediapipe/solutions/facemesh/BUILD @@ -25,6 +25,7 @@ android_library( assets = [ "//mediapipe/modules/face_detection:face_detection_short_range.tflite", "//mediapipe/modules/face_landmark:face_landmark.tflite", + "//mediapipe/modules/face_landmark:face_landmark_with_attention.tflite", "//mediapipe/modules/face_landmark:face_landmark_front_cpu_image.binarypb", "//mediapipe/modules/face_landmark:face_landmark_front_gpu_image.binarypb", ], diff --git a/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMesh.java b/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMesh.java index f98be97bc..0e26df18f 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMesh.java +++ b/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMesh.java @@ -29,41 +29,46 @@ import java.util.Map; import javax.annotation.Nullable; /** - * MediaPipe FaceMesh Solution API. + * MediaPipe Face Mesh Solution API. * - *

MediaPipe FaceMesh processes a {@link TextureFrame} or a {@link Bitmap} and returns the face + *

MediaPipe Face Mesh processes a {@link TextureFrame} or a {@link Bitmap} and returns the face * landmarks of each detected face. Please refer to * https://solutions.mediapipe.dev/face_mesh#android-solution-api for usage examples. */ public class FaceMesh extends ImageSolutionBase { private static final String TAG = "FaceMesh"; + public static final int FACEMESH_NUM_LANDMARKS = 468; + public static final int FACEMESH_NUM_LANDMARKS_WITH_IRISES = 478; + private static final String NUM_FACES = "num_faces"; + private static final String WITH_ATTENTION = "with_attention"; + private static final String USE_PREV_LANDMARKS = "use_prev_landmarks"; private static final String GPU_GRAPH_NAME = "face_landmark_front_gpu_image.binarypb"; private static final String CPU_GRAPH_NAME = "face_landmark_front_cpu_image.binarypb"; private static final String IMAGE_INPUT_STREAM = "image"; private static final ImmutableList OUTPUT_STREAMS = - ImmutableList.of("multi_face_landmarks", "image"); + ImmutableList.of("multi_face_landmarks", "throttled_image"); private static final int LANDMARKS_INDEX = 0; private static final int INPUT_IMAGE_INDEX = 1; - private final OutputHandler graphOutputHandler; + private final OutputHandler outputHandler; /** - * Initializes MediaPipe FaceMesh solution. + * Initializes MediaPipe Face Mesh solution. * * @param context an Android {@link Context}. * @param options the configuration options defined in {@link FaceMeshOptions}. */ public FaceMesh(Context context, FaceMeshOptions options) { - graphOutputHandler = new OutputHandler<>(); - graphOutputHandler.setOutputConverter( + outputHandler = new OutputHandler<>(); + outputHandler.setOutputConverter( packets -> { FaceMeshResult.Builder faceMeshResultBuilder = FaceMeshResult.builder(); try { faceMeshResultBuilder.setMultiFaceLandmarks( getProtoVector(packets.get(LANDMARKS_INDEX), NormalizedLandmarkList.parser())); } catch (MediaPipeException e) { - throwException("Error occurs when getting MediaPipe facemesh landmarks. ", e); + reportError("Error occurs when getting MediaPipe facemesh landmarks.", e); } return faceMeshResultBuilder .setImagePacket(packets.get(INPUT_IMAGE_INDEX)) @@ -77,31 +82,33 @@ public class FaceMesh extends ImageSolutionBase { .setBinaryGraphPath(options.runOnGpu() ? GPU_GRAPH_NAME : CPU_GRAPH_NAME) .setImageInputStreamName(IMAGE_INPUT_STREAM) .setOutputStreamNames(OUTPUT_STREAMS) - .setStaticImageMode(options.mode() == FaceMeshOptions.STATIC_IMAGE_MODE) + .setStaticImageMode(options.staticImageMode()) .build(); - initialize(context, solutionInfo, graphOutputHandler); + initialize(context, solutionInfo, outputHandler); Map inputSidePackets = new HashMap<>(); inputSidePackets.put(NUM_FACES, packetCreator.createInt32(options.maxNumFaces())); + inputSidePackets.put(WITH_ATTENTION, packetCreator.createBool(options.refineLandmarks())); + inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode())); start(inputSidePackets); } /** - * Sets a callback to be invoked when the FaceMeshResults become available. + * Sets a callback to be invoked when a {@link FaceMeshResult} becomes available. * * @param listener the {@link ResultListener} callback. */ public void setResultListener(ResultListener listener) { - this.graphOutputHandler.setResultListener(listener); + this.outputHandler.setResultListener(listener); } /** - * Sets a callback to be invoked when the FaceMesh solution throws errors. + * Sets a callback to be invoked when the Face Mesh solution throws errors. * * @param listener the {@link ErrorListener} callback. */ public void setErrorListener(@Nullable ErrorListener listener) { - this.graphOutputHandler.setErrorListener(listener); + this.outputHandler.setErrorListener(listener); this.errorListener = listener; } } diff --git a/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshConnections.java b/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshConnections.java index 0c0d3e7b8..786ce6adf 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshConnections.java +++ b/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshConnections.java @@ -94,7 +94,7 @@ public final class FaceMeshConnections { Connection.create(384, 398), Connection.create(398, 362)); - public static final ImmutableSet FACEMESH_LEFT_EYEBR0W = + public static final ImmutableSet FACEMESH_LEFT_EYEBROW = ImmutableSet.of( Connection.create(276, 283), Connection.create(283, 282), @@ -105,6 +105,13 @@ public final class FaceMeshConnections { Connection.create(334, 296), Connection.create(296, 336)); + public static final ImmutableSet FACEMESH_LEFT_IRIS = + ImmutableSet.of( + Connection.create(474, 475), + Connection.create(475, 476), + Connection.create(476, 477), + Connection.create(477, 474)); + public static final ImmutableSet FACEMESH_RIGHT_EYE = ImmutableSet.of( Connection.create(33, 7), @@ -123,6 +130,7 @@ public final class FaceMeshConnections { Connection.create(158, 157), Connection.create(157, 173), Connection.create(173, 133)); + public static final ImmutableSet FACEMESH_RIGHT_EYEBROW = ImmutableSet.of( Connection.create(46, 53), @@ -134,6 +142,13 @@ public final class FaceMeshConnections { Connection.create(105, 66), Connection.create(66, 107)); + public static final ImmutableSet FACEMESH_RIGHT_IRIS = + ImmutableSet.of( + Connection.create(469, 470), + Connection.create(470, 471), + Connection.create(471, 472), + Connection.create(472, 469)); + public static final ImmutableSet FACEMESH_FACE_OVAL = ImmutableSet.of( Connection.create(10, 338), @@ -177,7 +192,7 @@ public final class FaceMeshConnections { ImmutableSet.builder() .addAll(FACEMESH_LIPS) .addAll(FACEMESH_LEFT_EYE) - .addAll(FACEMESH_LEFT_EYEBR0W) + .addAll(FACEMESH_LEFT_EYEBROW) .addAll(FACEMESH_RIGHT_EYE) .addAll(FACEMESH_RIGHT_EYEBROW) .addAll(FACEMESH_FACE_OVAL) diff --git a/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshOptions.java b/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshOptions.java index b735a9efd..0600a597f 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshOptions.java +++ b/mediapipe/java/com/google/mediapipe/solutions/facemesh/FaceMeshOptions.java @@ -14,18 +14,22 @@ package com.google.mediapipe.solutions.facemesh; -import androidx.annotation.IntDef; import com.google.auto.value.AutoValue; /** * MediaPipe FaceMesh solution-specific options. * - *

mode: Whether to treat the input images as a batch of static and possibly unrelated images, or - * a video stream. See details in https://solutions.mediapipe.dev/face_mesh#static_image_mode. + *

staticImageMode: Whether to treat the input images as a batch of static and possibly unrelated + * images, or a video stream. Default to false. See details in + * https://solutions.mediapipe.dev/face_mesh#static_image_mode. * *

maxNumFaces: Maximum number of faces to detect. See details in * https://solutions.mediapipe.dev/face_mesh#max_num_faces. * + *

refineLandmarks: Whether to further refine the landmark coordinates around the eyes, lips and + * face oval, and output additional landmarks around the irises. Default to False. See details in + * https://solutions.mediapipe.dev/face_mesh#refine_landmark. + * *

minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for face detection to be * considered successful. See details in * https://solutions.mediapipe.dev/face_mesh#min_detection_confidence. @@ -39,19 +43,7 @@ import com.google.auto.value.AutoValue; @AutoValue public abstract class FaceMeshOptions { - // TODO: Switch to use boolean variable. - public static final int STREAMING_MODE = 1; - public static final int STATIC_IMAGE_MODE = 2; - - /** - * Indicates whether to treat the input images as a batch of static and possibly unrelated images, - * or a video stream. - */ - @IntDef({STREAMING_MODE, STATIC_IMAGE_MODE}) - public @interface Mode {} - - @Mode - public abstract int mode(); + public abstract boolean staticImageMode(); public abstract int maxNumFaces(); @@ -59,6 +51,8 @@ public abstract class FaceMeshOptions { public abstract float minTrackingConfidence(); + public abstract boolean refineLandmarks(); + public abstract boolean runOnGpu(); public static Builder builder() { @@ -69,13 +63,15 @@ public abstract class FaceMeshOptions { @AutoValue.Builder public abstract static class Builder { public Builder withDefaultValues() { - return setMaxNumFaces(1) + return setStaticImageMode(false) + .setMaxNumFaces(1) .setMinDetectionConfidence(0.5f) .setMinTrackingConfidence(0.5f) + .setRefineLandmarks(false) .setRunOnGpu(true); } - public abstract Builder setMode(int value); + public abstract Builder setStaticImageMode(boolean value); public abstract Builder setMaxNumFaces(int value); @@ -83,6 +79,8 @@ public abstract class FaceMeshOptions { public abstract Builder setMinTrackingConfidence(float value); + public abstract Builder setRefineLandmarks(boolean value); + public abstract Builder setRunOnGpu(boolean value); public abstract FaceMeshOptions build(); diff --git a/mediapipe/java/com/google/mediapipe/solutions/hands/Hands.java b/mediapipe/java/com/google/mediapipe/solutions/hands/Hands.java index fdf8f67f3..d3fe548cf 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/hands/Hands.java +++ b/mediapipe/java/com/google/mediapipe/solutions/hands/Hands.java @@ -79,15 +79,16 @@ public class Hands extends ImageSolutionBase { Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP)); private static final String NUM_HANDS = "num_hands"; + private static final String USE_PREV_LANDMARKS = "use_prev_landmarks"; private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb"; private static final String CPU_GRAPH_NAME = "hand_landmark_tracking_cpu_image.binarypb"; private static final String IMAGE_INPUT_STREAM = "image"; private static final ImmutableList OUTPUT_STREAMS = - ImmutableList.of("multi_hand_landmarks", "multi_handedness", "image"); + ImmutableList.of("multi_hand_landmarks", "multi_handedness", "throttled_image"); private static final int LANDMARKS_INDEX = 0; private static final int HANDEDNESS_INDEX = 1; private static final int INPUT_IMAGE_INDEX = 2; - private final OutputHandler graphOutputHandler; + private final OutputHandler outputHandler; /** * Initializes MediaPipe Hands solution. @@ -96,21 +97,21 @@ public class Hands extends ImageSolutionBase { * @param options the configuration options defined in {@link HandsOptions}. */ public Hands(Context context, HandsOptions options) { - graphOutputHandler = new OutputHandler<>(); - graphOutputHandler.setOutputConverter( + outputHandler = new OutputHandler<>(); + outputHandler.setOutputConverter( packets -> { HandsResult.Builder handsResultBuilder = HandsResult.builder(); try { handsResultBuilder.setMultiHandLandmarks( getProtoVector(packets.get(LANDMARKS_INDEX), NormalizedLandmarkList.parser())); } catch (MediaPipeException e) { - throwException("Error occurs when getting MediaPipe hand landmarks. ", e); + reportError("Error occurs while getting MediaPipe hand landmarks.", e); } try { handsResultBuilder.setMultiHandedness( getProtoVector(packets.get(HANDEDNESS_INDEX), Classification.parser())); } catch (MediaPipeException e) { - throwException("Error occurs when getting MediaPipe handedness data. ", e); + reportError("Error occurs while getting MediaPipe handedness data.", e); } return handsResultBuilder .setImagePacket(packets.get(INPUT_IMAGE_INDEX)) @@ -124,22 +125,23 @@ public class Hands extends ImageSolutionBase { .setBinaryGraphPath(options.runOnGpu() ? GPU_GRAPH_NAME : CPU_GRAPH_NAME) .setImageInputStreamName(IMAGE_INPUT_STREAM) .setOutputStreamNames(OUTPUT_STREAMS) - .setStaticImageMode(options.mode() == HandsOptions.STATIC_IMAGE_MODE) + .setStaticImageMode(options.staticImageMode()) .build(); - initialize(context, solutionInfo, graphOutputHandler); + initialize(context, solutionInfo, outputHandler); Map inputSidePackets = new HashMap<>(); inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands())); + inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode())); start(inputSidePackets); } /** - * Sets a callback to be invoked when the HandsResults become available. + * Sets a callback to be invoked when a {@link HandsResult} becomes available. * * @param listener the {@link ResultListener} callback. */ public void setResultListener(ResultListener listener) { - this.graphOutputHandler.setResultListener(listener); + this.outputHandler.setResultListener(listener); } /** @@ -148,7 +150,7 @@ public class Hands extends ImageSolutionBase { * @param listener the {@link ErrorListener} callback. */ public void setErrorListener(@Nullable ErrorListener listener) { - this.graphOutputHandler.setErrorListener(listener); + this.outputHandler.setErrorListener(listener); this.errorListener = listener; } diff --git a/mediapipe/java/com/google/mediapipe/solutions/hands/HandsOptions.java b/mediapipe/java/com/google/mediapipe/solutions/hands/HandsOptions.java index 13ff6e90c..199c4b053 100644 --- a/mediapipe/java/com/google/mediapipe/solutions/hands/HandsOptions.java +++ b/mediapipe/java/com/google/mediapipe/solutions/hands/HandsOptions.java @@ -14,14 +14,14 @@ package com.google.mediapipe.solutions.hands; -import androidx.annotation.IntDef; import com.google.auto.value.AutoValue; /** * MediaPipe Hands solution-specific options. * - *

mode: Whether to treat the input images as a batch of static and possibly unrelated images, or - * a video stream. See details in https://solutions.mediapipe.dev/hands#static_image_mode. + *

staticImageMode: Whether to treat the input images as a batch of static and possibly unrelated + * images, or a video stream. Default to false. See details in + * https://solutions.mediapipe.dev/hands#static_image_mode. * *

maxNumHands: Maximum number of hands to detect. See details in * https://solutions.mediapipe.dev/hands#max_num_hands. @@ -39,19 +39,7 @@ import com.google.auto.value.AutoValue; @AutoValue public abstract class HandsOptions { - // TODO: Switch to use boolean variable. - public static final int STREAMING_MODE = 1; - public static final int STATIC_IMAGE_MODE = 2; - - /** - * Indicates whether to treat the input images as a batch of static and possibly unrelated images, - * or a video stream. - */ - @IntDef({STREAMING_MODE, STATIC_IMAGE_MODE}) - public @interface Mode {} - - @Mode - public abstract int mode(); + public abstract boolean staticImageMode(); public abstract int maxNumHands(); @@ -69,13 +57,14 @@ public abstract class HandsOptions { @AutoValue.Builder public abstract static class Builder { public Builder withDefaultValues() { - return setMaxNumHands(2) + return setStaticImageMode(false) + .setMaxNumHands(2) .setMinDetectionConfidence(0.5f) .setMinTrackingConfidence(0.5f) .setRunOnGpu(true); } - public abstract Builder setMode(int value); + public abstract Builder setStaticImageMode(boolean value); public abstract Builder setMaxNumHands(int value); diff --git a/mediapipe/modules/face_detection/BUILD b/mediapipe/modules/face_detection/BUILD index 839418c77..b1cddeb6f 100644 --- a/mediapipe/modules/face_detection/BUILD +++ b/mediapipe/modules/face_detection/BUILD @@ -117,6 +117,30 @@ mediapipe_simple_subgraph( ], ) +mediapipe_simple_subgraph( + name = "face_detection_short_range_image", + graph = "face_detection_short_range_image.pbtxt", + register_as = "FaceDetectionShortRangeImage", + deps = [ + ":face_detection_short_range_common", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "face_detection_full_range_image", + graph = "face_detection_full_range_image.pbtxt", + register_as = "FaceDetectionFullRangeImage", + deps = [ + ":face_detection_full_range_common", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + ], +) + exports_files( srcs = [ "face_detection_full_range.tflite", diff --git a/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt new file mode 100644 index 000000000..4e0bc0b4d --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_full_range_image.pbtxt @@ -0,0 +1,86 @@ +# MediaPipe graph to detect faces. (GPU/CPU input, and inference is executed on +# GPU.) +# +# It is required that "face_detection_full_range_sparse.tflite" is available at +# "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" +# path during execution. + +type: "FaceDetectionFullRangeImage" + +# Image. (Image) +input_stream: "IMAGE:image" + +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" +# Detected faces. (std::vector) +# NOTE: there will not be an output packet in the DETECTIONS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +output_stream: "DETECTIONS:detections" + +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + +# Transforms the input image into a 128x128 tensor while keeping the aspect +# ratio (what is expected by the corresponding face detection model), resulting +# in potential letterboxing in the transformed image. +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:throttled_image" + output_stream: "TENSORS:input_tensors" + output_stream: "MATRIX:transform_matrix" + options: { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 192 + output_tensor_height: 192 + keep_aspect_ratio: true + output_tensor_float_range { + min: -1.0 + max: 1.0 + } + border_mode: BORDER_ZERO + gpu_origin: CONVENTIONAL + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +# TODO: Use GraphOptions to modify the delegate field to be +# `delegate { xnnpack {} }` for the CPU only use cases. +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:detection_tensors" + options: { + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" + # + delegate: { gpu { use_advanced_gpu_api: true } } + } + } +} + +# Performs tensor post processing to generate face detections. +node { + calculator: "FaceDetectionFullRangeCommon" + input_stream: "TENSORS:detection_tensors" + input_stream: "MATRIX:transform_matrix" + output_stream: "DETECTIONS:detections" +} diff --git a/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt b/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt new file mode 100644 index 000000000..a2590418b --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_short_range_image.pbtxt @@ -0,0 +1,94 @@ +# MediaPipe graph to detect faces. (GPU/CPU input, and inference is executed on +# GPU.) +# +# It is required that "face_detection_short_range.tflite" is available at +# "mediapipe/modules/face_detection/face_detection_short_range.tflite" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionShortRangeCpu" +# input_stream: "IMAGE:image" +# output_stream: "DETECTIONS:face_detections" +# } + +type: "FaceDetectionShortRangeCpu" + +# Image. (Image) +input_stream: "IMAGE:image" + +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" +# Detected faces. (std::vector) +# NOTE: there will not be an output packet in the DETECTIONS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +output_stream: "DETECTIONS:detections" + +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:detections" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + +# Transforms the input image into a 128x128 tensor while keeping the aspect +# ratio (what is expected by the corresponding face detection model), resulting +# in potential letterboxing in the transformed image. +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:throttled_image" + output_stream: "TENSORS:input_tensors" + output_stream: "MATRIX:transform_matrix" + options: { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 128 + output_tensor_height: 128 + keep_aspect_ratio: true + output_tensor_float_range { + min: -1.0 + max: 1.0 + } + border_mode: BORDER_ZERO + gpu_origin: CONVENTIONAL + } + } +} + +# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +# TODO: Use GraphOptions to modify the delegate field to be +# `delegate { xnnpack {} }` for the CPU only use cases. +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:detection_tensors" + options: { + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/modules/face_detection/face_detection_short_range.tflite" + + # + delegate: { gpu { use_advanced_gpu_api: true } } + } + } +} + +# Performs tensor post processing to generate face detections. +node { + calculator: "FaceDetectionShortRangeCommon" + input_stream: "TENSORS:detection_tensors" + input_stream: "MATRIX:transform_matrix" + output_stream: "DETECTIONS:detections" +} diff --git a/mediapipe/modules/face_landmark/BUILD b/mediapipe/modules/face_landmark/BUILD index 2560bad9b..f155e46d5 100644 --- a/mediapipe/modules/face_landmark/BUILD +++ b/mediapipe/modules/face_landmark/BUILD @@ -26,14 +26,19 @@ mediapipe_simple_subgraph( graph = "face_landmark_cpu.pbtxt", register_as = "FaceLandmarkCpu", deps = [ + ":face_landmarks_model_loader", + ":tensors_to_face_landmarks", + ":tensors_to_face_landmarks_with_attention", "//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:tensors_to_floats_calculator", "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", "//mediapipe/calculators/util:landmark_projection_calculator", "//mediapipe/calculators/util:thresholding_calculator", + "//mediapipe/framework/tool:switch_container", ], ) @@ -42,14 +47,19 @@ mediapipe_simple_subgraph( graph = "face_landmark_gpu.pbtxt", register_as = "FaceLandmarkGpu", deps = [ + ":face_landmarks_model_loader", + ":tensors_to_face_landmarks", + ":tensors_to_face_landmarks_with_attention", "//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:tensors_to_floats_calculator", "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", "//mediapipe/calculators/util:landmark_projection_calculator", "//mediapipe/calculators/util:thresholding_calculator", + "//mediapipe/framework/tool:switch_container", ], ) @@ -101,6 +111,7 @@ mediapipe_simple_subgraph( register_as = "FaceLandmarkFrontCpuImage", deps = [ ":face_landmark_front_cpu", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:from_image_calculator", ], @@ -112,6 +123,7 @@ mediapipe_simple_subgraph( register_as = "FaceLandmarkFrontGpuImage", deps = [ ":face_landmark_front_gpu", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:from_image_calculator", ], @@ -120,6 +132,7 @@ mediapipe_simple_subgraph( exports_files( srcs = [ "face_landmark.tflite", + "face_landmark_with_attention.tflite", ], ) @@ -143,3 +156,35 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/util:rect_transformation_calculator", ], ) + +mediapipe_simple_subgraph( + name = "face_landmarks_model_loader", + graph = "face_landmarks_model_loader.pbtxt", + register_as = "FaceLandmarksModelLoader", + deps = [ + "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/tflite:tflite_model_calculator", + "//mediapipe/calculators/util:local_file_contents_calculator", + "//mediapipe/framework/tool:switch_container", + ], +) + +mediapipe_simple_subgraph( + name = "tensors_to_face_landmarks", + graph = "tensors_to_face_landmarks.pbtxt", + register_as = "TensorsToFaceLandmarks", + deps = [ + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "tensors_to_face_landmarks_with_attention", + graph = "tensors_to_face_landmarks_with_attention.pbtxt", + register_as = "TensorsToFaceLandmarksWithAttention", + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/tensor:tensors_to_landmarks_calculator", + "//mediapipe/calculators/util:landmarks_refinement_calculator", + ], +) diff --git a/mediapipe/modules/face_landmark/face_landmark.tflite b/mediapipe/modules/face_landmark/face_landmark.tflite index e30e514e1..573285df4 100755 Binary files a/mediapipe/modules/face_landmark/face_landmark.tflite and b/mediapipe/modules/face_landmark/face_landmark.tflite differ diff --git a/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt index a94a8c803..4604fc753 100644 --- a/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_cpu.pbtxt @@ -3,13 +3,18 @@ # # It is required that "face_landmark.tflite" is available at # "mediapipe/modules/face_landmark/face_landmark.tflite" -# path during execution. +# path during execution if `with_attention` is not set or set to `false`. +# +# It is required that "face_landmark_with_attention.tflite" is available at +# "mediapipe/modules/face_landmark/face_landmark_with_attention.tflite" +# path during execution if `with_attention` is set to `true`. # # EXAMPLE: # node { # calculator: "FaceLandmarkCpu" # input_stream: "IMAGE:image" # input_stream: "ROI:face_roi" +# input_side_packet: "WITH_ATTENTION:with_attention" # output_stream: "LANDMARKS:face_landmarks" # } @@ -20,8 +25,17 @@ input_stream: "IMAGE:image" # ROI (region of interest) within the given image where a face is located. # (NormalizedRect) input_stream: "ROI:roi" +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" -# 468 face landmarks within the given ROI. (NormalizedLandmarkList) +# 468 or 478 facial landmarks within the given ROI. (NormalizedLandmarkList) +# +# Number of landmarks depends on the WITH_ATTENTION flag. If it's `true` - then +# there will be 478 landmarks with refined lips, eyes and irises (10 extra +# landmarks are for irises), otherwise 468 non-refined landmarks are returned. +# # NOTE: if a face is not present within the given ROI, for this particular # timestamp there will not be an output packet in the LANDMARKS stream. However, # the MediaPipe framework will internally inform the downstream calculators of @@ -46,31 +60,63 @@ node: { } } +# Loads the face landmarks TF Lite model. +node { + calculator: "FaceLandmarksModelLoader" + input_side_packet: "WITH_ATTENTION:with_attention" + output_side_packet: "MODEL:model" +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "op_resolver" +} + # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a # vector of tensors representing, for instance, detection boxes/keypoints and # scores. node { calculator: "InferenceCalculator" input_stream: "TENSORS:input_tensors" + input_side_packet: "MODEL:model" + input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" output_stream: "TENSORS:output_tensors" options: { [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_landmark/face_landmark.tflite" delegate { xnnpack {} } } } } -# Splits a vector of tensors into multiple vectors. +# Splits a vector of tensors into landmark tensors and face flag tensor. node { - calculator: "SplitTensorVectorCalculator" + calculator: "SwitchContainer" + input_side_packet: "ENABLE:with_attention" input_stream: "output_tensors" output_stream: "landmark_tensors" output_stream: "face_flag_tensor" options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 1 } - ranges: { begin: 1 end: 2 } + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "SplitTensorVectorCalculator" + options: { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + } + } + } + contained_node: { + calculator: "SplitTensorVectorCalculator" + options: { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 6 } + ranges: { begin: 6 end: 7 } + } + } + } } } } @@ -112,14 +158,18 @@ node { # Decodes the landmark tensors into a vector of landmarks, where the landmark # coordinates are normalized by the size of the input image to the model. node { - calculator: "TensorsToLandmarksCalculator" + calculator: "SwitchContainer" + input_side_packet: "ENABLE:with_attention" input_stream: "TENSORS:ensured_landmark_tensors" - output_stream: "NORM_LANDMARKS:landmarks" + output_stream: "LANDMARKS:landmarks" options: { - [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { - num_landmarks: 468 - input_image_width: 192 - input_image_height: 192 + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "TensorsToFaceLandmarks" + } + contained_node: { + calculator: "TensorsToFaceLandmarksWithAttention" + } } } } diff --git a/mediapipe/modules/face_landmark/face_landmark_front_cpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_front_cpu.pbtxt index f60ca3df7..70a57b0ef 100644 --- a/mediapipe/modules/face_landmark/face_landmark_front_cpu.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_front_cpu.pbtxt @@ -8,13 +8,19 @@ # # It is required that "face_landmark.tflite" is available at # "mediapipe/modules/face_landmark/face_landmark.tflite" -# path during execution. +# path during execution if `with_attention` is not set or set to `false`. +# +# It is required that "face_landmark_with_attention.tflite" is available at +# "mediapipe/modules/face_landmark/face_landmark_with_attention.tflite" +# path during execution if `with_attention` is set to `true`. # # EXAMPLE: # node { # calculator: "FaceLandmarkFrontCpu" # input_stream: "IMAGE:image" # input_side_packet: "NUM_FACES:num_faces" +# input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" +# input_side_packet: "WITH_ATTENTION:with_attention" # output_stream: "LANDMARKS:multi_face_landmarks" # } @@ -26,6 +32,15 @@ input_stream: "IMAGE:image" # Max number of faces to detect/track. (int) input_side_packet: "NUM_FACES:num_faces" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" + # Collection of detected/predicted faces, each represented as a list of 468 face # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -44,23 +59,19 @@ output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" # (std::vector) output_stream: "ROIS_FROM_DETECTIONS:face_rects_from_detections" -# Defines whether landmarks on the previous image should be used to help -# localize landmarks on the current image. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_face_rects_from_landmarks" output_stream: "gated_prev_face_rects_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Determines if an input vector of NormalizedRect has a size greater than or @@ -186,6 +197,7 @@ node { calculator: "FaceLandmarkCpu" input_stream: "IMAGE:landmarks_loop_image" input_stream: "ROI:face_rect" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:face_landmarks" } diff --git a/mediapipe/modules/face_landmark/face_landmark_front_cpu_image.pbtxt b/mediapipe/modules/face_landmark/face_landmark_front_cpu_image.pbtxt index 1f5bb3df4..7d0c46a75 100644 --- a/mediapipe/modules/face_landmark/face_landmark_front_cpu_image.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_front_cpu_image.pbtxt @@ -8,8 +8,17 @@ input_stream: "IMAGE:image" # Max number of faces to detect/track. (int) input_side_packet: "NUM_FACES:num_faces" -# The original input image. (Image) -output_stream: "IMAGE:image" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" + +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" # Collection of detected/predicted faces, each represented as a list of 468 face # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -28,10 +37,27 @@ output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" # (std::vector) output_stream: "ROIS_FROM_DETECTIONS:face_rects_from_detections" +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:multi_face_landmarks" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + # Converts Image to ImageFrame for FaceLandmarkFrontCpu to consume. node { calculator: "FromImageCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:throttled_image" output_stream: "IMAGE_CPU:raw_image_frame" output_stream: "SOURCE_ON_GPU:is_gpu_image" } @@ -52,6 +78,8 @@ node { calculator: "FaceLandmarkFrontCpu" input_stream: "IMAGE:image_frame" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "DETECTIONS:face_detections" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" diff --git a/mediapipe/modules/face_landmark/face_landmark_front_gpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_front_gpu.pbtxt index fe93d1955..fd8956518 100644 --- a/mediapipe/modules/face_landmark/face_landmark_front_gpu.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_front_gpu.pbtxt @@ -8,13 +8,19 @@ # # It is required that "face_landmark.tflite" is available at # "mediapipe/modules/face_landmark/face_landmark.tflite" -# path during execution. +# path during execution if `with_attention` is not set or set to `false`. +# +# It is required that "face_landmark_with_attention.tflite" is available at +# "mediapipe/modules/face_landmark/face_landmark_with_attention.tflite" +# path during execution if `with_attention` is set to `true`. # # EXAMPLE: # node { # calculator: "FaceLandmarkFrontGpu" # input_stream: "IMAGE:image" # input_side_packet: "NUM_FACES:num_faces" +# input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" +# input_side_packet: "WITH_ATTENTION:with_attention" # output_stream: "LANDMARKS:multi_face_landmarks" # } @@ -26,6 +32,15 @@ input_stream: "IMAGE:image" # Max number of faces to detect/track. (int) input_side_packet: "NUM_FACES:num_faces" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" + # Collection of detected/predicted faces, each represented as a list of 468 face # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -44,23 +59,19 @@ output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" # (std::vector) output_stream: "ROIS_FROM_DETECTIONS:face_rects_from_detections" -# Defines whether landmarks on the previous image should be used to help -# localize landmarks on the current image. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_face_rects_from_landmarks" output_stream: "gated_prev_face_rects_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Determines if an input vector of NormalizedRect has a size greater than or @@ -186,6 +197,7 @@ node { calculator: "FaceLandmarkGpu" input_stream: "IMAGE:landmarks_loop_image" input_stream: "ROI:face_rect" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:face_landmarks" } diff --git a/mediapipe/modules/face_landmark/face_landmark_front_gpu_image.pbtxt b/mediapipe/modules/face_landmark/face_landmark_front_gpu_image.pbtxt index 4c937bea9..31da4b849 100644 --- a/mediapipe/modules/face_landmark/face_landmark_front_gpu_image.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_front_gpu_image.pbtxt @@ -8,8 +8,17 @@ input_stream: "IMAGE:image" # Max number of faces to detect/track. (int) input_side_packet: "NUM_FACES:num_faces" -# The original input image. (Image) -output_stream: "IMAGE:image" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" + +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" # Collection of detected/predicted faces, each represented as a list of 468 face # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -28,10 +37,27 @@ output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" # (std::vector) output_stream: "ROIS_FROM_DETECTIONS:face_rects_from_detections" +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:multi_face_landmarks" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + # Converts Image to GpuBuffer for FaceLandmarkFrontGpu to consume. node { calculator: "FromImageCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:throttled_image" output_stream: "IMAGE_GPU:raw_gpu_buffer" output_stream: "SOURCE_ON_GPU:is_gpu_image" } @@ -52,6 +78,8 @@ node { calculator: "FaceLandmarkFrontGpu" input_stream: "IMAGE:gpu_buffer" input_side_packet: "NUM_FACES:num_faces" + input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + input_side_packet: "WITH_ATTENTION:with_attention" output_stream: "LANDMARKS:multi_face_landmarks" output_stream: "DETECTIONS:face_detections" output_stream: "ROIS_FROM_LANDMARKS:face_rects_from_landmarks" diff --git a/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt b/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt index 7d8c3bf7d..854ceaff6 100644 --- a/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt +++ b/mediapipe/modules/face_landmark/face_landmark_gpu.pbtxt @@ -3,13 +3,18 @@ # # It is required that "face_landmark.tflite" is available at # "mediapipe/modules/face_landmark/face_landmark.tflite" -# path during execution. +# path during execution if `with_attention` is not set or set to `false`. +# +# It is required that "face_landmark_with_attention.tflite" is available at +# "mediapipe/modules/face_landmark/face_landmark_with_attention.tflite" +# path during execution if `with_attention` is set to `true`. # # EXAMPLE: # node { # calculator: "FaceLandmarkGpu" # input_stream: "IMAGE:image" # input_stream: "ROI:face_roi" +# input_side_packet: "WITH_ATTENTION:with_attention" # output_stream: "LANDMARKS:face_landmarks" # } @@ -20,8 +25,17 @@ input_stream: "IMAGE:image" # ROI (region of interest) within the given image where a face is located. # (NormalizedRect) input_stream: "ROI:roi" +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" -# 468 face landmarks within the given ROI. (NormalizedLandmarkList) +# 468 or 478 facial landmarks within the given ROI. (NormalizedLandmarkList) +# +# Number of landmarks depends on the WITH_ATTENTION flag. If it's `true` - then +# there will be 478 landmarks with refined lips, eyes and irises (10 extra +# landmarks are for irises), otherwise 468 non-refined landmarks are returned. +# # NOTE: if a face is not present within the given ROI, for this particular # timestamp there will not be an output packet in the LANDMARKS stream. However, # the MediaPipe framework will internally inform the downstream calculators of @@ -47,30 +61,63 @@ node: { } } +# Loads the face landmarks TF Lite model. +node { + calculator: "FaceLandmarksModelLoader" + input_side_packet: "WITH_ATTENTION:with_attention" + output_side_packet: "MODEL:model" +} + +# Generates a single side packet containing a TensorFlow Lite op resolver that +# supports custom ops needed by the model used in this graph. +node { + calculator: "TfLiteCustomOpResolverCalculator" + output_side_packet: "op_resolver" +} + # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a # vector of GPU tensors representing, for instance, detection boxes/keypoints # and scores. node { calculator: "InferenceCalculator" input_stream: "TENSORS:input_tensors" + input_side_packet: "MODEL:model" + input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" output_stream: "TENSORS:output_tensors" options: { [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/modules/face_landmark/face_landmark.tflite" + # Do not remove. Used for generation of XNNPACK/NNAPI graphs. } } } -# Splits a vector of tensors into multiple vectors. +# Splits a vector of tensors into landmark tensors and face flag tensor. node { - calculator: "SplitTensorVectorCalculator" + calculator: "SwitchContainer" + input_side_packet: "ENABLE:with_attention" input_stream: "output_tensors" output_stream: "landmark_tensors" output_stream: "face_flag_tensor" - options: { - [mediapipe.SplitVectorCalculatorOptions.ext] { - ranges: { begin: 0 end: 1 } - ranges: { begin: 1 end: 2 } + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "SplitTensorVectorCalculator" + options: { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + } + } + } + contained_node: { + calculator: "SplitTensorVectorCalculator" + options: { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 6 } + ranges: { begin: 6 end: 7 } + } + } + } } } } @@ -81,11 +128,11 @@ node { calculator: "TensorsToFloatsCalculator" input_stream: "TENSORS:face_flag_tensor" output_stream: "FLOAT:face_presence_score" - options { - [mediapipe.TensorsToFloatsCalculatorOptions.ext] { - activation: SIGMOID - } + options: { + [mediapipe.TensorsToFloatsCalculatorOptions.ext] { + activation: SIGMOID } + } } # Applies a threshold to the confidence score to determine whether a face is @@ -112,14 +159,18 @@ node { # Decodes the landmark tensors into a vector of landmarks, where the landmark # coordinates are normalized by the size of the input image to the model. node { - calculator: "TensorsToLandmarksCalculator" + calculator: "SwitchContainer" + input_side_packet: "ENABLE:with_attention" input_stream: "TENSORS:ensured_landmark_tensors" - output_stream: "NORM_LANDMARKS:landmarks" + output_stream: "LANDMARKS:landmarks" options: { - [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { - num_landmarks: 468 - input_image_width: 192 - input_image_height: 192 + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "TensorsToFaceLandmarks" + } + contained_node: { + calculator: "TensorsToFaceLandmarksWithAttention" + } } } } diff --git a/mediapipe/modules/face_landmark/face_landmark_with_attention.tflite b/mediapipe/modules/face_landmark/face_landmark_with_attention.tflite new file mode 100755 index 000000000..fe0a93a85 Binary files /dev/null and b/mediapipe/modules/face_landmark/face_landmark_with_attention.tflite differ diff --git a/mediapipe/modules/face_landmark/face_landmarks_model_loader.pbtxt b/mediapipe/modules/face_landmark/face_landmarks_model_loader.pbtxt new file mode 100644 index 000000000..ecac1a6b9 --- /dev/null +++ b/mediapipe/modules/face_landmark/face_landmarks_model_loader.pbtxt @@ -0,0 +1,58 @@ +# MediaPipe graph to load a selected face landmarks TF Lite model. + +type: "FaceLandmarksModelLoader" + +# Whether to run face mesh model with attention on lips and eyes. (bool) +# Attention provides more accuracy on lips and eye regions as well as iris +# landmarks. +input_side_packet: "WITH_ATTENTION:with_attention" + +# TF Lite model represented as a FlatBuffer. +# (std::unique_ptr>) +output_side_packet: "MODEL:model" + +# Determines path to the desired face landmark model file based on specification +# in the input side packet. +node { + calculator: "SwitchContainer" + input_side_packet: "ENABLE:with_attention" + output_side_packet: "PACKET:model_path" + options: { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "ConstantSidePacketCalculator" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { + string_value: "mediapipe/modules/face_landmark/face_landmark.tflite" + } + } + } + } + contained_node: { + calculator: "ConstantSidePacketCalculator" + options: { + [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { + packet { + string_value: "mediapipe/modules/face_landmark/face_landmark_with_attention.tflite" + } + } + } + } + } + } +} + +# Loads the file in the specified path into a blob. +node { + calculator: "LocalFileContentsCalculator" + input_side_packet: "FILE_PATH:model_path" + output_side_packet: "CONTENTS:model_blob" +} + +# Converts the input blob into a TF Lite model. +node { + calculator: "TfLiteModelCalculator" + input_side_packet: "MODEL_BLOB:model_blob" + output_side_packet: "MODEL:model" +} diff --git a/mediapipe/modules/face_landmark/tensors_to_face_landmarks.pbtxt b/mediapipe/modules/face_landmark/tensors_to_face_landmarks.pbtxt new file mode 100644 index 000000000..0adbdf38c --- /dev/null +++ b/mediapipe/modules/face_landmark/tensors_to_face_landmarks.pbtxt @@ -0,0 +1,24 @@ +# MediaPipe graph to transform single tensor into 468 facial landmarks. + +type: "TensorsToFaceLandmarks" + +# Vector with a single tensor that contains 468 landmarks. (std::vector) +input_stream: "TENSORS:tensors" + +# 468 facial landmarks (NormalizedLandmarkList) +output_stream: "LANDMARKS:landmarks" + +# Decodes the landmark tensors into a vector of lanmarks, where the landmark +# coordinates are normalized by the size of the input image to the model. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:tensors" + output_stream: "NORM_LANDMARKS:landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 468 + input_image_width: 192 + input_image_height: 192 + } + } +} diff --git a/mediapipe/modules/face_landmark/tensors_to_face_landmarks_with_attention.pbtxt b/mediapipe/modules/face_landmark/tensors_to_face_landmarks_with_attention.pbtxt new file mode 100644 index 000000000..4f9b994bc --- /dev/null +++ b/mediapipe/modules/face_landmark/tensors_to_face_landmarks_with_attention.pbtxt @@ -0,0 +1,299 @@ +# MediaPipe graph to transform model output tensors into 478 facial landmarks +# with refined lips, eyes and irises. + +type: "TensorsToFaceLandmarksWithAttention" + +# Vector with a six tensors to parse landmarks from. (std::vector) +# Landmark tensors order: +# - mesh_tensor +# - lips_tensor +# - left_eye_tensor +# - right_eye_tensor +# - left_iris_tensor +# - right_iris_tensor +input_stream: "TENSORS:tensors" + +# 478 facial landmarks (NormalizedLandmarkList) +output_stream: "LANDMARKS:landmarks" + +# Splits a vector of tensors into multiple vectors. +node { + calculator: "SplitTensorVectorCalculator" + input_stream: "tensors" + output_stream: "mesh_tensor" + output_stream: "lips_tensor" + output_stream: "left_eye_tensor" + output_stream: "right_eye_tensor" + output_stream: "left_iris_tensor" + output_stream: "right_iris_tensor" + options: { + [mediapipe.SplitVectorCalculatorOptions.ext] { + ranges: { begin: 0 end: 1 } + ranges: { begin: 1 end: 2 } + ranges: { begin: 2 end: 3 } + ranges: { begin: 3 end: 4 } + ranges: { begin: 4 end: 5 } + ranges: { begin: 5 end: 6 } + } + } +} + +# Decodes mesh landmarks tensor into a vector of normalized lanmarks. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:mesh_tensor" + output_stream: "NORM_LANDMARKS:mesh_landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 468 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Decodes lips landmarks tensor into a vector of normalized lanmarks. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:lips_tensor" + output_stream: "NORM_LANDMARKS:lips_landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 80 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Decodes left eye landmarks tensor into a vector of normalized lanmarks. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:left_eye_tensor" + output_stream: "NORM_LANDMARKS:left_eye_landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 71 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Decodes right eye landmarks tensor into a vector of normalized lanmarks. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:right_eye_tensor" + output_stream: "NORM_LANDMARKS:right_eye_landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 71 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Decodes left iris landmarks tensor into a vector of normalized lanmarks. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:left_iris_tensor" + output_stream: "NORM_LANDMARKS:left_iris_landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 5 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Decodes right iris landmarks tensor into a vector of normalized lanmarks. +node { + calculator: "TensorsToLandmarksCalculator" + input_stream: "TENSORS:right_iris_tensor" + output_stream: "NORM_LANDMARKS:right_iris_landmarks" + options: { + [mediapipe.TensorsToLandmarksCalculatorOptions.ext] { + num_landmarks: 5 + input_image_width: 192 + input_image_height: 192 + } + } +} + +# Refine mesh landmarks with lips, eyes and irises. +node { + calculator: "LandmarksRefinementCalculator" + input_stream: "LANDMARKS:0:mesh_landmarks" + input_stream: "LANDMARKS:1:lips_landmarks" + input_stream: "LANDMARKS:2:left_eye_landmarks" + input_stream: "LANDMARKS:3:right_eye_landmarks" + input_stream: "LANDMARKS:4:left_iris_landmarks" + input_stream: "LANDMARKS:5:right_iris_landmarks" + output_stream: "REFINED_LANDMARKS:landmarks" + options: { + [mediapipe.LandmarksRefinementCalculatorOptions.ext] { + # 0 - mesh + refinement: { + indexes_mapping: [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, + 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, + 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, + 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, + 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, + 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, + 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, + 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, + 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, + 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, + 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, + 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, + 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, + 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, + 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, + 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, + 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, + 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, + 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, + 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, + 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, + 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, + 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, + 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, + 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, + 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, + 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467 + ] + z_refinement: { copy {} } + } + # 1 - lips + refinement: { + indexes_mapping: [ + # Lower outer. + 61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, + # Upper outer (excluding corners). + 185, 40, 39, 37, 0, 267, 269, 270, 409, + # Lower inner. + 78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, + # Upper inner (excluding corners). + 191, 80, 81, 82, 13, 312, 311, 310, 415, + # Lower semi-outer. + 76, 77, 90, 180, 85, 16, 315, 404, 320, 307, 306, + # Upper semi-outer (excluding corners). + 184, 74, 73, 72, 11, 302, 303, 304, 408, + # Lower semi-inner. + 62, 96, 89, 179, 86, 15, 316, 403, 319, 325, 292, + # Upper semi-inner (excluding corners). + 183, 42, 41, 38, 12, 268, 271, 272, 407 + ] + z_refinement: { none {} } + } + # 2 - left eye + refinement: { + indexes_mapping: [ + # Lower contour. + 33, 7, 163, 144, 145, 153, 154, 155, 133, + # upper contour (excluding corners). + 246, 161, 160, 159, 158, 157, 173, + # Halo x2 lower contour. + 130, 25, 110, 24, 23, 22, 26, 112, 243, + # Halo x2 upper contour (excluding corners). + 247, 30, 29, 27, 28, 56, 190, + # Halo x3 lower contour. + 226, 31, 228, 229, 230, 231, 232, 233, 244, + # Halo x3 upper contour (excluding corners). + 113, 225, 224, 223, 222, 221, 189, + # Halo x4 upper contour (no lower because of mesh structure) or + # eyebrow inner contour. + 35, 124, 46, 53, 52, 65, + # Halo x5 lower contour. + 143, 111, 117, 118, 119, 120, 121, 128, 245, + # Halo x5 upper contour (excluding corners) or eyebrow outer contour. + 156, 70, 63, 105, 66, 107, 55, 193 + ] + z_refinement: { none {} } + } + # 3 - right eye + refinement: { + indexes_mapping: [ + # Lower contour. + 263, 249, 390, 373, 374, 380, 381, 382, 362, + # Upper contour (excluding corners). + 466, 388, 387, 386, 385, 384, 398, + # Halo x2 lower contour. + 359, 255, 339, 254, 253, 252, 256, 341, 463, + # Halo x2 upper contour (excluding corners). + 467, 260, 259, 257, 258, 286, 414, + # Halo x3 lower contour. + 446, 261, 448, 449, 450, 451, 452, 453, 464, + # Halo x3 upper contour (excluding corners). + 342, 445, 444, 443, 442, 441, 413, + # Halo x4 upper contour (no lower because of mesh structure) or + # eyebrow inner contour. + 265, 353, 276, 283, 282, 295, + # Halo x5 lower contour. + 372, 340, 346, 347, 348, 349, 350, 357, 465, + # Halo x5 upper contour (excluding corners) or eyebrow outer contour. + 383, 300, 293, 334, 296, 336, 285, 417 + ] + z_refinement: { none {} } + } + # 4 - left iris + refinement: { + indexes_mapping: [ + # Center. + 468, + # Iris right edge. + 469, + # Iris top edge. + 470, + # Iris left edge. + 471, + # Iris bottom edge. + 472 + ] + z_refinement: { + assign_average: { + indexes_for_average: [ + # Lower contour. + 33, 7, 163, 144, 145, 153, 154, 155, 133, + # Upper contour (excluding corners). + 246, 161, 160, 159, 158, 157, 173 + ] + } + } + } + # 5 - right iris + refinement: { + indexes_mapping: [ + # Center. + 473, + # Iris right edge. + 474, + # Iris top edge. + 475, + # Iris left edge. + 476, + # Iris bottom edge. + 477 + ] + z_refinement: { + assign_average: { + indexes_for_average: [ + # Lower contour. + 263, 249, 390, 373, 374, 380, 381, 382, 362, + # Upper contour (excluding corners). + 466, 388, 387, 386, 385, 384, 398 + ] + } + } + } + } + } +} diff --git a/mediapipe/modules/hand_landmark/BUILD b/mediapipe/modules/hand_landmark/BUILD index dfbb34125..3bb726f95 100644 --- a/mediapipe/modules/hand_landmark/BUILD +++ b/mediapipe/modules/hand_landmark/BUILD @@ -92,6 +92,7 @@ mediapipe_simple_subgraph( register_as = "HandLandmarkTrackingCpuImage", deps = [ ":hand_landmark_tracking_cpu", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:from_image_calculator", ], @@ -103,6 +104,7 @@ mediapipe_simple_subgraph( register_as = "HandLandmarkTrackingGpuImage", deps = [ ":hand_landmark_tracking_gpu", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:from_image_calculator", ], diff --git a/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite b/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite index 28d2730a3..f69ae8e6f 100755 Binary files a/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite and b/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite differ diff --git a/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt index fbbdaa098..efc6e1936 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.pbtxt @@ -14,6 +14,10 @@ input_stream: "IMAGE:image" # Max number of hands to detect/track. (int) input_side_packet: "NUM_HANDS:num_hands" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Collection of detected/predicted hands, each represented as a list of # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -38,23 +42,19 @@ output_stream: "HAND_ROIS_FROM_LANDMARKS:hand_rects" # (std::vector) output_stream: "HAND_ROIS_FROM_PALM_DETECTIONS:hand_rects_from_palm_detections" -# Defines whether landmarks on the previous image should be used to help -# localize landmarks on the current image. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_hand_rects_from_landmarks" output_stream: "gated_prev_hand_rects_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Determines if an input vector of NormalizedRect has a size greater than or diff --git a/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu_image.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu_image.pbtxt index a7ead52a2..913123ec3 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu_image.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu_image.pbtxt @@ -14,8 +14,12 @@ input_stream: "IMAGE:image" # Max number of hands to detect/track. (int) input_side_packet: "NUM_HANDS:num_hands" -# The original input image. (Image) -output_stream: "IMAGE:image" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" # Collection of detected/predicted hands, each represented as a list of # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -40,10 +44,27 @@ output_stream: "HAND_ROIS_FROM_LANDMARKS:hand_rects" # (std::vector) output_stream: "HAND_ROIS_FROM_PALM_DETECTIONS:hand_rects_from_palm_detections" +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:multi_hand_landmarks" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + # Converts Image to ImageFrame for HandLandmarkTrackingCpu to consume. node { calculator: "FromImageCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:throttled_image" output_stream: "IMAGE_CPU:raw_image_frame" output_stream: "SOURCE_ON_GPU:is_gpu_image" } @@ -64,6 +85,7 @@ node { calculator: "HandLandmarkTrackingCpu" input_stream: "IMAGE:image_frame" input_side_packet: "NUM_HANDS:num_hands" + input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" output_stream: "LANDMARKS:multi_hand_landmarks" output_stream: "HANDEDNESS:multi_handedness" output_stream: "PALM_DETECTIONS:palm_detections" diff --git a/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu.pbtxt index fa8c5c172..56e7fdb00 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu.pbtxt @@ -14,6 +14,10 @@ input_stream: "IMAGE:image" # Max number of hands to detect/track. (int) input_side_packet: "NUM_HANDS:num_hands" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Collection of detected/predicted hands, each represented as a list of # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -38,23 +42,19 @@ output_stream: "HAND_ROIS_FROM_LANDMARKS:hand_rects" # (std::vector) output_stream: "HAND_ROIS_FROM_PALM_DETECTIONS:hand_rects_from_palm_detections" -# Defines whether landmarks on the previous image should be used to help -# localize landmarks on the current image. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_hand_rects_from_landmarks" output_stream: "gated_prev_hand_rects_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Determines if an input vector of NormalizedRect has a size greater than or diff --git a/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu_image.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu_image.pbtxt index 006a2aaf6..269fe3770 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu_image.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_tracking_gpu_image.pbtxt @@ -14,6 +14,10 @@ input_stream: "IMAGE:image" # Max number of hands to detect/track. (int) input_side_packet: "NUM_HANDS:num_hands" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Collection of detected/predicted hands, each represented as a list of # landmarks. (std::vector) # NOTE: there will not be an output packet in the LANDMARKS stream for this @@ -28,8 +32,8 @@ output_stream: "LANDMARKS:multi_hand_landmarks" # horizontally. output_stream: "HANDEDNESS:multi_handedness" -# The original input image. (Image) -output_stream: "IMAGE:image" +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" # Extra outputs (for debugging, for instance). # Detected palms. (std::vector) output_stream: "PALM_DETECTIONS:palm_detections" @@ -40,10 +44,27 @@ output_stream: "HAND_ROIS_FROM_LANDMARKS:hand_rects" # (std::vector) output_stream: "HAND_ROIS_FROM_PALM_DETECTIONS:hand_rects_from_palm_detections" +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:multi_hand_landmarks" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + # Converts Image to GpuBuffer for HandLandmarkTrackingGpu to consume. node { calculator: "FromImageCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:throttled_image" output_stream: "IMAGE_GPU:raw_gpu_buffer" output_stream: "SOURCE_ON_GPU:is_gpu_image" } @@ -64,6 +85,7 @@ node { calculator: "HandLandmarkTrackingGpu" input_stream: "IMAGE:gpu_buffer" input_side_packet: "NUM_HANDS:num_hands" + input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" output_stream: "LANDMARKS:multi_hand_landmarks" output_stream: "HANDEDNESS:multi_handedness" output_stream: "PALM_DETECTIONS:palm_detections" diff --git a/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt b/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt index d198bb1e7..878e2304d 100644 --- a/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt +++ b/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt @@ -35,6 +35,7 @@ # input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" # input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" +# input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" # output_stream: "POSE_LANDMARKS:pose_landmarks" # output_stream: "FACE_LANDMARKS:face_landmarks" # output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" @@ -69,6 +70,10 @@ input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" # jitter. If unspecified, functions as set to true. (bool) input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Pose landmarks. (NormalizedLandmarkList) # 33 pose landmarks. output_stream: "POSE_LANDMARKS:pose_landmarks" @@ -81,6 +86,9 @@ output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" # 468 face landmarks. (NormalizedLandmarkList) output_stream: "FACE_LANDMARKS:face_landmarks" +# Segmentation mask. (ImageFrame in ImageFormat::VEC32F1) +output_stream: "SEGMENTATION_MASK:segmentation_mask" + # Debug outputs output_stream: "POSE_ROI:pose_landmarks_roi" output_stream: "POSE_DETECTION:pose_detection" @@ -93,8 +101,10 @@ node { input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" + input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" output_stream: "LANDMARKS:pose_landmarks" output_stream: "WORLD_LANDMARKS:pose_world_landmarks" + output_stream: "SEGMENTATION_MASK:segmentation_mask" output_stream: "ROI_FROM_LANDMARKS:pose_landmarks_roi" output_stream: "DETECTION:pose_detection" } diff --git a/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt b/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt index 49cfa4677..dc2a7b931 100644 --- a/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt +++ b/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt @@ -35,6 +35,7 @@ # input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" # input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" +# input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" # output_stream: "POSE_LANDMARKS:pose_landmarks" # output_stream: "FACE_LANDMARKS:face_landmarks" # output_stream: "LEFT_HAND_LANDMARKS:left_hand_landmarks" @@ -69,6 +70,10 @@ input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" # jitter. If unspecified, functions as set to true. (bool) input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Pose landmarks. (NormalizedLandmarkList) # 33 pose landmarks. output_stream: "POSE_LANDMARKS:pose_landmarks" @@ -81,6 +86,9 @@ output_stream: "RIGHT_HAND_LANDMARKS:right_hand_landmarks" # 468 face landmarks. (NormalizedLandmarkList) output_stream: "FACE_LANDMARKS:face_landmarks" +# Segmentation mask. (GpuBuffer in RGBA, with the same mask values in R and A) +output_stream: "SEGMENTATION_MASK:segmentation_mask" + # Debug outputs output_stream: "POSE_ROI:pose_landmarks_roi" output_stream: "POSE_DETECTION:pose_detection" @@ -93,8 +101,10 @@ node { input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" + input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" output_stream: "LANDMARKS:pose_landmarks" output_stream: "WORLD_LANDMARKS:pose_world_landmarks" + output_stream: "SEGMENTATION_MASK:segmentation_mask" output_stream: "ROI_FROM_LANDMARKS:pose_landmarks_roi" output_stream: "DETECTION:pose_detection" } diff --git a/mediapipe/modules/objectron/objectron_cpu.pbtxt b/mediapipe/modules/objectron/objectron_cpu.pbtxt index 834c56464..884da057b 100644 --- a/mediapipe/modules/objectron/objectron_cpu.pbtxt +++ b/mediapipe/modules/objectron/objectron_cpu.pbtxt @@ -9,6 +9,9 @@ input_side_packet: "MODEL_PATH:box_landmark_model_path" input_side_packet: "LABELS_CSV:allowed_labels" # Max number of objects to detect/track. (int) input_side_packet: "MAX_NUM_OBJECTS:max_num_objects" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" # Bounding box landmarks topology definition. # The numbers are indices in the box_landmarks list. # @@ -48,24 +51,19 @@ node { output_side_packet: "MODEL:box_landmark_model" } -# Defines whether landmarks from the previous video frame should be used to help -# predict landmarks on the current video frame. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} - +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_box_rects_from_landmarks" output_stream: "gated_prev_box_rects_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Determines if an input vector of NormalizedRect has a size greater than or diff --git a/mediapipe/modules/objectron/objectron_gpu.pbtxt b/mediapipe/modules/objectron/objectron_gpu.pbtxt index 16187deae..7ef2b6710 100644 --- a/mediapipe/modules/objectron/objectron_gpu.pbtxt +++ b/mediapipe/modules/objectron/objectron_gpu.pbtxt @@ -8,28 +8,26 @@ input_stream: "IMAGE_GPU:image" input_side_packet: "LABELS_CSV:allowed_labels" # Max number of objects to detect/track. (int) input_side_packet: "MAX_NUM_OBJECTS:max_num_objects" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" # Collection of detected 3D objects, represented as a FrameAnnotation. output_stream: "FRAME_ANNOTATION:detected_objects" -# Defines whether landmarks from the previous video frame should be used to help -# predict landmarks on the current video frame. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} - +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_box_rects_from_landmarks" output_stream: "gated_prev_box_rects_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Determines if an input vector of NormalizedRect has a size greater than or diff --git a/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt index bbec6cc26..5faf08a76 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_cpu.pbtxt @@ -21,6 +21,7 @@ # input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" # input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" +# input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" # input_stream: "IMAGE:image" # output_stream: "LANDMARKS:pose_landmarks" # output_stream: "SEGMENTATION_MASK:segmentation_mask" @@ -48,6 +49,10 @@ input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" # functions as set to 1. (int) input_side_packet: "MODEL_COMPLEXITY:model_complexity" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Pose landmarks. (NormalizedLandmarkList) # We have 33 landmarks (see pose_landmark_topology.svg), and there are other # auxiliary key points. @@ -110,23 +115,19 @@ output_stream: "ROI_FROM_LANDMARKS:pose_rect_from_landmarks" # Regions of interest calculated based on pose detections. (NormalizedRect) output_stream: "ROI_FROM_DETECTION:pose_rect_from_detection" -# Defines whether landmarks on the previous image should be used to help -# localize landmarks on the current image. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_pose_rect_from_landmarks" output_stream: "gated_prev_pose_rect_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Checks if there's previous pose rect calculated from landmarks. diff --git a/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt index 312d98f47..3ff9ac9fe 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_gpu.pbtxt @@ -21,6 +21,7 @@ # input_side_packet: "SMOOTH_LANDMARKS:smooth_landmarks" # input_side_packet: "ENABLE_SEGMENTATION:enable_segmentation" # input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" +# input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" # input_stream: "IMAGE:image" # output_stream: "LANDMARKS:pose_landmarks" # output_stream: "SEGMENTATION_MASK:segmentation_mask" @@ -48,6 +49,10 @@ input_side_packet: "SMOOTH_SEGMENTATION:smooth_segmentation" # functions as set to 1. (int) input_side_packet: "MODEL_COMPLEXITY:model_complexity" +# Whether landmarks on the previous image should be used to help localize +# landmarks on the current image. (bool) +input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" + # Pose landmarks. (NormalizedLandmarkList) # We have 33 landmarks (see pose_landmark_topology.svg), and there are other # auxiliary key points. @@ -110,23 +115,19 @@ output_stream: "ROI_FROM_LANDMARKS:pose_rect_from_landmarks" # Regions of interest calculated based on pose detections. (NormalizedRect) output_stream: "ROI_FROM_DETECTION:pose_rect_from_detection" -# Defines whether landmarks on the previous image should be used to help -# localize landmarks on the current image. -node { - name: "ConstantSidePacketCalculator" - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:use_prev_landmarks" - options: { - [mediapipe.ConstantSidePacketCalculatorOptions.ext]: { - packet { bool_value: true } - } - } -} +# When the optional input side packet "use_prev_landmarks" is either absent or +# set to true, uses the landmarks on the previous image to help localize +# landmarks on the current image. node { calculator: "GateCalculator" input_side_packet: "ALLOW:use_prev_landmarks" input_stream: "prev_pose_rect_from_landmarks" output_stream: "gated_prev_pose_rect_from_landmarks" + options: { + [mediapipe.GateCalculatorOptions.ext] { + allow: true + } + } } # Checks if there's previous pose rect calculated from landmarks. diff --git a/mediapipe/modules/selfie_segmentation/BUILD b/mediapipe/modules/selfie_segmentation/BUILD index d51af427c..7fc271a67 100644 --- a/mediapipe/modules/selfie_segmentation/BUILD +++ b/mediapipe/modules/selfie_segmentation/BUILD @@ -71,6 +71,7 @@ mediapipe_simple_subgraph( register_as = "SelfieSegmentationCpuImage", deps = [ ":selfie_segmentation_cpu", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:from_image_calculator", "//mediapipe/calculators/util:to_image_calculator", @@ -83,6 +84,7 @@ mediapipe_simple_subgraph( register_as = "SelfieSegmentationGpuImage", deps = [ ":selfie_segmentation_gpu", + "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/util:from_image_calculator", "//mediapipe/calculators/util:to_image_calculator", diff --git a/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu_image.pbtxt b/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu_image.pbtxt index 008d1814f..a35ff0e69 100644 --- a/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu_image.pbtxt +++ b/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu_image.pbtxt @@ -5,8 +5,8 @@ type: "SelfieSegmentationCpuImage" # Input image. (Image) input_stream: "IMAGE:image" -# The original input image. (Image) -output_stream: "IMAGE:image" +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" # An integer 0 or 1. Use 0 to select a general-purpose model (operating on a # 256x256 tensor), and 1 to select a model (operating on a 256x144 tensor) more @@ -16,10 +16,27 @@ input_side_packet: "MODEL_SELECTION:model_selection" # Segmentation mask. (Image) output_stream: "SEGMENTATION_MASK:segmentation_mask" +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:segmentation_mask" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + # Converts Image to ImageFrame for SelfieSegmentationCpu to consume. node { calculator: "FromImageCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:throttled_image" output_stream: "IMAGE_CPU:raw_image_frame" output_stream: "SOURCE_ON_GPU:is_gpu_image" } diff --git a/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu_image.pbtxt b/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu_image.pbtxt index 7ec5406f0..d5c0935a5 100644 --- a/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu_image.pbtxt +++ b/mediapipe/modules/selfie_segmentation/selfie_segmentation_gpu_image.pbtxt @@ -5,8 +5,8 @@ type: "SelfieSegmentationGpuImage" # Input image. (Image) input_stream: "IMAGE:image" -# The original input image. (Image) -output_stream: "IMAGE:image" +# The throttled input image. (Image) +output_stream: "IMAGE:throttled_image" # An integer 0 or 1. Use 0 to select a general-purpose model (operating on a # 256x256 tensor), and 1 to select a model (operating on a 256x144 tensor) more @@ -16,10 +16,27 @@ input_side_packet: "MODEL_SELECTION:model_selection" # Segmentation mask. (Image) output_stream: "SEGMENTATION_MASK:segmentation_mask" +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + input_stream: "FINISHED:segmentation_mask" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_image" + options: { + [mediapipe.FlowLimiterCalculatorOptions.ext] { + max_in_flight: 1 + max_in_queue: 1 + } + } +} + # Converts Image to ImageFrame for SelfieSegmentationGpu to consume. node { calculator: "FromImageCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:throttled_image" output_stream: "IMAGE_GPU:raw_gpu_buffer" output_stream: "SOURCE_ON_GPU:is_gpu_image" } diff --git a/mediapipe/objc/MPPPlayerInputSource.m b/mediapipe/objc/MPPPlayerInputSource.m index d9c78054a..30331b345 100644 --- a/mediapipe/objc/MPPPlayerInputSource.m +++ b/mediapipe/objc/MPPPlayerInputSource.m @@ -31,6 +31,8 @@ CVDisplayLinkRef _videoDisplayLink; #endif // TARGET_OS_OSX id _videoEndObserver; + id _audioInterruptedObserver; + BOOL _playing; } - (instancetype)initWithAVAsset:(AVAsset*)video { @@ -76,12 +78,19 @@ usingBlock:^(NSNotification* note) { [weakSelf playerItemDidPlayToEnd:note]; }]; + _audioInterruptedObserver = [center addObserverForName:AVAudioSessionInterruptionNotification + object:nil + queue:nil + usingBlock:^(NSNotification* note) { + [weakSelf audioSessionInterruption:note]; + }]; } return self; } - (void)start { [_videoPlayer play]; + _playing = YES; #if !TARGET_OS_OSX _videoDisplayLink.paused = NO; #else @@ -96,6 +105,7 @@ CVDisplayLinkStop(_videoDisplayLink); #endif [_videoPlayer pause]; + _playing = NO; } - (BOOL)isRunning { @@ -156,6 +166,17 @@ static CVReturn renderCallback(CVDisplayLinkRef displayLink, const CVTimeStamp* }); } +- (void)audioSessionInterruption:(NSNotification*)notification { + if ([notification.userInfo[AVAudioSessionInterruptionTypeKey] intValue] == + AVAudioSessionInterruptionTypeEnded) { + if ([notification.userInfo[AVAudioSessionInterruptionOptionKey] intValue] == + AVAudioSessionInterruptionOptionShouldResume && _playing) { + // AVVideoPlayer does not automatically resume on this notification. + [_videoPlayer play]; + } + } +} + - (void)seekToTime:(CMTime)time tolerance:(CMTime)tolerance { [_videoPlayer seekToTime:time toleranceBefore:tolerance toleranceAfter:tolerance]; } diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc index e02b8aba7..15d84cd07 100644 --- a/mediapipe/objc/util.cc +++ b/mediapipe/objc/util.cc @@ -241,6 +241,44 @@ vImage_Error vImageConvertCVPixelBuffers(CVPixelBufferRef src, return error; } +#if TARGET_IPHONE_SIMULATOR +static void FreeRefConReleaseCallback(void* refCon, const void* baseAddress) { + free(refCon); +} +#endif + +CVReturn CreateCVPixelBufferWithoutPool(int width, int height, OSType cv_format, + CVPixelBufferRef* out_buffer) { +#if TARGET_IPHONE_SIMULATOR + // On the simulator, syncing the texture with the pixelbuffer does not work, + // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not + // available in OpenGL ES 2, we should create the buffer so the pixels are + // contiguous. + // + // TODO: verify if we can use kIOSurfaceBytesPerRow to force + // CoreVideo to give us contiguous data. + size_t bytes_per_row = width * 4; + void* data = malloc(bytes_per_row * height); + return CVPixelBufferCreateWithBytes( + kCFAllocatorDefault, width, height, cv_format, data, bytes_per_row, + FreeRefConReleaseCallback, data, + GetCVPixelBufferAttributesForGlCompatibility(), out_buffer); +#else + return CVPixelBufferCreate(kCFAllocatorDefault, width, height, cv_format, + GetCVPixelBufferAttributesForGlCompatibility(), + out_buffer); +#endif +} + +absl::StatusOr> CreateCVPixelBufferWithoutPool( + int width, int height, OSType cv_format) { + CVPixelBufferRef buffer; + CVReturn err = + CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); + RET_CHECK(err == kCVReturnSuccess) << "Error creating pixel buffer: " << err; + return MakeCFHolderAdopting(buffer); +} + void ReleaseMediaPipePacket(void* refcon, const void* base_address) { auto packet = (mediapipe::Packet*)refcon; delete packet; @@ -288,13 +326,9 @@ absl::Status CreateCVPixelBufferForImageFramePacket( if (can_overwrite) { v_dest = v_image; } else { - CVPixelBufferRef pixel_buffer_temp; - status = CVPixelBufferCreate( - kCFAllocatorDefault, frame.Width(), frame.Height(), pixel_format, - GetCVPixelBufferAttributesForGlCompatibility(), &pixel_buffer_temp); - RET_CHECK(status == kCVReturnSuccess) - << "CVPixelBufferCreate failed: " << status; - pixel_buffer.adopt(pixel_buffer_temp); + ASSIGN_OR_RETURN(pixel_buffer, + CreateCVPixelBufferWithoutPool( + frame.Width(), frame.Height(), pixel_format)); status = CVPixelBufferLockBaseAddress(*pixel_buffer, kCVPixelBufferLock_ReadOnly); RET_CHECK(status == kCVReturnSuccess) @@ -345,6 +379,83 @@ absl::Status CreateCVPixelBufferForImageFramePacket( return absl::OkStatus(); } +absl::StatusOr> CreateCVPixelBufferCopyingImageFrame( + const mediapipe::ImageFrame& image_frame) { + CFHolder pixel_buffer; + OSType pixel_format = 0; + std::function copy_fun = + [](const vImage_Buffer& src, vImage_Buffer& dst) -> absl::Status { + const char* src_row = reinterpret_cast(src.data); + char* dst_row = reinterpret_cast(dst.data); + if (src.rowBytes == dst.rowBytes) { + memcpy(dst_row, src_row, src.height * src.rowBytes); + } else { + for (int i = src.height; i > 0; --i) { + memcpy(dst_row, src_row, src.rowBytes); + src_row += src.rowBytes; + dst_row += dst.rowBytes; + } + } + return {}; + }; + + // TODO: unify some code with CreateCVPixelBufferForImageFramePacket? + mediapipe::ImageFormat::Format image_format = image_frame.Format(); + switch (image_format) { + case mediapipe::ImageFormat::SRGBA: + pixel_format = kCVPixelFormatType_32BGRA; + copy_fun = [](const vImage_Buffer& src, + vImage_Buffer& dst) -> absl::Status { + // Swap R and B channels. + const uint8_t permute_map[4] = {2, 1, 0, 3}; + vImage_Error vError = vImagePermuteChannels_ARGB8888( + &src, &dst, permute_map, kvImageNoFlags); + RET_CHECK(vError == kvImageNoError) + << "vImagePermuteChannels failed: " << vError; + return {}; + }; + break; + + case mediapipe::ImageFormat::GRAY8: + pixel_format = kCVPixelFormatType_OneComponent8; + break; + + case mediapipe::ImageFormat::VEC32F1: + pixel_format = kCVPixelFormatType_OneComponent32Float; + break; + + case mediapipe::ImageFormat::VEC32F2: + pixel_format = kCVPixelFormatType_TwoComponent32Float; + break; + + default: + return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) + << "unsupported ImageFrame format: " << image_format; + } + + CVReturn cv_err; + ASSIGN_OR_RETURN(pixel_buffer, CreateCVPixelBufferWithoutPool( + image_frame.Width(), image_frame.Height(), + pixel_format)); + cv_err = + CVPixelBufferLockBaseAddress(*pixel_buffer, kCVPixelBufferLock_ReadOnly); + RET_CHECK(cv_err == kCVReturnSuccess) + << "CVPixelBufferLockBaseAddress failed: " << cv_err; + + vImage_Buffer v_image = vImageForImageFrame(image_frame); + vImage_Buffer v_dest = vImageForCVPixelBuffer(*pixel_buffer); + auto status = copy_fun(v_image, v_dest); + + cv_err = CVPixelBufferUnlockBaseAddress(*pixel_buffer, + kCVPixelBufferLock_ReadOnly); + RET_CHECK(cv_err == kCVReturnSuccess) + << "CVPixelBufferUnlockBaseAddress failed: " << cv_err; + + MP_RETURN_IF_ERROR(status); + + return pixel_buffer; +} + absl::Status CreateCGImageFromCVPixelBuffer(CVPixelBufferRef image_buffer, CFHolder* image) { CVReturn status = diff --git a/mediapipe/objc/util.h b/mediapipe/objc/util.h index e499162f3..221ba6b1e 100644 --- a/mediapipe/objc/util.h +++ b/mediapipe/objc/util.h @@ -63,6 +63,13 @@ vImage_Error vImageConvertCVPixelBuffers(CVPixelBufferRef src, /// alive while the CVPixelBuffer is in use. void ReleaseMediaPipePacket(void* refcon, const void* base_address); +// Create a CVPixelBuffer without using a pool. See pixel_buffer_pool_util.h +// for creation functions that use pools. +CVReturn CreateCVPixelBufferWithoutPool(int width, int height, OSType cv_format, + CVPixelBufferRef* out_buffer); +absl::StatusOr> CreateCVPixelBufferWithoutPool( + int width, int height, OSType cv_format); + /// Returns a CVPixelBuffer that references the data inside the packet. The /// packet must contain an ImageFrame. The CVPixelBuffer manages a copy of /// the packet, so that the packet's data is kept alive as long as the @@ -78,6 +85,8 @@ absl::Status CreateCVPixelBufferForImageFramePacket( absl::Status CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet, bool can_overwrite, CFHolder* out_buffer); +absl::StatusOr> CreateCVPixelBufferCopyingImageFrame( + const mediapipe::ImageFrame& image_frame); /// Creates a CVPixelBuffer with a copy of the contents of the CGImage. absl::Status CreateCVPixelBufferFromCGImage( diff --git a/mediapipe/python/solutions/drawing_styles.py b/mediapipe/python/solutions/drawing_styles.py index f84148e2d..b43bca8d3 100644 --- a/mediapipe/python/solutions/drawing_styles.py +++ b/mediapipe/python/solutions/drawing_styles.py @@ -183,6 +183,23 @@ def get_default_face_mesh_tesselation_style() -> DrawingSpec: return DrawingSpec(color=_GRAY, thickness=_THICKNESS_TESSELATION) +def get_default_face_mesh_iris_connections_style( +) -> Mapping[Tuple[int, int], DrawingSpec]: + """Returns the default face mesh iris connections drawing style. + + Returns: + A mapping from each iris connection to its default drawing spec. + """ + face_mesh_iris_connections_style = {} + left_spec = DrawingSpec(color=_GREEN, thickness=_THICKNESS_CONTOURS) + for connection in face_mesh_connections.FACEMESH_LEFT_IRIS: + face_mesh_iris_connections_style[connection] = left_spec + right_spec = DrawingSpec(color=_RED, thickness=_THICKNESS_CONTOURS) + for connection in face_mesh_connections.FACEMESH_RIGHT_IRIS: + face_mesh_iris_connections_style[connection] = right_spec + return face_mesh_iris_connections_style + + def get_default_pose_landmarks_style() -> Mapping[int, DrawingSpec]: """Returns the default pose landmarks drawing style. diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index 5576e606f..ea5d881cb 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -138,9 +138,12 @@ def draw_landmarks( landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand landmarks to the DrawingSpecs that specifies the landmarks' drawing settings such as color, line thickness, and circle radius. + If this argument is explicitly set to None, no landmarks will be drawn. connection_drawing_spec: Either a DrawingSpec object or a mapping from hand connections to the DrawingSpecs that specifies the connections' drawing settings such as color and line thickness. + If this argument is explicitly set to None, no landmark connections will + be drawn. Raises: ValueError: If one of the followings: diff --git a/mediapipe/python/solutions/face_detection.py b/mediapipe/python/solutions/face_detection.py index 6f62ae924..7d4da8fe9 100644 --- a/mediapipe/python/solutions/face_detection.py +++ b/mediapipe/python/solutions/face_detection.py @@ -28,8 +28,8 @@ from mediapipe.calculators.util import non_max_suppression_calculator_pb2 # pylint: enable=unused-import from mediapipe.python.solution_base import SolutionBase -SHORT_RANGE_GRAPH_FILE_PATH = 'mediapipe/modules/face_detection/face_detection_short_range_cpu.binarypb' -FULL_RANGE_GRAPH_FILE_PATH = 'mediapipe/modules/face_detection/face_detection_full_range_cpu.binarypb' +_SHORT_RANGE_GRAPH_FILE_PATH = 'mediapipe/modules/face_detection/face_detection_short_range_cpu.binarypb' +_FULL_RANGE_GRAPH_FILE_PATH = 'mediapipe/modules/face_detection/face_detection_full_range_cpu.binarypb' def get_key_point( @@ -83,7 +83,7 @@ class FaceDetection(SolutionBase): https://solutions.mediapipe.dev/face_detection#model_selection. """ - binary_graph_path = FULL_RANGE_GRAPH_FILE_PATH if model_selection == 1 else SHORT_RANGE_GRAPH_FILE_PATH + binary_graph_path = _FULL_RANGE_GRAPH_FILE_PATH if model_selection == 1 else _SHORT_RANGE_GRAPH_FILE_PATH subgraph_name = 'facedetectionfullrangecommon' if model_selection == 1 else 'facedetectionshortrangecommon' super().__init__( diff --git a/mediapipe/python/solutions/face_mesh.py b/mediapipe/python/solutions/face_mesh.py index f09ec6225..1fe9d91cc 100644 --- a/mediapipe/python/solutions/face_mesh.py +++ b/mediapipe/python/solutions/face_mesh.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MediaPipe FaceMesh.""" +"""MediaPipe Face Mesh.""" from typing import NamedTuple import numpy as np -from mediapipe.calculators.core import constant_side_packet_calculator_pb2 # pylint: disable=unused-import +from mediapipe.calculators.core import constant_side_packet_calculator_pb2 from mediapipe.calculators.core import gate_calculator_pb2 from mediapipe.calculators.core import split_vector_calculator_pb2 from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 @@ -30,6 +30,7 @@ from mediapipe.calculators.tensor import tensors_to_landmarks_calculator_pb2 from mediapipe.calculators.tflite import ssd_anchors_calculator_pb2 from mediapipe.calculators.util import association_calculator_pb2 from mediapipe.calculators.util import detections_to_rects_calculator_pb2 +from mediapipe.calculators.util import landmarks_refinement_calculator_pb2 from mediapipe.calculators.util import logic_calculator_pb2 from mediapipe.calculators.util import non_max_suppression_calculator_pb2 from mediapipe.calculators.util import rect_transformation_calculator_pb2 @@ -39,22 +40,26 @@ from mediapipe.python.solution_base import SolutionBase # pylint: disable=unused-import from mediapipe.python.solutions.face_mesh_connections import FACEMESH_CONTOURS from mediapipe.python.solutions.face_mesh_connections import FACEMESH_FACE_OVAL +from mediapipe.python.solutions.face_mesh_connections import FACEMESH_IRISES from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LEFT_EYE from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LEFT_EYEBROW +from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LEFT_IRIS from mediapipe.python.solutions.face_mesh_connections import FACEMESH_LIPS from mediapipe.python.solutions.face_mesh_connections import FACEMESH_RIGHT_EYE from mediapipe.python.solutions.face_mesh_connections import FACEMESH_RIGHT_EYEBROW +from mediapipe.python.solutions.face_mesh_connections import FACEMESH_RIGHT_IRIS from mediapipe.python.solutions.face_mesh_connections import FACEMESH_TESSELATION # pylint: enable=unused-import - -BINARYPB_FILE_PATH = 'mediapipe/modules/face_landmark/face_landmark_front_cpu.binarypb' +FACEMESH_NUM_LANDMARKS = 468 +FACEMESH_NUM_LANDMARKS_WITH_IRISES = 478 +_BINARYPB_FILE_PATH = 'mediapipe/modules/face_landmark/face_landmark_front_cpu.binarypb' class FaceMesh(SolutionBase): - """MediaPipe FaceMesh. + """MediaPipe Face Mesh. - MediaPipe FaceMesh processes an RGB image and returns the face landmarks on + MediaPipe Face Mesh processes an RGB image and returns the face landmarks on each detected face. Please refer to https://solutions.mediapipe.dev/face_mesh#python-solution-api @@ -64,9 +69,10 @@ class FaceMesh(SolutionBase): def __init__(self, static_image_mode=False, max_num_faces=1, + refine_landmarks=False, min_detection_confidence=0.5, min_tracking_confidence=0.5): - """Initializes a MediaPipe FaceMesh object. + """Initializes a MediaPipe Face Mesh object. Args: static_image_mode: Whether to treat the input images as a batch of static @@ -74,6 +80,10 @@ class FaceMesh(SolutionBase): https://solutions.mediapipe.dev/face_mesh#static_image_mode. max_num_faces: Maximum number of faces to detect. See details in https://solutions.mediapipe.dev/face_mesh#max_num_faces. + refine_landmarks: Whether to further refine the landmark coordinates + around the eyes and lips, and output additional landmarks around the + irises. Default to False. See details in + https://solutions.mediapipe.dev/face_mesh#refine_landmarks. min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for face detection to be considered successful. See details in https://solutions.mediapipe.dev/face_mesh#min_detection_confidence. @@ -82,16 +92,13 @@ class FaceMesh(SolutionBase): https://solutions.mediapipe.dev/face_mesh#min_tracking_confidence. """ super().__init__( - binary_graph_path=BINARYPB_FILE_PATH, + binary_graph_path=_BINARYPB_FILE_PATH, side_inputs={ 'num_faces': max_num_faces, + 'with_attention': refine_landmarks, + 'use_prev_landmarks': not static_image_mode, }, calculator_params={ - 'ConstantSidePacketCalculator.packet': [ - constant_side_packet_calculator_pb2 - .ConstantSidePacketCalculatorOptions.ConstantSidePacket( - bool_value=not static_image_mode) - ], 'facedetectionshortrangecpu__facedetectionshortrangecommon__TensorsToDetectionsCalculator.min_score_thresh': min_detection_confidence, 'facelandmarkcpu__ThresholdingCalculator.threshold': diff --git a/mediapipe/python/solutions/face_mesh_connections.py b/mediapipe/python/solutions/face_mesh_connections.py index 0980c2259..1ebd541df 100644 --- a/mediapipe/python/solutions/face_mesh_connections.py +++ b/mediapipe/python/solutions/face_mesh_connections.py @@ -29,6 +29,9 @@ FACEMESH_LEFT_EYE = frozenset([(263, 249), (249, 390), (390, 373), (373, 374), (263, 466), (466, 388), (388, 387), (387, 386), (386, 385), (385, 384), (384, 398), (398, 362)]) +FACEMESH_LEFT_IRIS = frozenset([(474, 475), (475, 476), (476, 477), + (477, 474)]) + FACEMESH_LEFT_EYEBROW = frozenset([(276, 283), (283, 282), (282, 295), (295, 285), (300, 293), (293, 334), (334, 296), (296, 336)]) @@ -41,6 +44,9 @@ FACEMESH_RIGHT_EYE = frozenset([(33, 7), (7, 163), (163, 144), (144, 145), FACEMESH_RIGHT_EYEBROW = frozenset([(46, 53), (53, 52), (52, 65), (65, 55), (70, 63), (63, 105), (105, 66), (66, 107)]) +FACEMESH_RIGHT_IRIS = frozenset([(469, 470), (470, 471), (471, 472), + (472, 469)]) + FACEMESH_FACE_OVAL = frozenset([(10, 338), (338, 297), (297, 332), (332, 284), (284, 251), (251, 389), (389, 356), (356, 454), (454, 323), (323, 361), (361, 288), (288, 397), @@ -56,6 +62,8 @@ FACEMESH_CONTOURS = frozenset().union(*[ FACEMESH_RIGHT_EYEBROW, FACEMESH_FACE_OVAL ]) +FACEMESH_IRISES = frozenset().union(*[FACEMESH_LEFT_IRIS, FACEMESH_RIGHT_IRIS]) + FACEMESH_TESSELATION = frozenset([ (127, 34), (34, 139), (139, 127), (11, 0), (0, 37), (37, 11), (232, 231), (231, 120), (120, 232), (72, 37), (37, 39), (39, 72), diff --git a/mediapipe/python/solutions/face_mesh_test.py b/mediapipe/python/solutions/face_mesh_test.py index be875dd9c..263bb2c0b 100644 --- a/mediapipe/python/solutions/face_mesh_test.py +++ b/mediapipe/python/solutions/face_mesh_test.py @@ -67,10 +67,24 @@ EYE_INDICES_TO_LANDMARKS = { 398: [432, 175] } +IRIS_INDICES_TO_LANDMARKS = { + 468: [362, 175], + 469: [371, 175], + 470: [362, 167], + 471: [354, 175], + 472: [363, 182], + 473: [449, 174], + 474: [458, 174], + 475: [449, 167], + 476: [440, 174], + 477: [449, 181] +} + class FaceMeshTest(parameterized.TestCase): - def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int): + def _annotate(self, frame: np.ndarray, results: NamedTuple, idx: int, + draw_iris: bool): for face_landmarks in results.multi_face_landmarks: mp_drawing.draw_landmarks( frame, @@ -86,6 +100,14 @@ class FaceMeshTest(parameterized.TestCase): landmark_drawing_spec=None, connection_drawing_spec=drawing_styles .get_default_face_mesh_contours_style()) + if draw_iris: + mp_drawing.draw_landmarks( + frame, + face_landmarks, + mp_faces.FACEMESH_IRISES, + landmark_drawing_spec=None, + connection_drawing_spec=drawing_styles + .get_default_face_mesh_iris_connections_style()) path = os.path.join(tempfile.gettempdir(), self.id().split('.')[-1] + '_frame_{}.png'.format(idx)) cv2.imwrite(path, frame) @@ -103,22 +125,29 @@ class FaceMeshTest(parameterized.TestCase): results = faces.process(image) self.assertIsNone(results.multi_face_landmarks) - @parameterized.named_parameters(('static_image_mode', True, 1), - ('video_mode', False, 5)) - def test_face(self, static_image_mode: bool, num_frames: int): + @parameterized.named_parameters( + ('static_image_mode_no_attention', True, False, 5), + ('static_image_mode_with_attention', True, True, 5), + ('streaming_mode_no_attention', False, False, 10), + ('streaming_mode_with_attention', False, True, 10)) + def test_face(self, static_image_mode: bool, refine_landmarks: bool, + num_frames: int): image_path = os.path.join(os.path.dirname(__file__), 'testdata/portrait.jpg') image = cv2.imread(image_path) rows, cols, _ = image.shape with mp_faces.FaceMesh( static_image_mode=static_image_mode, + refine_landmarks=refine_landmarks, min_detection_confidence=0.5) as faces: for idx in range(num_frames): results = faces.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._annotate(image.copy(), results, idx) + self._annotate(image.copy(), results, idx, refine_landmarks) multi_face_landmarks = [] for landmarks in results.multi_face_landmarks: - self.assertLen(landmarks.landmark, 468) + self.assertLen( + landmarks.landmark, mp_faces.FACEMESH_NUM_LANDMARKS_WITH_IRISES + if refine_landmarks else mp_faces.FACEMESH_NUM_LANDMARKS) x = [landmark.x * cols for landmark in landmarks.landmark] y = [landmark.y * rows for landmark in landmarks.landmark] face_landmarks = np.column_stack((x, y)) @@ -129,6 +158,12 @@ class FaceMeshTest(parameterized.TestCase): prediction_error = np.abs( np.asarray(multi_face_landmarks[0][eye_idx]) - np.asarray(gt_lds)) npt.assert_array_less(prediction_error, DIFF_THRESHOLD) + if refine_landmarks: + for iris_idx, gt_lds in IRIS_INDICES_TO_LANDMARKS.items(): + prediction_error = np.abs( + np.asarray(multi_face_landmarks[0][iris_idx]) - + np.asarray(gt_lds)) + npt.assert_array_less(prediction_error, DIFF_THRESHOLD) if __name__ == '__main__': diff --git a/mediapipe/python/solutions/hands.py b/mediapipe/python/solutions/hands.py index 01578a0cb..08f2d7340 100644 --- a/mediapipe/python/solutions/hands.py +++ b/mediapipe/python/solutions/hands.py @@ -19,8 +19,8 @@ from typing import NamedTuple import numpy as np -from mediapipe.calculators.core import constant_side_packet_calculator_pb2 # pylint: disable=unused-import +from mediapipe.calculators.core import constant_side_packet_calculator_pb2 from mediapipe.calculators.core import gate_calculator_pb2 from mediapipe.calculators.core import split_vector_calculator_pb2 from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 @@ -67,7 +67,7 @@ class HandLandmark(enum.IntEnum): PINKY_TIP = 20 -BINARYPB_FILE_PATH = 'mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.binarypb' +_BINARYPB_FILE_PATH = 'mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.binarypb' class Hands(SolutionBase): @@ -107,16 +107,12 @@ class Hands(SolutionBase): https://solutions.mediapipe.dev/hands#min_tracking_confidence. """ super().__init__( - binary_graph_path=BINARYPB_FILE_PATH, + binary_graph_path=_BINARYPB_FILE_PATH, side_inputs={ 'num_hands': max_num_hands, + 'use_prev_landmarks': not static_image_mode, }, calculator_params={ - 'ConstantSidePacketCalculator.packet': [ - constant_side_packet_calculator_pb2 - .ConstantSidePacketCalculatorOptions.ConstantSidePacket( - bool_value=not static_image_mode) - ], 'palmdetectioncpu__TensorsToDetectionsCalculator.min_score_thresh': min_detection_confidence, 'handlandmarkcpu__ThresholdingCalculator.threshold': diff --git a/mediapipe/python/solutions/hands_test.py b/mediapipe/python/solutions/hands_test.py index cdab95e84..113992aea 100644 --- a/mediapipe/python/solutions/hands_test.py +++ b/mediapipe/python/solutions/hands_test.py @@ -32,20 +32,20 @@ from mediapipe.python.solutions import hands as mp_hands TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -DIFF_THRESHOLD = 15 # pixels -EXPECTED_HAND_COORDINATES_PREDICTION = [[[144, 345], [211, 323], [257, 286], +DIFF_THRESHOLD = 20 # pixels +EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286], [289, 237], [322, 203], [219, 216], [238, 138], [249, 90], [253, 51], [177, 204], [184, 115], [187, 60], [185, 19], [138, 208], [131, 127], [124, 77], [117, 36], [106, 222], [92, 159], [79, 124], [68, 93]], - [[577, 37], [504, 56], [459, 94], - [429, 146], [397, 182], [496, 167], + [[580, 36], [504, 50], [459, 94], + [429, 146], [397, 182], [507, 167], [479, 245], [469, 292], [464, 330], - [540, 177], [534, 265], [533, 319], + [545, 180], [534, 265], [533, 319], [536, 360], [581, 172], [587, 252], - [593, 304], [599, 346], [615, 157], + [593, 304], [599, 346], [615, 168], [628, 223], [638, 258], [648, 288]]] diff --git a/mediapipe/python/solutions/holistic.py b/mediapipe/python/solutions/holistic.py index 85313d8cf..70ce491c2 100644 --- a/mediapipe/python/solutions/holistic.py +++ b/mediapipe/python/solutions/holistic.py @@ -17,10 +17,10 @@ from typing import NamedTuple import numpy as np -from mediapipe.calculators.core import constant_side_packet_calculator_pb2 # The following imports are needed because python pb2 silently discards # unknown protobuf fields. # pylint: disable=unused-import +from mediapipe.calculators.core import constant_side_packet_calculator_pb2 from mediapipe.calculators.core import gate_calculator_pb2 from mediapipe.calculators.core import split_vector_calculator_pb2 from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 @@ -49,7 +49,7 @@ from mediapipe.python.solutions.pose import PoseLandmark from mediapipe.python.solutions.pose_connections import POSE_CONNECTIONS # pylint: enable=unused-import -BINARYPB_FILE_PATH = 'mediapipe/modules/holistic_landmark/holistic_landmark_cpu.binarypb' +_BINARYPB_FILE_PATH = 'mediapipe/modules/holistic_landmark/holistic_landmark_cpu.binarypb' def _download_oss_pose_landmark_model(model_complexity): @@ -78,6 +78,8 @@ class Holistic(SolutionBase): static_image_mode=False, model_complexity=1, smooth_landmarks=True, + enable_segmentation=False, + smooth_segmentation=True, min_detection_confidence=0.5, min_tracking_confidence=0.5): """Initializes a MediaPipe Holistic object. @@ -91,6 +93,11 @@ class Holistic(SolutionBase): smooth_landmarks: Whether to filter landmarks across different input images to reduce jitter. See details in https://solutions.mediapipe.dev/holistic#smooth_landmarks. + enable_segmentation: Whether to predict segmentation mask. See details in + https://solutions.mediapipe.dev/holistic#enable_segmentation. + smooth_segmentation: Whether to filter segmentation across different input + images to reduce jitter. See details in + https://solutions.mediapipe.dev/holistic#smooth_segmentation. min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for person detection to be considered successful. See details in https://solutions.mediapipe.dev/holistic#min_detection_confidence. @@ -100,18 +107,16 @@ class Holistic(SolutionBase): """ _download_oss_pose_landmark_model(model_complexity) super().__init__( - binary_graph_path=BINARYPB_FILE_PATH, + binary_graph_path=_BINARYPB_FILE_PATH, side_inputs={ 'model_complexity': model_complexity, 'smooth_landmarks': smooth_landmarks and not static_image_mode, - 'smooth_segmentation': not static_image_mode, + 'enable_segmentation': enable_segmentation, + 'smooth_segmentation': + smooth_segmentation and not static_image_mode, + 'use_prev_landmarks': not static_image_mode, }, calculator_params={ - 'poselandmarkcpu__ConstantSidePacketCalculator.packet': [ - constant_side_packet_calculator_pb2 - .ConstantSidePacketCalculatorOptions.ConstantSidePacket( - bool_value=not static_image_mode) - ], 'poselandmarkcpu__posedetectioncpu__TensorsToDetectionsCalculator.min_score_thresh': min_detection_confidence, 'poselandmarkcpu__poselandmarkbyroicpu__tensorstoposelandmarksandsegmentation__ThresholdingCalculator.threshold': @@ -119,7 +124,7 @@ class Holistic(SolutionBase): }, outputs=[ 'pose_landmarks', 'pose_world_landmarks', 'left_hand_landmarks', - 'right_hand_landmarks', 'face_landmarks' + 'right_hand_landmarks', 'face_landmarks', 'segmentation_mask' ]) def process(self, image: np.ndarray) -> NamedTuple: @@ -133,8 +138,8 @@ class Holistic(SolutionBase): ValueError: If the input image is not three channel RGB. Returns: - A NamedTuple that has five fields describing the landmarks on the most - prominate person detected: + A NamedTuple with fields describing the landmarks on the most prominate + person detected: 1) "pose_landmarks" field that contains the pose landmarks. 2) "pose_world_landmarks" field that contains the pose landmarks in real-world 3D coordinates that are in meters with the origin at the @@ -142,6 +147,8 @@ class Holistic(SolutionBase): 3) "left_hand_landmarks" field that contains the left-hand landmarks. 4) "right_hand_landmarks" field that contains the right-hand landmarks. 5) "face_landmarks" field that contains the face landmarks. + 6) "segmentation_mask" field that contains the segmentation mask if + "enable_segmentation" is set to true. """ results = super().process(input_data={'image': image}) diff --git a/mediapipe/python/solutions/objectron.py b/mediapipe/python/solutions/objectron.py index 195c2b8c7..28cc026aa 100644 --- a/mediapipe/python/solutions/objectron.py +++ b/mediapipe/python/solutions/objectron.py @@ -20,8 +20,8 @@ from typing import List, Tuple, NamedTuple, Optional import attr import numpy as np -from mediapipe.calculators.core import constant_side_packet_calculator_pb2 # pylint: disable=unused-import +from mediapipe.calculators.core import constant_side_packet_calculator_pb2 from mediapipe.calculators.core import gate_calculator_pb2 from mediapipe.calculators.core import split_vector_calculator_pb2 from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 @@ -75,7 +75,7 @@ class BoxLandmark(enum.IntEnum): BACK_TOP_RIGHT = 7 FRONT_TOP_RIGHT = 8 -BINARYPB_FILE_PATH = 'mediapipe/modules/objectron/objectron_cpu.binarypb' +_BINARYPB_FILE_PATH = 'mediapipe/modules/objectron/objectron_cpu.binarypb' BOX_CONNECTIONS = frozenset([ (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.FRONT_BOTTOM_LEFT), (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.BACK_TOP_LEFT), @@ -216,18 +216,14 @@ class Objectron(SolutionBase): # Create and init model. model = get_model_by_name(model_name) super().__init__( - binary_graph_path=BINARYPB_FILE_PATH, + binary_graph_path=_BINARYPB_FILE_PATH, side_inputs={ 'box_landmark_model_path': model.model_path, 'allowed_labels': model.label_name, 'max_num_objects': max_num_objects, + 'use_prev_landmarks': not static_image_mode, }, calculator_params={ - 'ConstantSidePacketCalculator.packet': [ - constant_side_packet_calculator_pb2 - .ConstantSidePacketCalculatorOptions.ConstantSidePacket( - bool_value=not static_image_mode) - ], ('objectdetectionoidv4subgraph' '__TensorsToDetectionsCalculator.min_score_thresh'): min_detection_confidence, diff --git a/mediapipe/python/solutions/pose.py b/mediapipe/python/solutions/pose.py index 7788d4700..d4b499faa 100644 --- a/mediapipe/python/solutions/pose.py +++ b/mediapipe/python/solutions/pose.py @@ -19,10 +19,10 @@ from typing import NamedTuple import numpy as np -from mediapipe.calculators.core import constant_side_packet_calculator_pb2 # The following imports are needed because python pb2 silently discards # unknown protobuf fields. # pylint: disable=unused-import +from mediapipe.calculators.core import constant_side_packet_calculator_pb2 from mediapipe.calculators.core import gate_calculator_pb2 from mediapipe.calculators.core import split_vector_calculator_pb2 from mediapipe.calculators.image import warp_affine_calculator_pb2 @@ -87,7 +87,7 @@ class PoseLandmark(enum.IntEnum): RIGHT_FOOT_INDEX = 32 -BINARYPB_FILE_PATH = 'mediapipe/modules/pose_landmark/pose_landmark_cpu.binarypb' +_BINARYPB_FILE_PATH = 'mediapipe/modules/pose_landmark/pose_landmark_cpu.binarypb' def _download_oss_pose_landmark_model(model_complexity): @@ -144,20 +144,16 @@ class Pose(SolutionBase): """ _download_oss_pose_landmark_model(model_complexity) super().__init__( - binary_graph_path=BINARYPB_FILE_PATH, + binary_graph_path=_BINARYPB_FILE_PATH, side_inputs={ 'model_complexity': model_complexity, 'smooth_landmarks': smooth_landmarks and not static_image_mode, 'enable_segmentation': enable_segmentation, 'smooth_segmentation': smooth_segmentation and not static_image_mode, + 'use_prev_landmarks': not static_image_mode, }, calculator_params={ - 'ConstantSidePacketCalculator.packet': [ - constant_side_packet_calculator_pb2 - .ConstantSidePacketCalculatorOptions.ConstantSidePacket( - bool_value=not static_image_mode) - ], 'posedetectioncpu__TensorsToDetectionsCalculator.min_score_thresh': min_detection_confidence, 'poselandmarkbyroicpu__tensorstoposelandmarksandsegmentation__ThresholdingCalculator.threshold': @@ -176,12 +172,14 @@ class Pose(SolutionBase): ValueError: If the input image is not three channel RGB. Returns: - A NamedTuple that has two fields describing the landmarks on the most - prominate person detected: + A NamedTuple with fields describing the landmarks on the most prominate + person detected: 1) "pose_landmarks" field that contains the pose landmarks. 2) "pose_world_landmarks" field that contains the pose landmarks in real-world 3D coordinates that are in meters with the origin at the center between hips. + 3) "segmentation_mask" field that contains the segmentation mask if + "enable_segmentation" is set to true. """ results = super().process(input_data={'image': image}) diff --git a/mediapipe/python/solutions/selfie_segmentation.py b/mediapipe/python/solutions/selfie_segmentation.py index 8aa07569c..1334e9f13 100644 --- a/mediapipe/python/solutions/selfie_segmentation.py +++ b/mediapipe/python/solutions/selfie_segmentation.py @@ -29,7 +29,7 @@ from mediapipe.framework.tool import switch_container_pb2 from mediapipe.python.solution_base import SolutionBase -BINARYPB_FILE_PATH = 'mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.binarypb' +_BINARYPB_FILE_PATH = 'mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.binarypb' class SelfieSegmentation(SolutionBase): @@ -52,7 +52,7 @@ class SelfieSegmentation(SolutionBase): https://solutions.mediapipe.dev/selfie_segmentation#model_selection. """ super().__init__( - binary_graph_path=BINARYPB_FILE_PATH, + binary_graph_path=_BINARYPB_FILE_PATH, side_inputs={ 'model_selection': model_selection, }, diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 8c9687723..3c0f1f35c 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -34,10 +34,12 @@ cc_library( hdrs = ["cpu_op_resolver.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", + "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", "//mediapipe/util/tflite/operations:max_unpooling", + "//mediapipe/util/tflite/operations:transform_landmarks", + "//mediapipe/util/tflite/operations:transform_tensor_bilinear", "//mediapipe/util/tflite/operations:transpose_conv_bias", "@org_tensorflow//tensorflow/lite:builtin_op_data", "@org_tensorflow//tensorflow/lite:framework", diff --git a/mediapipe/util/tflite/cpu_op_resolver.cc b/mediapipe/util/tflite/cpu_op_resolver.cc index 935bed08b..588a237b2 100644 --- a/mediapipe/util/tflite/cpu_op_resolver.cc +++ b/mediapipe/util/tflite/cpu_op_resolver.cc @@ -15,8 +15,11 @@ #include "mediapipe/util/tflite/cpu_op_resolver.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" #include "mediapipe/util/tflite/operations/max_unpooling.h" +#include "mediapipe/util/tflite/operations/transform_landmarks.h" +#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h" #include "mediapipe/util/tflite/operations/transpose_conv_bias.h" #include "tensorflow/lite/builtin_op_data.h" #include "tensorflow/lite/mutable_op_resolver.h" @@ -31,6 +34,16 @@ void MediaPipe_RegisterTfLiteOpResolver(tflite::MutableOpResolver *resolver) { tflite_operations::RegisterMaxUnpooling2D()); resolver->AddCustom("Convolution2DTransposeBias", tflite_operations::RegisterConvolution2DTransposeBias()); + + resolver->AddCustom("TransformTensorBilinear", + tflite_operations::RegisterTransformTensorBilinearV2(), + /*version=*/2); + resolver->AddCustom("TransformLandmarks", + tflite_operations::RegisterTransformLandmarksV2(), + /*version=*/2); + resolver->AddCustom("Landmarks2TransformMatrix", + tflite_operations::RegisterLandmarksToTransformMatrixV2(), + /*version=*/2); } } // namespace mediapipe diff --git a/mediapipe/util/tflite/operations/BUILD b/mediapipe/util/tflite/operations/BUILD index aa4691b9a..e902906c9 100644 --- a/mediapipe/util/tflite/operations/BUILD +++ b/mediapipe/util/tflite/operations/BUILD @@ -22,6 +22,24 @@ package(default_visibility = [ "//learning/brain/models/app_benchmarks/camera_models:__subpackages__", ]) +CUSTOM_OPS_PACKAGE = "@org_tensorflow//tensorflow/lite/delegates/gpu/common/mediapipe" + +cc_library( + name = "landmarks_to_transform_matrix", + srcs = ["landmarks_to_transform_matrix.cc"], + hdrs = ["landmarks_to_transform_matrix.h"], + deps = [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels:padding", + "@org_tensorflow//tensorflow/lite/kernels/internal:common", + "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + CUSTOM_OPS_PACKAGE + ":landmarks_to_transform_matrix", + ], +) + cc_library( name = "max_pool_argmax", srcs = ["max_pool_argmax.cc"], @@ -48,6 +66,38 @@ cc_library( ], ) +cc_library( + name = "transform_landmarks", + srcs = ["transform_landmarks.cc"], + hdrs = ["transform_landmarks.h"], + deps = [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels:padding", + "@org_tensorflow//tensorflow/lite/kernels/internal:common", + "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + CUSTOM_OPS_PACKAGE + ":transform_landmarks", + ], +) + +cc_library( + name = "transform_tensor_bilinear", + srcs = ["transform_tensor_bilinear.cc"], + hdrs = ["transform_tensor_bilinear.h"], + deps = [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels:padding", + "@org_tensorflow//tensorflow/lite/kernels/internal:common", + "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + CUSTOM_OPS_PACKAGE + ":transform_tensor_bilinear", + ], +) + cc_library( name = "transpose_conv_bias", srcs = ["transpose_conv_bias.cc"], diff --git a/third_party/org_tensorflow_custom_ops.diff b/third_party/org_tensorflow_custom_ops.diff new file mode 100644 index 000000000..4d92ba95c --- /dev/null +++ b/third_party/org_tensorflow_custom_ops.diff @@ -0,0 +1,3056 @@ +diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD +index c49f2ce731d..d72773c0a5b 100644 +--- a/tensorflow/lite/delegates/gpu/common/BUILD ++++ b/tensorflow/lite/delegates/gpu/common/BUILD +@@ -173,7 +173,7 @@ cc_library( + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", +- ] + tf_platform_alias("custom_parsers", "//tensorflow/lite/delegates/gpu/common/"), ++ ] + ["//tensorflow/lite/delegates/gpu/common/mediapipe:custom_parsers"], + ) + + cc_test( +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/BUILD b/tensorflow/lite/delegates/gpu/common/mediapipe/BUILD +new file mode 100644 +index 00000000000..58967ddbb66 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/BUILD +@@ -0,0 +1,93 @@ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++cc_library( ++ name = "custom_parsers", ++ srcs = ["custom_parsers.cc"], ++ hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_parsers.h"], ++ deps = [ ++ ":landmarks_to_transform_matrix", ++ ":transform_landmarks", ++ ":transform_tensor_bilinear", ++ "//tensorflow/lite/delegates/gpu/common:operation_parser", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:unimplemented_operation_parser", ++ "@com_google_absl//absl/memory", ++ "@com_google_absl//absl/strings", ++ "@com_google_absl//absl/types:any", ++ ], ++) ++ ++cc_library( ++ name = "custom_transformations", ++ srcs = ["custom_transformations.cc"], ++ hdrs = ["//tensorflow/lite/delegates/gpu/common:custom_transformations.h"], ++ deps = [ ++ ":landmarks_to_transform_matrix", ++ ":transform_landmarks", ++ ":transform_tensor_bilinear", ++ "//tensorflow/lite/delegates/gpu/common:model_transformer", ++ "@com_google_absl//absl/memory", ++ ], ++) ++ ++cc_library( ++ name = "landmarks_to_transform_matrix", ++ srcs = ["landmarks_to_transform_matrix.cc"], ++ hdrs = ["landmarks_to_transform_matrix.h"], ++ deps = [ ++ "//tensorflow/lite/c:common", ++ "//tensorflow/lite/delegates/gpu/common:model", ++ "//tensorflow/lite/delegates/gpu/common:model_builder_helper", ++ "//tensorflow/lite/delegates/gpu/common:model_transformer", ++ "//tensorflow/lite/delegates/gpu/common:object_reader", ++ "//tensorflow/lite/delegates/gpu/common:operation_parser", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:tensor", ++ "//tensorflow/lite/delegates/gpu/common:types", ++ "@com_google_absl//absl/types:any", ++ "@flatbuffers", ++ ], ++) ++ ++cc_library( ++ name = "transform_landmarks", ++ srcs = ["transform_landmarks.cc"], ++ hdrs = ["transform_landmarks.h"], ++ deps = [ ++ "//tensorflow/lite/c:common", ++ "//tensorflow/lite/delegates/gpu/common:model", ++ "//tensorflow/lite/delegates/gpu/common:model_builder_helper", ++ "//tensorflow/lite/delegates/gpu/common:model_transformer", ++ "//tensorflow/lite/delegates/gpu/common:object_reader", ++ "//tensorflow/lite/delegates/gpu/common:operation_parser", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:tensor", ++ "@com_google_absl//absl/types:any", ++ "@flatbuffers", ++ ], ++) ++ ++cc_library( ++ name = "transform_tensor_bilinear", ++ srcs = ["transform_tensor_bilinear.cc"], ++ hdrs = ["transform_tensor_bilinear.h"], ++ deps = [ ++ "//tensorflow/lite/c:common", ++ "//tensorflow/lite/delegates/gpu/common:model", ++ "//tensorflow/lite/delegates/gpu/common:model_builder_helper", ++ "//tensorflow/lite/delegates/gpu/common:model_transformer", ++ "//tensorflow/lite/delegates/gpu/common:object_reader", ++ "//tensorflow/lite/delegates/gpu/common:operation_parser", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:tensor", ++ "@com_google_absl//absl/types:any", ++ "@flatbuffers", ++ ], ++) +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/custom_parsers.cc b/tensorflow/lite/delegates/gpu/common/mediapipe/custom_parsers.cc +new file mode 100644 +index 00000000000..52c11b90fc8 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/custom_parsers.cc +@@ -0,0 +1,34 @@ ++#include "tensorflow/lite/delegates/gpu/common/custom_parsers.h" ++ ++#include ++#include ++ ++#include "absl/memory/memory.h" ++#include "absl/strings/string_view.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/unimplemented_operation_parser.h" ++ ++namespace tflite { ++namespace gpu { ++ ++std::unique_ptr NewCustomOperationParser( ++ absl::string_view op_name) { ++ if (op_name == "Landmarks2TransformMatrix" || ++ op_name == "Landmarks2TransformMatrixV2") { ++ return std::make_unique(); ++ } ++ if (op_name == "TransformLandmarks") { ++ return std::make_unique(); ++ } ++ if (op_name == "TransformTensor" /*for version 1*/ || ++ op_name == "TransformTensorBilinear" /*for version 2*/) { ++ return std::make_unique(); ++ } ++ return absl::make_unique(op_name); ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/custom_transformations.cc b/tensorflow/lite/delegates/gpu/common/mediapipe/custom_transformations.cc +new file mode 100644 +index 00000000000..1509ea3bcf3 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/custom_transformations.cc +@@ -0,0 +1,24 @@ ++#include "tensorflow/lite/delegates/gpu/common/custom_transformations.h" ++ ++#include "absl/memory/memory.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h" ++#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" ++ ++namespace tflite { ++namespace gpu { ++bool ApplyCustomTransformations(ModelTransformer* transformer) { ++ return transformer->Apply( ++ "transform_landmarks_v2_to_v1", ++ absl::make_unique().get()) && ++ transformer->Apply( ++ "transform_tensor_bilinear_v2_to_v1", ++ absl::make_unique().get()) && ++ transformer->Apply( ++ "landmarks_to_transform_matrix_v2_with_mul", ++ absl::make_unique() ++ .get()); ++} ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.cc b/tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.cc +new file mode 100644 +index 00000000000..4e73cf649e6 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.cc +@@ -0,0 +1,182 @@ ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h" ++ ++#include ++#include ++#include ++ ++#include "absl/types/any.h" ++#include "flatbuffers/flexbuffers.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" ++#include "tensorflow/lite/delegates/gpu/common/object_reader.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/tensor.h" ++#include "tensorflow/lite/delegates/gpu/common/types.h" ++ ++namespace tflite { ++namespace gpu { ++ ++absl::Status LandmarksToTransformMatrixOperationParser::IsSupported( ++ const TfLiteContext* context, const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration) { ++ RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); ++ return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, ++ /*outputs=*/1); ++} ++ ++absl::Status LandmarksToTransformMatrixOperationParser::Parse( ++ const TfLiteNode* tflite_node, const TfLiteRegistration* registration, ++ GraphFloat32* graph, ObjectReader* reader) { ++ Node* node = graph->NewNode(); ++ RETURN_IF_ERROR(reader->AddInput(node, 0)); // landmarks ++ RETURN_IF_ERROR(reader->AddOutputs(node)); // transform matrix ++ ++ node->operation.type = kLandmarksToTransformMatrixType; ++ BHWC output_shape; ++ if (registration->version == 2) { ++ LandmarksToTransformMatrixV2Attributes attr; ++ RETURN_IF_ERROR(ParseLandmarksToTransformMatrixV2Attributes( ++ tflite_node->custom_initial_data, tflite_node->custom_initial_data_size, ++ &attr, &output_shape)); ++ node->operation.attributes = attr; ++ } else if (registration->version == 1) { ++ LandmarksToTransformMatrixV1Attributes attr; ++ RETURN_IF_ERROR(ParseLandmarksToTransformMatrixV1Attributes( ++ tflite_node->custom_initial_data, tflite_node->custom_initial_data_size, ++ &attr, &output_shape)); ++ node->operation.attributes = attr; ++ } else { ++ return absl::UnimplementedError( ++ "Landmarks To Transform Matrix operation can be of version 1 or 2 " ++ "only."); ++ } ++ ++ auto output_value = graph->FindOutputs(node->id)[0]; ++ output_value->tensor.shape = output_shape; ++ return absl::OkStatus(); ++} ++ ++absl::Status ParseLandmarksToTransformMatrixV1Attributes( ++ const void* data, uint32_t data_size, ++ LandmarksToTransformMatrixV1Attributes* attr, BHWC* output_shape) { ++ const flexbuffers::Map m = ++ flexbuffers::GetRoot(reinterpret_cast(data), data_size) ++ .AsMap(); ++ ++ const auto input_hw = m["input_hw"].AsTypedVector(); ++ attr->input_hw = HW(input_hw[0].AsInt32(), input_hw[1].AsInt32()); ++ ++ const auto output_hw = m["output_hw"].AsTypedVector(); ++ attr->output_hw = HW(output_hw[0].AsInt32(), output_hw[1].AsInt32()); ++ ++ attr->dimensions = m["dimensions"].AsInt32(); ++ attr->landmarks_range = m["landmarks_range"].AsInt32(); ++ attr->bbox_size_multiplier = m["bbox_size_multiplier"].AsFloat(); ++ attr->left_rotation_idx = m["left_rotation_idx"].AsInt32(); ++ attr->right_rotation_idx = m["right_rotation_idx"].AsInt32(); ++ ++ const auto subset = m["subset"].AsTypedVector(); ++ for (int i = 0; i < subset.size() / 2; i++) { ++ attr->subset.emplace_back(subset[i * 2].AsInt32(), ++ subset[i * 2 + 1].AsInt32()); ++ } ++ if (subset.size() % 2 != 0) { ++ attr->subset.emplace_back(subset[subset.size() - 1].AsInt32(), ++ subset[subset.size() - 1].AsInt32()); ++ } ++ *output_shape = BHWC(1, 1, 4, 4); ++ return absl::OkStatus(); ++} ++ ++absl::Status ParseLandmarksToTransformMatrixV2Attributes( ++ const void* data, uint32_t data_size, ++ LandmarksToTransformMatrixV2Attributes* attr, BHWC* output_shape) { ++ const flexbuffers::Map m = ++ flexbuffers::GetRoot(reinterpret_cast(data), data_size) ++ .AsMap(); ++ const auto subset_idxs = m["subset_idxs"].AsTypedVector(); ++ int amount = subset_idxs.size(); ++ for (int i = 0; i < amount / 2; i++) { ++ attr->subset_idxs.emplace_back(subset_idxs[i * 2].AsInt32(), ++ subset_idxs[i * 2 + 1].AsInt32()); ++ } ++ if (amount % 2 != 0) { ++ int previous = amount - 1; ++ attr->subset_idxs.emplace_back(subset_idxs[previous].AsInt32(), ++ subset_idxs[previous].AsInt32()); ++ } ++ attr->left_rotation_idx = m["left_rotation_idx"].AsInt32(); ++ attr->right_rotation_idx = m["right_rotation_idx"].AsInt32(); ++ attr->target_rotation_radians = m["target_rotation_radians"].AsFloat(); ++ attr->output_height = m["output_height"].AsInt32(); ++ attr->output_width = m["output_width"].AsInt32(); ++ attr->scale_x = m["scale_x"].AsFloat(); ++ attr->scale_y = m["scale_y"].AsFloat(); ++ ++ *output_shape = BHWC(1, 1, 4, 4); ++ return absl::OkStatus(); ++} ++ ++TransformResult LandmarksToTransformMatrixV2ToV2WithMul::ApplyToNode( ++ Node* node, GraphFloat32* graph) { ++ // Recognize Landmarks2TransformMatrix.v2 as a root operation of this ++ // transformation. ++ if (node->operation.type != kLandmarksToTransformMatrixType) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ auto* landmarks2tm_attr = ++ absl::any_cast( ++ &node->operation.attributes); ++ if (!landmarks2tm_attr) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ auto node_inputs = graph->FindInputs(node->id); ++ if (node_inputs.size() != 1) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ // Recognize preeceding scalar Mul operation and save the value. ++ auto mul = graph->FindProducer(node_inputs[0]->id); ++ if (mul->operation.type != ToString(OperationType::MUL)) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ const auto& mul_attr = ++ absl::any_cast(mul->operation.attributes); ++ float scalar = 0.0; ++ if (!absl::holds_alternative(mul_attr.param)) { ++ return {TransformStatus::SKIPPED, ""}; ++ } else { ++ scalar = absl::get(mul_attr.param); ++ } ++ auto mul_inputs = graph->FindInputs(mul->id); ++ if (mul_inputs.size() != 1) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ // Recognize preceding reshape. ++ auto reshape = graph->FindProducer(mul_inputs[0]->id); ++ if (reshape->operation.type != ToString(OperationType::RESHAPE)) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ // Start modifying the graph. ++ { ++ absl::Status status = RemoveSimpleNodeKeepInput(graph, reshape); ++ if (!status.ok()) { ++ return {TransformStatus::INVALID, ++ "Unable to remove a node: " + std::string(status.message())}; ++ } ++ } ++ { ++ absl::Status status = RemoveSimpleNodeKeepInput(graph, mul); ++ if (!status.ok()) { ++ return {TransformStatus::INVALID, ++ "Unable to remove a node: " + std::string(status.message())}; ++ } ++ } ++ // Update LandmarksToTransformMatrix attributes with a stored multiplier. ++ landmarks2tm_attr->multiplier = scalar; ++ return {TransformStatus::APPLIED, ""}; ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h b/tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h +new file mode 100644 +index 00000000000..78c72aea123 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h +@@ -0,0 +1,96 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_LANDMARKS_TO_TRANSFORM_MATRIX_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_LANDMARKS_TO_TRANSFORM_MATRIX_H_ ++ ++#include ++#include ++ ++#include "absl/types/any.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" ++#include "tensorflow/lite/delegates/gpu/common/object_reader.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/types.h" ++ ++namespace tflite { ++namespace gpu { ++ ++constexpr const char kLandmarksToTransformMatrixType[] = ++ "landmarks_to_transform_matrix"; ++ ++struct LandmarksToTransformMatrixV1Attributes { ++ int dimensions; ++ int landmarks_range; ++ int left_rotation_idx; ++ int right_rotation_idx; ++ float bbox_size_multiplier; ++ HW input_hw; ++ HW output_hw; ++ std::vector subset; ++}; ++ ++struct LandmarksToTransformMatrixV2Attributes { ++ std::vector subset_idxs; ++ int left_rotation_idx; ++ int right_rotation_idx; ++ float target_rotation_radians; ++ int output_height; ++ int output_width; ++ float scale_x; ++ float scale_y; ++ float multiplier = 1.0; ++}; ++ ++class LandmarksToTransformMatrixOperationParser : public TFLiteOperationParser { ++ public: ++ absl::Status IsSupported(const TfLiteContext* context, ++ const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration) final; ++ absl::Status Parse(const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration, ++ GraphFloat32* graph, ObjectReader* reader) final; ++}; ++ ++absl::Status ParseLandmarksToTransformMatrixV1Attributes( ++ const void* data, uint32_t data_size, ++ LandmarksToTransformMatrixV1Attributes* attr, BHWC* output_shape); ++ ++absl::Status ParseLandmarksToTransformMatrixV2Attributes( ++ const void* data, uint32_t data_size, ++ LandmarksToTransformMatrixV2Attributes* attr, BHWC* output_shape); ++ ++// Converts subgraph of Reshape + Mul + Landmarks2TransformMatrix.v2 into ++// Landmarks2TransformMatrix.v2 with multiplier: ++// Source subgraph: ++// ++// Value_0 [1, 1, 1, 30] ++// | ++// Reshape ++// | ++// Value_1 [1, 10, 3] ++// | ++// Mul (* 0.25) ++// | ++// Value_2 [1, 10, 3] ++// | ++// Landmarks2TransformMatrix.v2 ++// | ++// Value_3 [1, 1, 4] ++// ++// Resulting subgraph: ++// ++// Value_0 [1, 1, 1, 30] ++// | ++// Landmarks2TransformMatrix.v2 ++// | ++// Value_3 [1, 1, 4] ++class LandmarksToTransformMatrixV2ToV2WithMul : public NodeTransformation { ++ public: ++ TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final; ++}; ++ ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_LANDMARKS_TO_TRANSFORM_MATRIX_H_ +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.cc b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.cc +new file mode 100644 +index 00000000000..fba7e742998 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.cc +@@ -0,0 +1,169 @@ ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h" ++ ++#include ++#include ++#include ++ ++#include "absl/types/any.h" ++#include "flatbuffers/flexbuffers.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" ++#include "tensorflow/lite/delegates/gpu/common/object_reader.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/tensor.h" ++ ++namespace tflite { ++namespace gpu { ++ ++absl::Status TransformLandmarksOperationParser::IsSupported( ++ const TfLiteContext* context, const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration) { ++ RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); ++ RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, ++ /*runtime_inputs=*/2, /*outputs=*/1)); ++ return absl::OkStatus(); ++} ++ ++absl::Status TransformLandmarksOperationParser::Parse( ++ const TfLiteNode* tflite_node, const TfLiteRegistration* registration, ++ GraphFloat32* graph, ObjectReader* reader) { ++ Node* node = graph->NewNode(); ++ RETURN_IF_ERROR(reader->AddInput(node, 0)); // data ++ RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox ++ RETURN_IF_ERROR(reader->AddOutputs(node)); ++ node->operation.type = kTransformLandmarksType; ++ BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape; ++ if (registration->version == 2) { ++ TransformLandmarksAttributes attr; ++ RETURN_IF_ERROR(ParseTransformLandmarksV2Attributes( ++ tflite_node->custom_initial_data, tflite_node->custom_initial_data_size, ++ &attr, &output_shape)); ++ node->operation.attributes = attr; ++ } else if (registration->version == 1) { ++ TransformLandmarksAttributes attr; ++ RETURN_IF_ERROR(ParseTransformLandmarksV1Attributes( ++ tflite_node->custom_initial_data, tflite_node->custom_initial_data_size, ++ &attr, &output_shape)); ++ node->operation.attributes = attr; ++ } else { ++ return absl::UnimplementedError( ++ "Transform Landmarks operation can be of version 1 or 2 only."); ++ } ++ ++ auto output_value = graph->FindOutputs(node->id)[0]; ++ ++ output_value->tensor.shape = graph->FindInputs(node->id)[0]->tensor.shape; ++ return absl::OkStatus(); ++} ++ ++absl::Status ParseTransformLandmarksV1Attributes( ++ const void* data, uint32_t data_size, TransformLandmarksAttributes* attr, ++ BHWC* output_shape) { ++ attr->version = 1; ++ ++ const flexbuffers::Map m = ++ flexbuffers::GetRoot(reinterpret_cast(data), data_size) ++ .AsMap(); ++ const flexbuffers::TypedVector keys = m.Keys(); ++ ++ for (int k = 0; k < keys.size(); ++k) { ++ const std::string key = keys[k].ToString(); ++ const auto value = m[key]; ++ if (key == "dimensions") { ++ attr->dimensions = value.AsInt32(); ++ } ++ if (key == "scale") { ++ attr->scale = value.AsFloat(); ++ } ++ } ++ return absl::OkStatus(); ++} ++ ++absl::Status ParseTransformLandmarksV2Attributes( ++ const void* data, uint32_t data_size, TransformLandmarksAttributes* attr, ++ BHWC* output_shape) { ++ attr->version = 2; ++ attr->dimensions = output_shape->c; ++ attr->scale = 1.0; ++ ++ return absl::OkStatus(); ++} ++ ++TransformResult TransformLandmarksV2ToV1::ApplyToNode(Node* node, ++ GraphFloat32* graph) { ++ // Recognize suitable Transform Landmarks operation. ++ if (node->operation.type != kTransformLandmarksType) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ TransformLandmarksAttributes transform_landmarks_attr = ++ absl::any_cast(node->operation.attributes); ++ if (transform_landmarks_attr.version != 2) { ++ return {TransformStatus::SKIPPED, ++ "Transform Landmarks operation should be of version 2."}; ++ } ++ ++ // Recognize suitable preceding Reshape. ++ std::vector transform_landmarks_inputs = graph->FindInputs(node->id); ++ if (transform_landmarks_inputs.size() != 2) { ++ return {TransformStatus::SKIPPED, ++ "Transform Landmarks operation should have two inputs."}; ++ } ++ Value* landmarks_input_tensor = transform_landmarks_inputs[1]; ++ if (transform_landmarks_inputs[1]->tensor.shape == BHWC(1, 1, 4, 4)) { ++ landmarks_input_tensor = transform_landmarks_inputs[0]; ++ } ++ Node* preceding_reshape = graph->FindProducer(landmarks_input_tensor->id); ++ if (preceding_reshape->operation.type != ToString(OperationType::RESHAPE)) { ++ return {TransformStatus::SKIPPED, ++ "Expected Reshape node to be a producer of the transformation " ++ "matrix input."}; ++ } ++ ++ // Recognize suitable succeeding Reshape. ++ std::vector transform_landmarks_outputs = ++ graph->FindOutputs(node->id); ++ if (transform_landmarks_outputs.size() != 1) { ++ return {TransformStatus::SKIPPED, ++ "Transform Landmarks operation should have one output."}; ++ } ++ Value* landmarks_output_tensor = transform_landmarks_outputs[0]; ++ std::vector landmarks__output_consumers = ++ graph->FindConsumers(landmarks_output_tensor->id); ++ if (landmarks__output_consumers.size() != 1) { ++ return {TransformStatus::SKIPPED, ++ "Transform Landmarks output should be consumed by one operation."}; ++ } ++ Node* succeeding_reshape = landmarks__output_consumers[0]; ++ if (succeeding_reshape->operation.type != ToString(OperationType::RESHAPE)) { ++ return {TransformStatus::SKIPPED, ++ "Expected Reshape node to be a consumer of the Transform " ++ "Landmarks operation's output value."}; ++ } ++ ++ // Delete preceding and succeding Reshape operations. ++ absl::Status removed_preceding = ++ RemoveSimpleNodeKeepInput(graph, preceding_reshape); ++ if (!removed_preceding.ok()) { ++ return {TransformStatus::INVALID, ++ "Unable to remove a preceding Reshape node: " + ++ std::string(removed_preceding.message())}; ++ } ++ absl::Status removed_succeeding = ++ RemoveSimpleNodeKeepOutput(graph, succeeding_reshape); ++ if (!removed_succeeding.ok()) { ++ return {TransformStatus::INVALID, ++ "Unable to remove a succeeding Reshape node: " + ++ std::string(removed_succeeding.message())}; ++ } ++ ++ // Switch Transform Landmarks operation back to version 1. ++ transform_landmarks_attr.version = 1; ++ node->operation.attributes = transform_landmarks_attr; ++ ++ return {TransformStatus::APPLIED, ""}; ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h +new file mode 100644 +index 00000000000..f804e14e55d +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h +@@ -0,0 +1,74 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_TRANSFORM_LANDMARKS_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_TRANSFORM_LANDMARKS_H_ ++ ++#include ++ ++#include "absl/types/any.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" ++#include "tensorflow/lite/delegates/gpu/common/object_reader.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++ ++namespace tflite { ++namespace gpu { ++ ++constexpr const char kTransformLandmarksType[] = "transform_landmarks"; ++ ++struct TransformLandmarksAttributes { ++ int dimensions = 3; ++ float scale = 1.0; ++ int version = 0; ++}; ++ ++class TransformLandmarksOperationParser : public TFLiteOperationParser { ++ public: ++ absl::Status IsSupported(const TfLiteContext* context, ++ const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration) final; ++ absl::Status Parse(const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration, ++ GraphFloat32* graph, ObjectReader* reader) final; ++}; ++ ++absl::Status ParseTransformLandmarksV1Attributes( ++ const void* data, uint32_t data_size, TransformLandmarksAttributes* attr, ++ BHWC* output_shape); ++ ++absl::Status ParseTransformLandmarksV2Attributes( ++ const void* data, uint32_t data_size, TransformLandmarksAttributes* attr, ++ BHWC* output_shape); ++ ++// Removes reshapes from subgraph: ++// ++// Value_0 [1, 1, 1, 240] ++// | ++// Reshape ++// | ++// Value_1 [1, 1, 80, 3] Value_2 [1, 1, 4, 4] ++// \ / ++// TransformLandmarks.version_2 ++// | ++// Value_3 [1, 1, 80, 3] ++// | ++// Reshape ++// | ++// Value_4 [1, 1, 1, 240] ++// ++// Resulting subgraph is: ++// ++// Value_0 [1, 1, 1, 240] Value_2 [1, 1, 4, 4] ++// \ / ++// TransformLandmarks.version_1 ++// | ++// Value_4 [1, 1, 1, 240] ++class TransformLandmarksV2ToV1 : public NodeTransformation { ++ public: ++ TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final; ++}; ++ ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_TRANSFORM_LANDMARKS_H_ +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.cc b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.cc +new file mode 100644 +index 00000000000..704ce7d4a47 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.cc +@@ -0,0 +1,142 @@ ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h" ++ ++#include ++#include ++#include ++#include ++ ++#include "absl/types/any.h" ++#include "flatbuffers/flexbuffers.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" ++#include "tensorflow/lite/delegates/gpu/common/object_reader.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/tensor.h" ++ ++namespace tflite { ++namespace gpu { ++ ++absl::Status TransformTensorBilinearOperationParser::IsSupported( ++ const TfLiteContext* context, const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration) { ++ RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); ++ RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, ++ /*runtime_inputs=*/2, /*outputs=*/1)); ++ return absl::OkStatus(); ++} ++ ++absl::Status TransformTensorBilinearOperationParser::Parse( ++ const TfLiteNode* tflite_node, const TfLiteRegistration* registration, ++ GraphFloat32* graph, ObjectReader* reader) { ++ Node* node = graph->NewNode(); ++ RETURN_IF_ERROR(reader->AddInput(node, 0)); // data ++ RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox ++ RETURN_IF_ERROR(reader->AddOutputs(node)); ++ ++ node->operation.type = kTransformTensorBilinearType; ++ BHWC output_shape; ++ if (registration->version == 2) { ++ TransformTensorBilinearAttributes attr; ++ RETURN_IF_ERROR(ParseTransformTensorBilinearV2Attributes( ++ tflite_node->custom_initial_data, tflite_node->custom_initial_data_size, ++ &attr, &output_shape)); ++ node->operation.attributes = attr; ++ } else if (registration->version == 1) { ++ TransformTensorBilinearAttributes attr; ++ RETURN_IF_ERROR(ParseTransformTensorBilinearV1Attributes( ++ tflite_node->custom_initial_data, tflite_node->custom_initial_data_size, ++ &attr, &output_shape)); ++ node->operation.attributes = attr; ++ } else { ++ return absl::UnimplementedError( ++ "Transform Tensor Bilinear operation can be of version 1 or 2 only."); ++ } ++ ++ auto output_value = graph->FindOutputs(node->id)[0]; ++ ++ output_value->tensor.shape = ++ BHWC(1, output_shape.h, output_shape.w, ++ graph->FindInputs(node->id)[0]->tensor.shape.c); ++ return absl::OkStatus(); ++} ++ ++absl::Status ParseTransformTensorBilinearV1Attributes( ++ const void* data, uint32_t data_size, ++ TransformTensorBilinearAttributes* attr, BHWC* output_shape) { ++ attr->version = 1; ++ ++ const flexbuffers::Map m = ++ flexbuffers::GetRoot(reinterpret_cast(data), data_size) ++ .AsMap(); ++ const flexbuffers::TypedVector keys = m.Keys(); ++ ++ for (int k = 0; k < keys.size(); ++k) { ++ const std::string key = keys[k].ToString(); ++ const auto value = m[key]; ++ if (key == "mode") { ++ if (value.AsString().str() != "bilinear") { ++ return absl::UnimplementedError( ++ "TransformTensor operation supports only bilinear interpolation."); ++ } ++ } ++ ++ if (key == "output_size") { ++ attr->output_size = HW(value.AsTypedVector()[0].AsInt32(), ++ value.AsTypedVector()[1].AsInt32()); ++ } ++ } ++ attr->align_corners = false; ++ *output_shape = BHWC(1, attr->output_size.h, attr->output_size.w, 1); ++ return absl::OkStatus(); ++} ++ ++absl::Status ParseTransformTensorBilinearV2Attributes( ++ const void* data, uint32_t data_size, ++ TransformTensorBilinearAttributes* attr, BHWC* output_shape) { ++ attr->version = 2; ++ ++ const flexbuffers::Map m = ++ flexbuffers::GetRoot(reinterpret_cast(data), data_size) ++ .AsMap(); ++ const flexbuffers::TypedVector keys = m.Keys(); ++ HW output_size; ++ for (int k = 0; k < keys.size(); ++k) { ++ const std::string key = keys[k].ToString(); ++ const auto value = m[key]; ++ if (key == "output_height") { ++ output_size.h = value.AsInt32(); ++ } ++ if (key == "output_width") { ++ output_size.w = value.AsInt32(); ++ } ++ } ++ attr->output_size = std::move(output_size); ++ attr->align_corners = true; ++ *output_shape = BHWC(1, attr->output_size.h, attr->output_size.w, 1); ++ return absl::OkStatus(); ++} ++ ++TransformResult TransformTensorBilinearV2ToV1::ApplyToNode( ++ Node* node, GraphFloat32* graph) { ++ if (node->operation.type != kTransformTensorBilinearType) { ++ return {TransformStatus::SKIPPED, ""}; ++ } ++ TransformTensorBilinearAttributes transform_tensor_attr = ++ absl::any_cast( ++ node->operation.attributes); ++ ++ if (transform_tensor_attr.version != 2) { ++ return {TransformStatus::SKIPPED, ++ "Transform Tensor Bilinear operation should be of version 2."}; ++ } ++ transform_tensor_attr.version = 1; ++ transform_tensor_attr.align_corners = true; ++ node->operation.attributes = transform_tensor_attr; ++ ++ return {TransformStatus::APPLIED, ""}; ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h +new file mode 100644 +index 00000000000..8a1f840c12f +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h +@@ -0,0 +1,54 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_TRANSFORM_TENSOR_BILINEAR_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_TRANSFORM_TENSOR_BILINEAR_H_ ++ ++#include ++ ++#include "absl/types/any.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" ++#include "tensorflow/lite/delegates/gpu/common/object_reader.h" ++#include "tensorflow/lite/delegates/gpu/common/operation_parser.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++ ++namespace tflite { ++namespace gpu { ++ ++constexpr const char kTransformTensorBilinearType[] = ++ "transform_tensor_bilinear"; ++ ++struct TransformTensorBilinearAttributes { ++ HW output_size; ++ bool align_corners = false; ++ int version = 0; ++}; ++ ++class TransformTensorBilinearOperationParser : public TFLiteOperationParser { ++ public: ++ absl::Status IsSupported(const TfLiteContext* context, ++ const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration) final; ++ absl::Status Parse(const TfLiteNode* tflite_node, ++ const TfLiteRegistration* registration, ++ GraphFloat32* graph, ObjectReader* reader) final; ++}; ++ ++absl::Status ParseTransformTensorBilinearV1Attributes( ++ const void* data, uint32_t data_size, ++ TransformTensorBilinearAttributes* attr, BHWC* output_shape); ++ ++absl::Status ParseTransformTensorBilinearV2Attributes( ++ const void* data, uint32_t data_size, ++ TransformTensorBilinearAttributes* attr, BHWC* output_shape); ++ ++// Converts Transform Tensor Bilinear operation of version 2 to version 1 with ++// align corners parameter set to true. ++class TransformTensorBilinearV2ToV1 : public NodeTransformation { ++ public: ++ TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final; ++}; ++ ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MEDIAPIPE_TRANSFORM_TENSOR_BILINEAR_H_ +diff --git a/tensorflow/lite/delegates/gpu/common/selectors/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/BUILD +index ec6c2281b9e..26cf9aab1a9 100644 +--- a/tensorflow/lite/delegates/gpu/common/selectors/BUILD ++++ b/tensorflow/lite/delegates/gpu/common/selectors/BUILD +@@ -45,9 +45,9 @@ cc_library( + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_hints", + "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common/selectors/mediapipe:default_selector", + "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", + "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", +- _selectors_package + ":default_selector", + ], + ) + +diff --git a/tensorflow/lite/delegates/gpu/common/selectors/mediapipe/BUILD b/tensorflow/lite/delegates/gpu/common/selectors/mediapipe/BUILD +new file mode 100644 +index 00000000000..d5a28d6f72e +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/selectors/mediapipe/BUILD +@@ -0,0 +1,21 @@ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++cc_library( ++ name = "default_selector", ++ srcs = ["default_selector.cc"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:model", ++ "//tensorflow/lite/delegates/gpu/common:model_hints", ++ "//tensorflow/lite/delegates/gpu/common:operations", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common/selectors:subgraph", ++ "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", ++ "//tensorflow/lite/delegates/gpu/common/tasks/mediapipe:landmarks_to_transform_matrix", ++ "//tensorflow/lite/delegates/gpu/common/tasks/mediapipe:transform_landmarks", ++ "//tensorflow/lite/delegates/gpu/common/tasks/mediapipe:transform_tensor_bilinear", ++ "@com_google_absl//absl/strings", ++ ], ++) +diff --git a/tensorflow/lite/delegates/gpu/common/selectors/mediapipe/default_selector.cc b/tensorflow/lite/delegates/gpu/common/selectors/mediapipe/default_selector.cc +new file mode 100644 +index 00000000000..9c93149f95b +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/selectors/mediapipe/default_selector.cc +@@ -0,0 +1,48 @@ ++#include ++ ++#include "absl/strings/str_cat.h" ++#include "tensorflow/lite/delegates/gpu/common/model.h" ++#include "tensorflow/lite/delegates/gpu/common/model_hints.h" ++#include "tensorflow/lite/delegates/gpu/common/operations.h" ++#include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" ++#include "tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.h" ++#include "tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.h" ++#include "tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.h" ++ ++namespace tflite { ++namespace gpu { ++namespace { ++ ++absl::Status CustomGPUOperationFromNode( ++ const GpuInfo& gpu_info, const OperationDef& op_def, ModelHints hints, ++ const std::vector& inputs, const std::vector& outputs, ++ const Node& node, GPUOperationsSubgraph* gpu_subgraph) { ++ std::unique_ptr* gpu_op = ++ InitSingleOpSubgraph(inputs, outputs, gpu_subgraph); ++ if (node.operation.type == kLandmarksToTransformMatrixType) { ++ return CreateLandmarksToTransformMatrixFromNode(op_def, node, gpu_op); ++ } ++ if (node.operation.type == kTransformLandmarksType) { ++ return CreateTransformLandmarksFromNode(op_def, node, gpu_op); ++ } ++ if (node.operation.type == kTransformTensorBilinearType) { ++ return CreateTransformTensorBilinearFromNode(op_def, node, gpu_op); ++ } ++ ++ return absl::UnimplementedError( ++ absl::StrCat("No selector for ", node.operation.type)); ++} ++} // namespace ++ ++absl::Status SelectDefault(const GpuInfo& gpu_info, const OperationDef& op_def, ++ ModelHints hints, const std::vector& inputs, ++ const std::vector& outputs, const Node& node, ++ GPUOperationsSubgraph* gpu_subgraph) { ++ return CustomGPUOperationFromNode(gpu_info, op_def, hints, inputs, outputs, ++ node, gpu_subgraph); ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/BUILD b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/BUILD +new file mode 100644 +index 00000000000..9df0735f0eb +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/BUILD +@@ -0,0 +1,39 @@ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++cc_library( ++ name = "landmarks_to_transform_matrix", ++ srcs = ["landmarks_to_transform_matrix.cc"], ++ hdrs = ["landmarks_to_transform_matrix.h"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common/mediapipe:landmarks_to_transform_matrix", ++ "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", ++ ], ++) ++ ++cc_library( ++ name = "transform_landmarks", ++ srcs = ["transform_landmarks.cc"], ++ hdrs = ["transform_landmarks.h"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common/mediapipe:transform_landmarks", ++ "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", ++ "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ++ ], ++) ++ ++cc_library( ++ name = "transform_tensor_bilinear", ++ srcs = ["transform_tensor_bilinear.cc"], ++ hdrs = ["transform_tensor_bilinear.h"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common/mediapipe:transform_tensor_bilinear", ++ "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", ++ "//tensorflow/lite/delegates/gpu/common/task:work_group_picking", ++ ], ++) +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.cc b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.cc +new file mode 100644 +index 00000000000..18f28b19361 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.cc +@@ -0,0 +1,368 @@ ++#include "tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.h" ++ ++#include ++#include ++ ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++ ++namespace tflite { ++namespace gpu { ++namespace { ++ ++std::string GetLandmarksToTransformMatrixV1KernelCode( ++ const OperationDef& op_def, ++ const LandmarksToTransformMatrixV1Attributes& attr) { ++ const std::string batch_id = op_def.IsBatchSupported() ? "B" : ""; ++ std::string c; ++ c += "#define MAT_MUL_3x3(R0, R1, R2, A0, A1, A2, B0, B1, B2) \\\n"; ++ c += " R0.x = A0.x * B0.x + A1.x * B0.y + A2.x * B0.z; \\\n"; ++ c += " R0.y = A0.y * B0.x + A1.y * B0.y + A2.y * B0.z; \\\n"; ++ c += " R0.z = A0.z * B0.x + A1.z * B0.y + A2.z * B0.z; \\\n"; ++ c += " R1.x = A0.x * B1.x + A1.x * B1.y + A2.x * B1.z; \\\n"; ++ c += " R1.y = A0.y * B1.x + A1.y * B1.y + A2.y * B1.z; \\\n"; ++ c += " R1.z = A0.z * B1.x + A1.z * B1.y + A2.z * B1.z; \\\n"; ++ c += " R2.x = A0.x * B2.x + A1.x * B2.y + A2.x * B2.z; \\\n"; ++ c += " R2.y = A0.y * B2.x + A1.y * B2.y + A2.y * B2.z; \\\n"; ++ c += " R2.z = A0.z * B2.x + A1.z * B2.y + A2.z * B2.z; \n"; ++ ++ c += "MAIN_FUNCTION($0) {\n"; ++ // temporary ++ c += " int dummy_var = GLOBAL_ID_0;\n"; ++ if (op_def.IsBatchSupported()) { ++ c += " int B = GLOBAL_ID_0;\n"; ++ c += " if (B >= args.dst_tensor.Batch()) return;\n"; ++ c += " args.dst_tensor.SetBatchRef(B);\n"; ++ c += " args.src_tensor.SetBatchRef(B);\n"; ++ } ++ // reads x and y coords only. ++ auto read_landmark = [&](const std::string& result, const std::string& id) { ++ c += " {\n"; ++ c += " int start = " + id + " * " + std::to_string(attr.dimensions) + ++ ";\n"; ++ c += " int ZC = start / 4;\n"; ++ if (attr.dimensions == 2) { ++ c += " float4 t_res = args.src_tensor.Read(0, 0, ZC);\n"; ++ c += " " + result + ".xy = t_res.xy;\n"; ++ } else if (attr.dimensions == 3) { ++ c += " float4 t_res = args.src_tensor.Read(0, 0, ZC);\n"; ++ c += " int rem = start % 4;\n"; ++ c += " if (rem == 0) {\n"; ++ c += " " + result + ".xy = t_res.xy;\n"; ++ c += " } else if (rem == 1) {\n"; ++ c += " " + result + ".xy = t_res.yz;\n"; ++ c += " } else if (rem == 2) {\n"; ++ c += " " + result + ".xy = t_res.zw;\n"; ++ c += " } else {\n"; ++ c += " float4 t_res_next = args.src_tensor.Read(0, 0, ZC + " ++ "1);\n"; ++ c += " " + result + ".xy = INIT_FLOAT2v2(t_res.w, t_res_next.x);\n"; ++ c += " }\n"; ++ } ++ c += " }\n"; ++ }; ++ c += " float2 l_pt, r_pt;\n"; ++ read_landmark("l_pt", "args.rotations_idx_x"); ++ read_landmark("r_pt", "args.rotations_idx_y"); ++ c += " float alpha = -atan2(r_pt.y - l_pt.y, r_pt.x - l_pt.x);\n"; ++ c += " float cosa = cos(alpha);\n"; ++ c += " float sina = sin(alpha);\n"; ++ c += " float2 max_value = INIT_FLOAT2v2(-100000.0f, -100000.0f);\n"; ++ c += " float2 min_value = INIT_FLOAT2v2(100000.0f, 100000.0f);\n"; ++ c += " for (int i = 0; i < args.subset_size; i++) {\n"; ++ c += " float2 p0, p1;\n"; ++ c += " int2 subset_v = args.subset.Read(i);\n"; ++ read_landmark("p0", "subset_v.x"); ++ read_landmark("p1", "subset_v.y"); ++ c += " // rotation\n"; ++ c += ++ " p0 = INIT_FLOAT2v2(p0.x*cosa - p0.y*sina, p0.x*sina + p0.y*cosa);\n"; ++ c += ++ " p1 = INIT_FLOAT2v2(p1.x*cosa - p1.y*sina, p1.x*sina + p1.y*cosa);\n"; ++ c += " max_value.x = max(max(p0.x, p1.x), max_value.x);\n"; ++ c += " max_value.y = max(max(p0.y, p1.y), max_value.y);\n"; ++ c += " min_value.x = min(min(p0.x, p1.x), min_value.x);\n"; ++ c += " min_value.y = min(min(p0.y, p1.y), min_value.y);\n"; ++ c += " }\n"; ++ c += " float2 bbox_size = (max_value - min_value) * " ++ "args.bbox_size_multiplier;\n"; ++ c += ++ " float3 scale_mat_c0 = INIT_FLOAT3v3(bbox_size.x / args.l_range, 0.0f, " ++ "0.0f);\n"; ++ c += ++ " float3 scale_mat_c1 = INIT_FLOAT3v3(0.0f, bbox_size.y / args.l_range, " ++ "0.0f);\n"; ++ c += " float3 scale_mat_c2 = INIT_FLOAT3v3(0.0f, 0.0f, 1.0f);\n"; ++ c += " float2 middle = (max_value + min_value) * 0.5f;\n"; ++ c += " float2 rotated_middle;\n"; ++ c += " float cosnega = cos(-alpha);\n"; ++ c += " float sinnega = sin(-alpha);\n"; ++ c += " rotated_middle.x = middle.x * cosnega - middle.y * sinnega;\n"; ++ c += " rotated_middle.y = middle.x * sinnega + middle.y * cosnega;\n"; ++ c += " float3 rot_mat_c0 = INIT_FLOAT3v3(cosnega, sinnega, 0.0f);\n"; ++ c += " float3 rot_mat_c1 = INIT_FLOAT3v3(-sinnega, cosnega, 0.0f);\n"; ++ c += " float3 rot_mat_c2 = INIT_FLOAT3v3(rotated_middle.x / args.l_range * " ++ "2.0f - " ++ "1.0f, rotated_middle.y / args.l_range * 2.0f - 1.0f, 1.0f);\n"; ++ c += " float3 to_relative_c0 = INIT_FLOAT3v3(2.0f / (args.output_size_x - " ++ "1.0f), 0.0f, 0.0f);\n"; ++ c += " float3 to_relative_c1 = INIT_FLOAT3v3(0.0f, 2.0f / " ++ "(args.output_size_y - 1.0f), 0.0f);\n"; ++ c += " float3 to_relative_c2 = INIT_FLOAT3v3(-1.0f, -1.0f, 1.0f);\n"; ++ c += " float3 to_absolute_c0 = INIT_FLOAT3v3((args.input_size_x - 1.0f) / " ++ "2.0f, 0.0f, 0.0f);\n"; ++ c += " float3 to_absolute_c1 = INIT_FLOAT3v3(0.0f, (args.input_size_y - " ++ "1.0f) / 2.0f, 0.0f);\n"; ++ c += " float3 to_absolute_c2 = INIT_FLOAT3v3((args.input_size_x - 1.0f) / " ++ "2.0f, (args.input_size_y - 1.0f) / 2.0f, 1.0f);\n"; ++ c += " float3 t0;\n"; ++ c += " float3 t1;\n"; ++ c += " float3 t2;\n"; ++ c += " // t0 = to_absolute * rotation_matrix\n"; ++ c += " MAT_MUL_3x3(t0, t1, t2, to_absolute_c0, to_absolute_c1, " ++ "to_absolute_c2, rot_mat_c0, rot_mat_c1, rot_mat_c2);\n"; ++ c += " float3 u0;\n"; ++ c += " float3 u1;\n"; ++ c += " float3 u2;\n"; ++ c += " // u0 = t0 * scale_matrix\n"; ++ c += " MAT_MUL_3x3(u0, u1, u2, t0, t1, t2, scale_mat_c0, scale_mat_c1, " ++ "scale_mat_c2);\n"; ++ c += " float3 res_c0;\n"; ++ c += " float3 res_c1;\n"; ++ c += " float3 res_c2;\n"; ++ c += " MAT_MUL_3x3(res_c0, res_c1, res_c2, u0, u1, u2, to_relative_c0, " ++ "to_relative_c1, to_relative_c2);\n"; ++ c += " FLT4 r0 = INIT_FLT4v4(res_c0.x, res_c1.x, 0.0f, res_c2.x);\n"; ++ c += " FLT4 r1 = INIT_FLT4v4(res_c0.y, res_c1.y, 0.0f, res_c2.y);\n"; ++ c += " FLT4 r2 = INIT_FLT4v4(res_c0.z, res_c1.z, res_c2.z, 0.0f);\n"; ++ c += " FLT4 r3 = INIT_FLT4v4( 0.0f, 0.0f, 0.0f, 1.0f);\n"; ++ c += " args.dst_tensor.Write(r0, 0, 0, 0);\n"; ++ c += " args.dst_tensor.Write(r1, 1, 0, 0);\n"; ++ c += " args.dst_tensor.Write(r2, 2, 0, 0);\n"; ++ c += " args.dst_tensor.Write(r3, 3, 0, 0);\n"; ++ c += "}\n"; ++ return c; ++} ++ ++std::string GetLandmarksToTransformMatrixV2KernelCode( ++ const OperationDef& op_def, ++ const LandmarksToTransformMatrixV2Attributes& attr) { ++ std::string c; ++ c += "#define MAT_MUL_3x3(R0, R1, R2, A0, A1, A2, B0, B1, B2) \\\n"; ++ c += " R0.x = A0.x * B0.x + A1.x * B0.y + A2.x * B0.z; \\\n"; ++ c += " R0.y = A0.y * B0.x + A1.y * B0.y + A2.y * B0.z; \\\n"; ++ c += " R0.z = A0.z * B0.x + A1.z * B0.y + A2.z * B0.z; \\\n"; ++ c += " R1.x = A0.x * B1.x + A1.x * B1.y + A2.x * B1.z; \\\n"; ++ c += " R1.y = A0.y * B1.x + A1.y * B1.y + A2.y * B1.z; \\\n"; ++ c += " R1.z = A0.z * B1.x + A1.z * B1.y + A2.z * B1.z; \\\n"; ++ c += " R2.x = A0.x * B2.x + A1.x * B2.y + A2.x * B2.z; \\\n"; ++ c += " R2.y = A0.y * B2.x + A1.y * B2.y + A2.y * B2.z; \\\n"; ++ c += " R2.z = A0.z * B2.x + A1.z * B2.y + A2.z * B2.z; \n"; ++ ++ c += "MAIN_FUNCTION($0) {\n"; ++ // temporary ++ c += " int dummy_var = GLOBAL_ID_0;\n"; ++ if (op_def.IsBatchSupported()) { ++ c += " int B = GLOBAL_ID_0;\n"; ++ c += " if (B >= args.dst_tensor.Batch()) return;\n"; ++ c += " args.dst_tensor.SetBatchRef(B);\n"; ++ c += " args.src_tensor.SetBatchRef(B);\n"; ++ } ++ // reads x and y coords only. ++ auto read_landmark = [&](const std::string& result, const std::string& id) { ++ c += " {\n"; ++ c += " int start = " + id + " * 3; // only 3 dimensional landmarks\n"; ++ c += " int ZC = start / 4;\n"; ++ c += " float4 t_res = args.src_tensor.Read(0, 0, ZC);\n"; ++ c += " int rem = start % 4;\n"; ++ c += " if (rem == 0) {\n"; ++ c += " " + result + ".xy = t_res.xy;\n"; ++ c += " } else if (rem == 1) {\n"; ++ c += " " + result + ".xy = t_res.yz;\n"; ++ c += " } else if (rem == 2) {\n"; ++ c += " " + result + ".xy = t_res.zw;\n"; ++ c += " } else {\n"; ++ c += " float4 t_res_next = args.src_tensor.Read(0, 0, ZC + " ++ "1);\n"; ++ c += " " + result + ".xy = INIT_FLOAT2v2(t_res.w, t_res_next.x);\n"; ++ c += " }\n"; ++ c += " " + result + " *= args.multiplier;\n"; ++ c += " }\n"; ++ }; ++ c += " float2 left_landmark, right_landmark;\n"; ++ read_landmark("left_landmark", "args.left_rotation_idx"); ++ read_landmark("right_landmark", "args.right_rotation_idx"); ++ c += " float diff_y = right_landmark.y - left_landmark.y;\n"; ++ c += " float diff_x = right_landmark.x - left_landmark.x;\n"; ++ c += " float rotation = 0.0;\n"; ++ c += " if (diff_y != 0.0 && diff_x != 0.0) {" ++ " rotation = atan2(diff_y, diff_x);\n" ++ " }"; ++ c += " float r = args.target_rotation_radians - rotation;\n"; ++ c += " float cosr = cos(r);\n"; ++ c += " float sinr = sin(r);\n"; ++ c += " float2 max_value = INIT_FLOAT2v2(-100000.0f, -100000.0f);\n"; ++ c += " float2 min_value = INIT_FLOAT2v2(100000.0f, 100000.0f);\n"; ++ c += " for (int i = 0; i < args.subset_idxs_size; i++) {\n"; ++ c += " float2 p0, p1;\n"; ++ c += " int2 subset_idxs_v = args.subset_idxs.Read(i);\n"; ++ read_landmark("p0", "subset_idxs_v.x"); ++ read_landmark("p1", "subset_idxs_v.y"); ++ c += " // rotation\n"; ++ c += ++ " p0 = INIT_FLOAT2v2(p0.x*cosr - p0.y*sinr, p0.x*sinr + p0.y*cosr);\n"; ++ c += ++ " p1 = INIT_FLOAT2v2(p1.x*cosr - p1.y*sinr, p1.x*sinr + p1.y*cosr);\n"; ++ c += " max_value.x = max(max(p0.x, p1.x), max_value.x);\n"; ++ c += " max_value.y = max(max(p0.y, p1.y), max_value.y);\n"; ++ c += " min_value.x = min(min(p0.x, p1.x), min_value.x);\n"; ++ c += " min_value.y = min(min(p0.y, p1.y), min_value.y);\n"; ++ c += " }\n"; ++ c += " float crop_width = max_value.x - min_value.x;\n"; ++ c += " float crop_height = max_value.y - min_value.y;\n"; ++ c += " float2 crop_xy1 = (max_value + min_value) / 2.0f;\n"; ++ c += " float crop_x = cos(-r) * crop_xy1.x - sin(-r) * crop_xy1.y;\n"; ++ c += " float crop_y = sin(-r) * crop_xy1.x + cos(-r) * crop_xy1.y;\n"; ++ c += " float3 shift_c0 = INIT_FLOAT3v3(1.0, 0.0, 0.0);\n"; ++ c += " float3 shift_c1 = INIT_FLOAT3v3(0.0, 1.0, 0.0);\n"; ++ c += " float3 shift_c2 = INIT_FLOAT3v3(crop_x, crop_y, 1.0);\n"; ++ c += " r = -r;\n"; ++ c += " float3 rotation_c0 = INIT_FLOAT3v3(cos(r), sin(r), 0.0);\n"; ++ c += " float3 rotation_c1 = INIT_FLOAT3v3(-sin(r), cos(r), 0.0);\n"; ++ c += " float3 rotation_c2 = INIT_FLOAT3v3(0.0, 0.0, 1.0);\n"; ++ c += " float3 t0;\n"; ++ c += " float3 t1;\n"; ++ c += " float3 t2;\n"; ++ c += " MAT_MUL_3x3(t0, t1, t2, shift_c0, shift_c1, shift_c2, " ++ " rotation_c0, rotation_c1, rotation_c2);\n"; ++ c += " float cs_x = args.scale_x * crop_width / args.output_width;\n"; ++ c += " float cs_y = args.scale_y * crop_height / args.output_height;\n"; ++ c += " float3 scale_c0 = INIT_FLOAT3v3(cs_x, 0.0, 0.0);\n"; ++ c += " float3 scale_c1 = INIT_FLOAT3v3(0.0, cs_y, 0.0);\n"; ++ c += " float3 scale_c2 = INIT_FLOAT3v3(0.0, 0.0, 1.0);\n"; ++ c += " MAT_MUL_3x3(t0, t1, t2, t0, t1, t2, " ++ " scale_c0, scale_c1, scale_c2);\n"; ++ c += " float shift_x = -1.0 * (args.output_width / 2.0);\n"; ++ c += " float shift_y = -1.0 * (args.output_height / 2.0);\n"; ++ c += " float3 shift2_c0 = INIT_FLOAT3v3(1.0, 0.0, 0.0);\n"; ++ c += " float3 shift2_c1 = INIT_FLOAT3v3(0.0, 1.0, 0.0);\n"; ++ c += " float3 shift2_c2 = INIT_FLOAT3v3(shift_x, shift_y, 1.0);\n"; ++ c += " MAT_MUL_3x3(t0, t1, t2, t0, t1, t2, " ++ " shift2_c0, shift2_c1, shift2_c2);\n"; ++ c += " FLT4 r0 = INIT_FLT4v4(t0.x, t1.x, 0.0f, t2.x);\n"; ++ c += " FLT4 r1 = INIT_FLT4v4(t0.y, t1.y, 0.0f, t2.y);\n"; ++ c += " FLT4 r2 = INIT_FLT4v4(t0.z, t1.z, t2.z, 0.0f);\n"; ++ c += " FLT4 r3 = INIT_FLT4v4(0.0f, 0.0f, 0.0f, 1.0f);\n"; ++ c += " args.dst_tensor.Write(r0, 0, 0, 0);\n"; ++ c += " args.dst_tensor.Write(r1, 1, 0, 0);\n"; ++ c += " args.dst_tensor.Write(r2, 2, 0, 0);\n"; ++ c += " args.dst_tensor.Write(r3, 3, 0, 0);\n"; ++ c += "}\n"; ++ return c; ++} ++ ++} // namespace ++ ++absl::Status CreateLandmarksToTransformMatrixFromNode( ++ const OperationDef& op_def, const Node& node, ++ std::unique_ptr* gpu_op) { ++ auto* attr_v1 = absl::any_cast( ++ &node.operation.attributes); ++ if (attr_v1) { ++ GPUOperation operation = ++ CreateLandmarksToTransformMatrixV1(op_def, *attr_v1); ++ *gpu_op = absl::make_unique(std::move(operation)); ++ return absl::OkStatus(); ++ } ++ auto* attr_v2 = absl::any_cast( ++ &node.operation.attributes); ++ if (attr_v2) { ++ GPUOperation operation = ++ CreateLandmarksToTransformMatrixV2(op_def, *attr_v2); ++ *gpu_op = absl::make_unique(std::move(operation)); ++ return absl::OkStatus(); ++ } ++ return absl::InvalidArgumentError( ++ "Landmarks To Transform Matrix operation supports only version 1 or " ++ "2."); ++} ++ ++GPUOperation CreateLandmarksToTransformMatrixV1( ++ const OperationDef& definition, ++ const LandmarksToTransformMatrixV1Attributes& attr) { ++ std::vector data(attr.subset.size() * 2); ++ for (int i = 0; i < attr.subset.size(); ++i) { ++ data[i * 2 + 0] = attr.subset[i].x; ++ data[i * 2 + 1] = attr.subset[i].y; ++ } ++ ++ BufferDescriptor desc; ++ desc.element_type = DataType::INT32; ++ desc.element_size = 2; ++ desc.memory_type = MemoryType::GLOBAL; ++ desc.size = attr.subset.size() * sizeof(int32_t) * 2; ++ desc.data.resize(desc.size); ++ memcpy(desc.data.data(), data.data(), desc.size); ++ ++ GPUOperation result(definition); ++ result.AddSrcTensor("src_tensor", definition.src_tensors[0]); ++ result.AddDstTensor("dst_tensor", definition.dst_tensors[0]); ++ result.args_.AddFloat("l_range", attr.landmarks_range); ++ result.args_.AddFloat("bbox_size_multiplier", attr.bbox_size_multiplier); ++ result.args_.AddInt("rotations_idx_x", attr.left_rotation_idx); ++ result.args_.AddInt("rotations_idx_y", attr.right_rotation_idx); ++ result.args_.AddFloat("input_size_x", attr.input_hw.w); ++ result.args_.AddFloat("input_size_y", attr.input_hw.h); ++ result.args_.AddFloat("output_size_x", attr.output_hw.w); ++ result.args_.AddFloat("output_size_y", attr.output_hw.h); ++ result.args_.AddInt("subset_size", attr.subset.size()); ++ result.args_.AddObject("subset", ++ absl::make_unique(std::move(desc))); ++ result.code_ = GetLandmarksToTransformMatrixV1KernelCode(definition, attr); ++ result.work_group_size_ = int3(1, 1, 1); ++ result.tensor_to_grid_ = TensorToGrid::kBToX_YIs1_ZIs1; ++ ++ return result; ++} ++ ++GPUOperation CreateLandmarksToTransformMatrixV2( ++ const OperationDef& definition, ++ const LandmarksToTransformMatrixV2Attributes& attr) { ++ std::vector data(attr.subset_idxs.size() * 2); ++ for (int i = 0; i < attr.subset_idxs.size(); ++i) { ++ data[i * 2 + 0] = attr.subset_idxs[i].x; ++ data[i * 2 + 1] = attr.subset_idxs[i].y; ++ } ++ ++ BufferDescriptor desc; ++ desc.element_type = DataType::INT32; ++ desc.element_size = 2; ++ desc.memory_type = MemoryType::GLOBAL; ++ desc.size = attr.subset_idxs.size() * sizeof(int32_t) * 2; ++ desc.data.resize(desc.size); ++ memcpy(desc.data.data(), data.data(), desc.size); ++ ++ GPUOperation result(definition); ++ result.AddSrcTensor("src_tensor", definition.src_tensors[0]); ++ result.AddDstTensor("dst_tensor", definition.dst_tensors[0]); ++ ++ result.args_.AddInt("left_rotation_idx", attr.left_rotation_idx); ++ result.args_.AddInt("right_rotation_idx", attr.right_rotation_idx); ++ result.args_.AddFloat("target_rotation_radians", ++ attr.target_rotation_radians); ++ result.args_.AddFloat("output_height", attr.output_height); ++ result.args_.AddFloat("output_width", attr.output_width); ++ result.args_.AddFloat("scale_x", attr.scale_x); ++ result.args_.AddFloat("scale_y", attr.scale_y); ++ result.args_.AddFloat("multiplier", attr.multiplier); ++ ++ result.args_.AddInt("subset_idxs_size", attr.subset_idxs.size()); ++ result.args_.AddObject("subset_idxs", ++ absl::make_unique(std::move(desc))); ++ result.code_ = GetLandmarksToTransformMatrixV2KernelCode(definition, attr); ++ result.work_group_size_ = int3(1, 1, 1); ++ result.tensor_to_grid_ = TensorToGrid::kBToX_YIs1_ZIs1; ++ return result; ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.h b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.h +new file mode 100644 +index 00000000000..2fd523df7c7 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/landmarks_to_transform_matrix.h +@@ -0,0 +1,26 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPELANDMARKS_TO_TRANSFORM_MATRIX_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPELANDMARKS_TO_TRANSFORM_MATRIX_H_ ++ ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" ++ ++namespace tflite { ++namespace gpu { ++ ++absl::Status CreateLandmarksToTransformMatrixFromNode( ++ const OperationDef& op_def, const Node& node, ++ std::unique_ptr* gpu_op); ++ ++GPUOperation CreateLandmarksToTransformMatrixV1( ++ const OperationDef& definition, ++ const LandmarksToTransformMatrixV1Attributes& attr); ++ ++GPUOperation CreateLandmarksToTransformMatrixV2( ++ const OperationDef& definition, ++ const LandmarksToTransformMatrixV2Attributes& attr); ++ ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPELANDMARKS_TO_TRANSFORM_MATRIX_H_ +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.cc b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.cc +new file mode 100644 +index 00000000000..999917a9251 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.cc +@@ -0,0 +1,116 @@ ++#include "tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.h" ++ ++#include ++ ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" ++ ++namespace tflite { ++namespace gpu { ++namespace { ++ ++std::string GetTransformLandmarksKernelCode(const OperationDef& op_def, ++ int dimension, float scale) { ++ std::string c; ++ c += "MAIN_FUNCTION($0) {\n"; ++ if (op_def.IsBatchSupported()) { ++ c += " int linear_id = GLOBAL_ID_0;\n"; ++ c += " int X = linear_id / args.dst_tensor.Batch();\n"; ++ c += " int B = linear_id % args.dst_tensor.Batch();\n"; ++ c += " args.dst_tensor.SetBatchRef(B);\n"; ++ c += " args.matrix_transform.SetBatchRef(B);\n"; ++ c += " args.src_tensor.SetBatchRef(B);\n"; ++ } else { ++ c += " int X = GLOBAL_ID_0;\n"; ++ } ++ c += " int Y = GLOBAL_ID_1;\n"; ++ c += " int Z = GLOBAL_ID_2;\n"; ++ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " ++ "Z >= args.dst_tensor.Slices()) " ++ "return;\n"; ++ c += " float4 x_transform = args.matrix_transform.Read(0, 0, 0);\n"; ++ c += " float4 y_transform = args.matrix_transform.Read(1, 0, 0);\n"; ++ if (scale != 1.0) { ++ c += " x_transform.w *= args.scale;\n"; ++ c += " y_transform.w *= args.scale;\n"; ++ } ++ c += " float4 landmks = args.src_tensor.Read(X, Y, Z);\n"; ++ c += " float4 result = INIT_FLOAT4(0.0f);\n"; ++ if (dimension == 2) { ++ c += " float4 l_pair1_ = INIT_FLOAT4v4(landmks.x, landmks.y, 0.0f, " ++ "1.0f);\n"; ++ c += " float4 l_pair2_ = INIT_FLOAT4v4(landmks.z, landmks.w, 0.0f, " ++ "1.0f);\n"; ++ c += " result.x = dot(x_transform, l_pair1_);\n"; ++ c += " result.y = dot(y_transform, l_pair1_);\n"; ++ c += " result.z = dot(x_transform, l_pair2_);\n"; ++ c += " result.w = dot(y_transform, l_pair2_);\n"; ++ } else if (dimension == 3) { ++ c += " int reminder = (Z * 4) % 3;\n"; ++ c += " if (reminder == 0) { // 0, 3, 6\n"; ++ c += " // x y z x\n"; ++ c += " float4 landmks_next = args.src_tensor.Read(X, Y, Z+1);\n"; ++ c += " float4 l_= landmks;\n"; ++ c += " l_.z = 0.0f;\n"; ++ c += " l_.w = 1.0f;\n"; ++ c += " result.x = dot(x_transform, l_);\n"; ++ c += " result.y = dot(y_transform, l_);\n"; ++ c += " result.z = landmks.z;\n"; ++ c += " result.w = dot(x_transform, INIT_FLOAT4v4(landmks.w, " ++ "landmks_next.x, " ++ "0.0f, 1.0f));\n"; ++ c += " } else if (reminder == 1) { // 1, 4, 7\n"; ++ c += " // y z x y\n"; ++ c += " float4 landmks_prev = args.src_tensor.Read(X, Y, Z-1);\n"; ++ c += " float4 l_ = INIT_FLOAT4v4(landmks.z, landmks.w, 0.0f, 1.0f);\n"; ++ c += " result.x = dot(y_transform, INIT_FLOAT4v4(landmks_prev.w, " ++ "landmks.x, " ++ "0.0f, 1.0f));\n"; ++ c += " result.y = landmks.y;\n"; ++ c += " result.z = dot(x_transform, l_);\n"; ++ c += " result.w = dot(y_transform, l_);\n"; ++ c += " } else { // reminder == 2; // 2, 5, 8\n"; ++ c += " // z, x, y, z\n"; ++ c += " float4 l_ = INIT_FLOAT4v4(landmks.y, landmks.z, 0.0f, 1.0f);\n"; ++ c += " result.x = landmks.x;\n"; ++ c += " result.y = dot(x_transform, l_);\n"; ++ c += " result.z = dot(y_transform, l_);\n"; ++ c += " result.w = landmks.w;\n"; ++ c += " }\n"; ++ } ++ c += " FLT4 res = TO_FLT4(result);\n"; ++ c += " args.dst_tensor.Write(res, X, Y, Z);\n"; ++ c += "}\n"; ++ return c; ++} ++} // namespace ++ ++absl::Status CreateTransformLandmarksFromNode( ++ const OperationDef& op_def, const Node& node, ++ std::unique_ptr* gpu_op) { ++ auto attr = ++ absl::any_cast(node.operation.attributes); ++ if (attr.version != 1) { ++ return absl::InvalidArgumentError( ++ "Transform Landmarks operation supports only version 1."); ++ } ++ GPUOperation operation = CreateTransformLandmarks(op_def, attr); ++ *gpu_op = absl::make_unique(std::move(operation)); ++ return absl::OkStatus(); ++} ++ ++GPUOperation CreateTransformLandmarks( ++ const OperationDef& definition, const TransformLandmarksAttributes& attr) { ++ GPUOperation op(definition); ++ op.AddSrcTensor("src_tensor", definition.src_tensors[0]); ++ op.AddSrcTensor("matrix_transform", definition.src_tensors[1]); ++ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]); ++ op.args_.AddFloat("scale", attr.scale); ++ op.code_ = ++ GetTransformLandmarksKernelCode(definition, attr.dimensions, attr.scale); ++ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; ++ return op; ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.h b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.h +new file mode 100644 +index 00000000000..5c0be19033a +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_landmarks.h +@@ -0,0 +1,21 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPETRANSFORM_LANDMARKS_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPETRANSFORM_LANDMARKS_H_ ++ ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" ++ ++namespace tflite { ++namespace gpu { ++ ++absl::Status CreateTransformLandmarksFromNode( ++ const OperationDef& op_def, const Node& node, ++ std::unique_ptr* gpu_op); ++ ++GPUOperation CreateTransformLandmarks(const OperationDef& definition, ++ const TransformLandmarksAttributes& attr); ++ ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPETRANSFORM_LANDMARKS_H_ +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.cc b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.cc +new file mode 100644 +index 00000000000..2723216f324 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.cc +@@ -0,0 +1,133 @@ ++#include "tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.h" ++ ++#include ++ ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h" ++ ++namespace tflite { ++namespace gpu { ++namespace { ++ ++std::string AlignCornersCorrection(bool align_corners) { ++ // Align corners correction: T -> S * ( T * A ), where T is a ++ // transformation matrix, and subtruction and addition matrices are: ++ // S A ++ // 1 0 0 -0.5 1 0 0 0.5 ++ // 0 1 0 -0.5 0 1 0 0.5 ++ // 0 0 1 0 0 0 1 0 ++ // 0 0 0 1 0 0 0 1 ++ // Transformation matrix column 3 and rows 3, 4 are identity, which makes ++ // the final formula pretty simple and easy to get if doing a manual ++ // multiuplication. ++ return align_corners ? R"( ++ first_line.w += first_line.x * 0.5 + first_line.y * 0.5 - 0.5; ++ second_line.w += second_line.x * 0.5 + second_line.y * 0.5 - 0.5; ++ )" ++ : ""; ++} ++ ++std::string GetTransformTensorBilinearKernelCode(const OperationDef& op_def, ++ bool align_corners) { ++ std::string c; ++ c += "MAIN_FUNCTION($0) {\n"; ++ c += " int Y = GLOBAL_ID_1;\n"; ++ c += " int Z = GLOBAL_ID_2;\n"; ++ if (op_def.IsBatchSupported()) { ++ c += " int linear_id = GLOBAL_ID_0;\n"; ++ c += " int X = linear_id / args.dst_tensor.Batch();\n"; ++ c += " int B = linear_id % args.dst_tensor.Batch();\n"; ++ c += " args.dst_tensor.SetBatchRef(B);\n"; ++ c += " args.matrix_transform.SetBatchRef(B);\n"; ++ c += " args.src_tensor.SetBatchRef(B);\n"; ++ } else { ++ c += " int X = GLOBAL_ID_0;\n"; ++ } ++ c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || " ++ "Z >= args.dst_tensor.Slices()) " ++ "return;\n"; ++ c += " float4 first_line = args.matrix_transform.Read(0, 0, 0);\n"; ++ c += " float4 second_line = args.matrix_transform.Read(1, 0, 0);\n"; ++ c += AlignCornersCorrection(align_corners); ++ c += " float4 before_transform_coord_2d = INIT_FLOAT4v4(INIT_FLOAT(X), " ++ "INIT_FLOAT(Y), " ++ "0.0f, 1.0f);\n"; ++ c += " // Get transformed coordinates\n"; ++ c += ++ " float2 xy = INIT_FLOAT2v2(dot(first_line, before_transform_coord_2d), " ++ "dot(second_line, before_transform_coord_2d));\n"; ++ c += " float2 xy_floor = floor(xy);\n"; ++ c += " int4 st;\n"; ++ c += " st.xy = INIT_INT2v2(xy_floor.x, xy_floor.y);\n"; ++ c += " st.zw = INIT_INT2v2(xy_floor.x, xy_floor.y) + INIT_INT2v2(1, 1);\n"; ++ c += " // Apply interpolation if coordinate is in bounds.\n"; ++ c += " float4 result = INIT_FLOAT4(0.0f);\n"; ++ c += " float2 t = xy - xy_floor;\n"; ++ c += " if(xy.x >= 0.0 && xy.x <= INIT_FLOAT(args.src_tensor.Width() - 1) && " ++ "xy.y >= 0.0 && " ++ "xy.y <= INIT_FLOAT(args.src_tensor.Height() - 1)) {\n"; ++ c += " float4 p0 = INIT_FLOAT4(0.0f);\n"; ++ c += " float4 p1 = INIT_FLOAT4(0.0f);\n"; ++ c += " float4 p2 = INIT_FLOAT4(0.0f);\n"; ++ c += " float4 p3 = INIT_FLOAT4(0.0f);\n"; ++ const auto src_tensor_type = op_def.src_tensors[0].storage_type; ++ const bool buffer_type = src_tensor_type == TensorStorageType::BUFFER || ++ src_tensor_type == TensorStorageType::IMAGE_BUFFER; ++ auto read_src = [&](const std::string& result, const std::string& xc, ++ const std::string& yc, const std::string& zc) { ++ if (buffer_type) { ++ c += " if(" + xc + " >= 0 && " + yc + " >= 0 && " + xc + ++ " < args.src_tensor.Width() && " + yc + ++ " < args.src_tensor.Height()) {\n"; ++ c += " " + result + " = args.src_tensor.Read(" + xc + ", " + ++ yc + ", " + zc + ");\n"; ++ c += " }\n"; ++ } else { ++ c += " " + result + " = args.src_tensor.Read(" + xc + ", " + ++ yc + ", " + zc + ");\n"; ++ } ++ }; ++ read_src("p0", "st.x", "st.y", "Z"); ++ read_src("p1", "st.z", "st.y", "Z"); ++ read_src("p2", "st.x", "st.w", "Z"); ++ read_src("p3", "st.z", "st.w", "Z"); ++ c += " result = mix(mix(p0, p1, t.x), mix(p2, p3, t.x), t.y);\n"; ++ c += " }\n"; ++ c += " FLT4 res = TO_FLT4(result);\n"; ++ c += " args.dst_tensor.Write(res, X, Y, Z);\n"; ++ c += "}\n"; ++ return c; ++} ++} // namespace ++ ++absl::Status CreateTransformTensorBilinearFromNode( ++ const OperationDef& op_def, const Node& node, ++ std::unique_ptr* gpu_op) { ++ auto attr = absl::any_cast( ++ node.operation.attributes); ++ if (attr.version != 1) { ++ return absl::InvalidArgumentError( ++ "Transform Tensor Bilinear operation supports only version 1."); ++ } ++ GPUOperation operation = CreateTransformTensorBilinear(op_def, attr); ++ *gpu_op = absl::make_unique(std::move(operation)); ++ return absl::OkStatus(); ++} ++ ++GPUOperation CreateTransformTensorBilinear( ++ const OperationDef& definition, ++ const TransformTensorBilinearAttributes& attr) { ++ GPUOperation op(definition); ++ auto src_desc = definition.src_tensors[0]; ++ src_desc.SetAddressMode(AddressMode::kZero); ++ op.AddSrcTensor("src_tensor", src_desc); ++ op.AddSrcTensor("matrix_transform", definition.src_tensors[1]); ++ op.AddDstTensor("dst_tensor", definition.dst_tensors[0]); ++ op.code_ = ++ GetTransformTensorBilinearKernelCode(definition, attr.align_corners); ++ op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ; ++ return op; ++} ++ ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.h b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.h +new file mode 100644 +index 00000000000..0251265cdf4 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/common/tasks/mediapipe/transform_tensor_bilinear.h +@@ -0,0 +1,22 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPETRANSFORM_TENSOR_BILINEAR_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPETRANSFORM_TENSOR_BILINEAR_H_ ++ ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h" ++ ++namespace tflite { ++namespace gpu { ++ ++absl::Status CreateTransformTensorBilinearFromNode( ++ const OperationDef& op_def, const Node& node, ++ std::unique_ptr* gpu_op); ++ ++GPUOperation CreateTransformTensorBilinear( ++ const OperationDef& definition, ++ const TransformTensorBilinearAttributes& attr); ++ ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_TASKS_MEDIAPIPETRANSFORM_TENSOR_BILINEAR_H_ +diff --git a/tensorflow/lite/delegates/gpu/common/transformations/BUILD b/tensorflow/lite/delegates/gpu/common/transformations/BUILD +index d26b4f807de..9596dbab7e6 100644 +--- a/tensorflow/lite/delegates/gpu/common/transformations/BUILD ++++ b/tensorflow/lite/delegates/gpu/common/transformations/BUILD +@@ -287,7 +287,7 @@ cc_library( + ":merge_padding_with", + ":remove_noop", + "//tensorflow/lite/delegates/gpu/common:model_transformer", +- ] + tf_platform_alias("custom_transformations", "//tensorflow/lite/delegates/gpu/common/"), ++ ] + ["//tensorflow/lite/delegates/gpu/common/mediapipe:custom_transformations"], + ) + + cc_library( +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD +index b7860b44ede..30cc160d32c 100644 +--- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD +@@ -153,10 +153,11 @@ cc_test( + + cc_library( + name = "custom_registry", +- srcs = ["custom_registry.cc"], ++ srcs = ["//tensorflow/lite/delegates/gpu/gl/kernels/mediapipe:registry.cc"], + hdrs = ["custom_registry.h"], + deps = [ + "//tensorflow/lite/delegates/gpu/gl:node_shader", ++ "//tensorflow/lite/delegates/gpu/gl/kernels/mediapipe:all_custom_ops", + "@com_google_absl//absl/container:flat_hash_map", + ], + ) +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/BUILD +new file mode 100644 +index 00000000000..f5e696d0859 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/BUILD +@@ -0,0 +1,85 @@ ++load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") ++ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++exports_files([ ++ "registry.cc", ++ "landmarks_to_transform_matrix.h", ++ "transform_landmarks.h", ++ "transform_tensor_bilinear.h", ++]) ++ ++cc_library( ++ name = "all_custom_ops", ++ hdrs = [ ++ "landmarks_to_transform_matrix.h", ++ "transform_landmarks.h", ++ "transform_tensor_bilinear.h", ++ ], ++ deps = [ ++ ":landmarks_to_transform_matrix", ++ ":transform_landmarks", ++ ":transform_tensor_bilinear", ++ "//tensorflow/lite/delegates/gpu/common:operations", ++ "//tensorflow/lite/delegates/gpu/gl:node_shader", ++ ], ++) ++ ++cc_library( ++ name = "landmarks_to_transform_matrix", ++ srcs = ["landmarks_to_transform_matrix.cc"], ++ hdrs = ["landmarks_to_transform_matrix.h"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:operations", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:types", ++ "//tensorflow/lite/delegates/gpu/common:util", ++ "//tensorflow/lite/delegates/gpu/common/mediapipe:landmarks_to_transform_matrix", ++ "//tensorflow/lite/delegates/gpu/gl:node_shader", ++ "@com_google_absl//absl/memory", ++ "@com_google_absl//absl/strings", ++ "@com_google_absl//absl/types:any", ++ ], ++) ++ ++cc_library( ++ name = "transform_tensor_bilinear", ++ srcs = ["transform_tensor_bilinear.cc"], ++ hdrs = ["transform_tensor_bilinear.h"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:operations", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:types", ++ "//tensorflow/lite/delegates/gpu/common:util", ++ "//tensorflow/lite/delegates/gpu/common/mediapipe:transform_tensor_bilinear", ++ "//tensorflow/lite/delegates/gpu/gl:node_shader", ++ "@com_google_absl//absl/memory", ++ "@com_google_absl//absl/strings", ++ "@com_google_absl//absl/types:any", ++ ], ++) ++ ++cc_library( ++ name = "transform_landmarks", ++ srcs = ["transform_landmarks.cc"], ++ hdrs = ["transform_landmarks.h"], ++ deps = [ ++ "//tensorflow/lite/delegates/gpu/common:operations", ++ "//tensorflow/lite/delegates/gpu/common:shape", ++ "//tensorflow/lite/delegates/gpu/common:status", ++ "//tensorflow/lite/delegates/gpu/common:types", ++ "//tensorflow/lite/delegates/gpu/common:util", ++ "//tensorflow/lite/delegates/gpu/common/mediapipe:transform_landmarks", ++ "//tensorflow/lite/delegates/gpu/gl:node_shader", ++ "@com_google_absl//absl/memory", ++ "@com_google_absl//absl/strings", ++ "@com_google_absl//absl/types:any", ++ ], ++) ++ ++tflite_portable_test_suite() +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.cc +new file mode 100644 +index 00000000000..de75dd7df2e +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.cc +@@ -0,0 +1,356 @@ ++#include "tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.h" ++ ++#include ++#include ++#include ++#include ++#include ++ ++#include "absl/memory/memory.h" ++#include "absl/strings/substitute.h" ++#include "absl/types/any.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/landmarks_to_transform_matrix.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/types.h" ++#include "tensorflow/lite/delegates/gpu/common/util.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++namespace { ++ ++namespace v1 { ++ ++std::string ReadLandmark(const std::string& landmark, const std::string& idx) { ++ std::string source = R"( ++ vec4 )" + landmark + ++ R"(; ++ { ++ int z_coord = )" + ++ idx + ++ R"( * $dimensions$ / 4; ++ vec4 result = $input_data_0[0, 0, z_coord]$; ++ int rest = )" + idx + ++ R"( * $dimensions$ % 4; ++ if (rest != 0) { ++ if (rest == 1) { ++ result.x = result.y; ++ result.y = result.z; ++ } ++ if (rest == 2) { ++ result.x = result.z; ++ result.y = result.w; ++ } ++ if (rest == 3) { ++ vec4 next_after_result = $input_data_0[0, 0, z_coord + 1]$; ++ result.x = result.w; ++ result.y = next_after_result.x; ++ } ++ } ++ )" + landmark + R"( = result; ++ } ++ )"; ++ return source; ++} ++ ++bool IsSupported(const LandmarksToTransformMatrixV1Attributes& attr) { ++ return attr.dimensions == 3; ++} ++ ++absl::Status GenerateCode(const LandmarksToTransformMatrixV1Attributes& attr, ++ const NodeShader::GenerationContext& ctx, ++ GeneratedCode* generated_code) { ++ if (!IsSupported(attr)) { ++ return absl::InvalidArgumentError( ++ "This case is not supported by LandmarksToTransformMatrix v1"); ++ } ++ ++ std::vector params = { ++ {"dimensions", static_cast(attr.dimensions)}, ++ {"landmarks_range", static_cast(attr.landmarks_range)}, ++ {"left_rotation_idx", static_cast(attr.left_rotation_idx)}, ++ {"right_rotation_idx", static_cast(attr.right_rotation_idx)}, ++ {"bbox_size_multiplier", static_cast(attr.bbox_size_multiplier)}, ++ {"input_h", static_cast(attr.input_hw.h)}, ++ {"input_w", static_cast(attr.input_hw.w)}, ++ {"output_h", static_cast(attr.output_hw.h)}, ++ {"output_w", static_cast(attr.output_hw.w)}, ++ {"subset", attr.subset}, ++ {"subset_size", static_cast(attr.subset.size())}, ++ }; ++ ++ std::string source = R"( ++ )" + ReadLandmark("left_landmark", "$left_rotation_idx$") + ++ R"( ++ ++ )" + ReadLandmark("right_landmark", "$right_rotation_idx$") + ++ R"( ++ ++ float alpha = -atan(right_landmark.y - left_landmark.y, ++ right_landmark.x - left_landmark.x); ++ ++ vec4 max_value = vec4(-100000, -100000, 0.0, 0.0); ++ vec4 min_value = vec4(100000, 100000, 0.0, 0.0); ++ for (int i = 0; i < $subset_size$; i++) { ++ for (int j = 0; j < 2; j++) { ++ )" + ReadLandmark("landmark_current", "$subset$[i][j]") + ++ R"( ++ ++ vec4 rotated = vec4(landmark_current.x * cos(alpha) - ++ landmark_current.y * sin(alpha), ++ landmark_current.x * sin(alpha) + ++ landmark_current.y * cos(alpha), ++ 0.0, 0.0); ++ // both by x and y ++ max_value = vec4(max(max_value.x, rotated.x), ++ max(max_value.y, rotated.y), ++ 0.0, 0.0); ++ min_value = vec4(min(min_value.x, rotated.x), ++ min(min_value.y, rotated.y), ++ 0.0, 0.0); ++ } ++ } ++ ++ vec4 bbox_size = max_value - min_value; ++ bbox_size *= $bbox_size_multiplier$; ++ ++ mat3 scale_matrix = ++ mat3(bbox_size.x / float($landmarks_range$), 0.0, 0.0, // first column ++ 0.0, bbox_size.y / float($landmarks_range$), 0.0, // second column ++ 0.0, 0.0, 1.0); // third column ++ ++ vec4 middle = (max_value + min_value) / 2.0; ++ ++ vec4 rotated_middle = ++ vec4(middle.x * cos(-alpha) - middle.y * sin(-alpha), ++ middle.x * sin(-alpha) + middle.y * cos(-alpha), 0.0, 0.0); ++ ++ mat3 rotation_matrix = ++ mat3(cos(-alpha), sin(-alpha), 0, // first column ++ -sin(-alpha), cos(-alpha), 0, // second column ++ // third column ++ (rotated_middle.x / float($landmarks_range$)) * 2.0 - 1.0, ++ (rotated_middle.y / float($landmarks_range$)) * 2.0 - 1.0, 1); ++ ++ mat3 to_relative = ++ mat3(2.0 / (float($output_w$) - 1.0), 0.0, 0.0, // first column ++ 0.0, 2.0 / (float($output_h$) - 1.0), 0.0, // second column ++ -1.0, -1.0, 1.0); // third column ++ ++ mat3 to_absolute = ++ mat3((float($input_w$) - 1.0) / 2.0, 0.0, 0.0, // first column ++ 0.0, (float($input_h$) - 1.0) / 2.0, 0.0, // second column ++ // third column ++ (float($input_w$) - 1.0) / 2.0, (float($input_h$) - 1.0)/2.0, 1.0); ++ ++ // Transformstion Matrix ++ mat3 tm = to_absolute * rotation_matrix * scale_matrix * to_relative; ++ ++ // Inverse Transformation Matrix ++ $output_data_0[0, 0, 0] = vec4(tm[0][0], tm[1][0], 0.0, tm[2][0])$; ++ $output_data_0[1, 0, 0] = vec4(tm[0][1], tm[1][1], 0.0, tm[2][1])$; ++ $output_data_0[2, 0, 0] = vec4(tm[0][2], tm[1][2], tm[2][2], 0.0)$; ++ $output_data_0[3, 0, 0] = vec4( 0, 0, 0, 1.0)$; ++ )"; ++ ++ *generated_code = { ++ /*parameters=*/params, ++ /*objects=*/{}, ++ /*shared_variables=*/{}, ++ /*workload=*/uint3(1, 1, 1), ++ /*workgroup=*/uint3(1, 1, 1), ++ /*source_code=*/std::move(source), ++ /*input=*/IOStructure::ONLY_DEFINITIONS, ++ /*output=*/IOStructure::ONLY_DEFINITIONS, ++ }; ++ return absl::OkStatus(); ++} ++ ++} // namespace v1 ++ ++namespace v2 { ++ ++std::string ReadLandmark(const std::string& landmark, const std::string& idx) { ++ std::string source = R"( ++ vec4 )" + landmark + ++ R"(; ++ { ++ int z_coord = )" + ++ idx + ++ R"( * $dimensions$ / 4; ++ vec4 result = $input_data_0[0, 0, z_coord]$; ++ int rest = )" + idx + ++ R"( * $dimensions$ % 4; ++ if (rest != 0) { ++ if (rest == 1) { ++ result.x = result.y; ++ result.y = result.z; ++ } ++ if (rest == 2) { ++ result.x = result.z; ++ result.y = result.w; ++ } ++ if (rest == 3) { ++ vec4 next_after_result = $input_data_0[0, 0, z_coord + 1]$; ++ result.x = result.w; ++ result.y = next_after_result.x; ++ } ++ } ++ result *= $multiplier$; ++ )" + landmark + R"( = result; ++ } )"; ++ return source; ++} ++ ++static bool IsSupported(const NodeShader::GenerationContext& ctx) { ++ return ctx.input_shapes.size() == 1 && ctx.input_shapes[0][1] == 1 && ++ ctx.input_shapes[0][2] == 1 && ctx.input_shapes[0][3] % 3 == 0; ++} ++ ++absl::Status GenerateCode(const LandmarksToTransformMatrixV2Attributes& attr, ++ const NodeShader::GenerationContext& ctx, ++ GeneratedCode* generated_code) { ++ if (!IsSupported(ctx)) { ++ return absl::InvalidArgumentError( ++ "This case is not supported by LandmarksToTransformMatrixV2"); ++ } ++ ++ std::vector params = { ++ {"dimensions", static_cast(3)}, ++ {"scale_x", static_cast(attr.scale_x)}, ++ {"scale_y", static_cast(attr.scale_y)}, ++ {"left_rotation_idx", static_cast(attr.left_rotation_idx)}, ++ {"right_rotation_idx", static_cast(attr.right_rotation_idx)}, ++ {"target_rotation_radians", ++ static_cast(attr.target_rotation_radians)}, ++ {"output_width", static_cast(attr.output_width)}, ++ {"output_height", static_cast(attr.output_height)}, ++ {"subset_idxs", attr.subset_idxs}, ++ {"subset_idxs_size", static_cast(attr.subset_idxs.size())}, ++ {"multiplier", static_cast(attr.multiplier)}, ++ }; ++ ++ std::string source = R"( ++ )" + ReadLandmark("left_landmark", "$left_rotation_idx$") + ++ R"( ++ )" + ReadLandmark("right_landmark", "$right_rotation_idx$") + ++ R"( ++ ++ float diff_y = right_landmark.y - left_landmark.y; ++ float diff_x = right_landmark.x - left_landmark.x; ++ float rotation = 0.0; ++ if (diff_y != 0.0 && diff_x != 0.0) rotation = atan(diff_y, diff_x); ++ float r = $target_rotation_radians$ - rotation; ++ ++ vec4 max_value = vec4(-100000, -100000, 0.0, 0.0); ++ vec4 min_value = vec4(100000, 100000, 0.0, 0.0); ++ for (int i = 0; i < $subset_idxs_size$; i++) { ++ for (int j = 0; j < 2; j++) { ++ )" + ReadLandmark("landmark_current", "$subset_idxs$[i][j]") + ++ R"( ++ vec4 rotated = vec4(landmark_current.x * cos(r) - ++ landmark_current.y * sin(r), ++ landmark_current.x * sin(r) + ++ landmark_current.y * cos(r), ++ 0.0, 0.0); ++ // both by x and y ++ max_value = vec4(max(max_value.x, rotated.x), ++ max(max_value.y, rotated.y), ++ 0.0, 0.0); ++ min_value = vec4(min(min_value.x, rotated.x), ++ min(min_value.y, rotated.y), ++ 0.0, 0.0); ++ } ++ } ++ ++ float crop_width = max_value.x - min_value.x; ++ float crop_height = max_value.y - min_value.y; ++ ++ vec4 crop_xy1 = (max_value + min_value) / vec4(2.0); ++ ++ float crop_x = cos(-r) * crop_xy1.x - sin(-r) * crop_xy1.y; ++ float crop_y = sin(-r) * crop_xy1.x + cos(-r) * crop_xy1.y; ++ ++ ++ mat4 t = mat4(1.0, 0.0, 0.0, 0.0, // first column ++ 0.0, 1.0, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ 0.0, 0.0, 0.0, 1.0); // forth column ++ ++ mat4 t_shift = mat4(1.0, 0.0, 0.0, 0.0, // first column ++ 0.0, 1.0, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ crop_x, crop_y, 0.0, 1.0); // forth column ++ t *= t_shift; ++ ++ r = -r; ++ ++ mat4 t_rotation = mat4(cos(r), sin(r), 0.0, 0.0, // first column ++ -sin(r), cos(r), 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ 0.0, 0.0, 0.0, 1.0); // forth column ++ ++ t *= t_rotation; ++ // cropped scale for x and y ++ float cs_x = $scale_x$ * crop_width / $output_width$; ++ float cs_y = $scale_y$ * crop_height / $output_height$; ++ mat4 t_scale = mat4(cs_x, 0.0, 0.0, 0.0, // first column ++ 0.0, cs_y, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ 0.0, 0.0, 0.0, 1.0); // forth column ++ t *= t_scale; ++ float shift_x = -1.0 * ($output_width$ / 2.0); ++ float shift_y = -1.0 * ($output_height$ / 2.0); ++ mat4 t_shift2 = mat4(1.0, 0.0, 0.0, 0.0, // first column ++ 0.0, 1.0, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ shift_x, shift_y, 0.0, 1.0); // forth column ++ t *= t_shift2; ++ // Inverse Transformation Matrix ++ $output_data_0[0, 0, 0] = vec4(t[0][0], t[1][0], t[2][0], t[3][0])$; ++ $output_data_0[1, 0, 0] = vec4(t[0][1], t[1][1], t[2][1], t[3][1])$; ++ $output_data_0[2, 0, 0] = vec4(t[0][2], t[1][2], t[2][2], t[3][2])$; ++ $output_data_0[3, 0, 0] = vec4(t[0][3], t[1][3], t[2][3], t[3][3])$; ++ )"; ++ ++ *generated_code = { ++ /*parameters=*/params, ++ /*objects=*/{}, ++ /*shared_variables=*/{}, ++ /*workload=*/uint3(1, 1, 1), ++ /*workgroup=*/uint3(1, 1, 1), ++ /*source_code=*/std::move(source), ++ /*input=*/IOStructure::ONLY_DEFINITIONS, ++ /*output=*/IOStructure::ONLY_DEFINITIONS, ++ }; ++ return absl::OkStatus(); ++} ++ ++} // namespace v2 ++ ++class LandmarksToTransformMatrix : public NodeShader { ++ public: ++ absl::Status GenerateCode(const GenerationContext& ctx, ++ GeneratedCode* generated_code) const final { ++ auto* attr_v1 = ++ absl::any_cast(&ctx.op_attr); ++ if (attr_v1) return v1::GenerateCode(*attr_v1, ctx, generated_code); ++ ++ auto* attr_v2 = ++ absl::any_cast(&ctx.op_attr); ++ if (attr_v2) return v2::GenerateCode(*attr_v2, ctx, generated_code); ++ ++ return absl::InvalidArgumentError("Incorrect attributes' type."); ++ } ++}; ++ ++} // namespace ++ ++std::unique_ptr NewLandmarksToTransformMatrixNodeShader() { ++ return absl::make_unique(); ++} ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.cc.orig b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.cc.orig +new file mode 100644 +index 00000000000..3e884b643a5 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.cc.orig +@@ -0,0 +1,356 @@ ++#include "mediapipe/util/tflite/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.h" ++ ++#include ++#include ++#include ++#include ++#include ++ ++#include "third_party/absl/memory/memory.h" ++#include "third_party/absl/strings/substitute.h" ++#include "third_party/absl/types/any.h" ++#include "mediapipe/util/tflite/gpu/common/mediapipe/landmarks_to_transform_matrix.h" ++#include "third_party/tensorflow/lite/delegates/gpu/common/shape.h" ++#include "third_party/tensorflow/lite/delegates/gpu/common/status.h" ++#include "third_party/tensorflow/lite/delegates/gpu/common/types.h" ++#include "third_party/tensorflow/lite/delegates/gpu/common/util.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++namespace { ++ ++namespace v1 { ++ ++std::string ReadLandmark(const std::string& landmark, const std::string& idx) { ++ std::string source = R"( ++ vec4 )" + landmark + ++ R"(; ++ { ++ int z_coord = )" + ++ idx + ++ R"( * $dimensions$ / 4; ++ vec4 result = $input_data_0[0, 0, z_coord]$; ++ int rest = )" + idx + ++ R"( * $dimensions$ % 4; ++ if (rest != 0) { ++ if (rest == 1) { ++ result.x = result.y; ++ result.y = result.z; ++ } ++ if (rest == 2) { ++ result.x = result.z; ++ result.y = result.w; ++ } ++ if (rest == 3) { ++ vec4 next_after_result = $input_data_0[0, 0, z_coord + 1]$; ++ result.x = result.w; ++ result.y = next_after_result.x; ++ } ++ } ++ )" + landmark + R"( = result; ++ } ++ )"; ++ return source; ++} ++ ++bool IsSupported(const LandmarksToTransformMatrixV1Attributes& attr) { ++ return attr.dimensions == 3; ++} ++ ++absl::Status GenerateCode(const LandmarksToTransformMatrixV1Attributes& attr, ++ const NodeShader::GenerationContext& ctx, ++ GeneratedCode* generated_code) { ++ if (!IsSupported(attr)) { ++ return absl::InvalidArgumentError( ++ "This case is not supported by LandmarksToTransformMatrix v1"); ++ } ++ ++ std::vector params = { ++ {"dimensions", static_cast(attr.dimensions)}, ++ {"landmarks_range", static_cast(attr.landmarks_range)}, ++ {"left_rotation_idx", static_cast(attr.left_rotation_idx)}, ++ {"right_rotation_idx", static_cast(attr.right_rotation_idx)}, ++ {"bbox_size_multiplier", static_cast(attr.bbox_size_multiplier)}, ++ {"input_h", static_cast(attr.input_hw.h)}, ++ {"input_w", static_cast(attr.input_hw.w)}, ++ {"output_h", static_cast(attr.output_hw.h)}, ++ {"output_w", static_cast(attr.output_hw.w)}, ++ {"subset", attr.subset}, ++ {"subset_size", static_cast(attr.subset.size())}, ++ }; ++ ++ std::string source = R"( ++ )" + ReadLandmark("left_landmark", "$left_rotation_idx$") + ++ R"( ++ ++ )" + ReadLandmark("right_landmark", "$right_rotation_idx$") + ++ R"( ++ ++ float alpha = -atan(right_landmark.y - left_landmark.y, ++ right_landmark.x - left_landmark.x); ++ ++ vec4 max_value = vec4(-100000, -100000, 0.0, 0.0); ++ vec4 min_value = vec4(100000, 100000, 0.0, 0.0); ++ for (int i = 0; i < $subset_size$; i++) { ++ for (int j = 0; j < 2; j++) { ++ )" + ReadLandmark("landmark_current", "$subset$[i][j]") + ++ R"( ++ ++ vec4 rotated = vec4(landmark_current.x * cos(alpha) - ++ landmark_current.y * sin(alpha), ++ landmark_current.x * sin(alpha) + ++ landmark_current.y * cos(alpha), ++ 0.0, 0.0); ++ // both by x and y ++ max_value = vec4(max(max_value.x, rotated.x), ++ max(max_value.y, rotated.y), ++ 0.0, 0.0); ++ min_value = vec4(min(min_value.x, rotated.x), ++ min(min_value.y, rotated.y), ++ 0.0, 0.0); ++ } ++ } ++ ++ vec4 bbox_size = max_value - min_value; ++ bbox_size *= $bbox_size_multiplier$; ++ ++ mat3 scale_matrix = ++ mat3(bbox_size.x / float($landmarks_range$), 0.0, 0.0, // first column ++ 0.0, bbox_size.y / float($landmarks_range$), 0.0, // second column ++ 0.0, 0.0, 1.0); // third column ++ ++ vec4 middle = (max_value + min_value) / 2.0; ++ ++ vec4 rotated_middle = ++ vec4(middle.x * cos(-alpha) - middle.y * sin(-alpha), ++ middle.x * sin(-alpha) + middle.y * cos(-alpha), 0.0, 0.0); ++ ++ mat3 rotation_matrix = ++ mat3(cos(-alpha), sin(-alpha), 0, // first column ++ -sin(-alpha), cos(-alpha), 0, // second column ++ // third column ++ (rotated_middle.x / float($landmarks_range$)) * 2.0 - 1.0, ++ (rotated_middle.y / float($landmarks_range$)) * 2.0 - 1.0, 1); ++ ++ mat3 to_relative = ++ mat3(2.0 / (float($output_w$) - 1.0), 0.0, 0.0, // first column ++ 0.0, 2.0 / (float($output_h$) - 1.0), 0.0, // second column ++ -1.0, -1.0, 1.0); // third column ++ ++ mat3 to_absolute = ++ mat3((float($input_w$) - 1.0) / 2.0, 0.0, 0.0, // first column ++ 0.0, (float($input_h$) - 1.0) / 2.0, 0.0, // second column ++ // third column ++ (float($input_w$) - 1.0) / 2.0, (float($input_h$) - 1.0)/2.0, 1.0); ++ ++ // Transformstion Matrix ++ mat3 tm = to_absolute * rotation_matrix * scale_matrix * to_relative; ++ ++ // Inverse Transformation Matrix ++ $output_data_0[0, 0, 0] = vec4(tm[0][0], tm[1][0], 0.0, tm[2][0])$; ++ $output_data_0[1, 0, 0] = vec4(tm[0][1], tm[1][1], 0.0, tm[2][1])$; ++ $output_data_0[2, 0, 0] = vec4(tm[0][2], tm[1][2], tm[2][2], 0.0)$; ++ $output_data_0[3, 0, 0] = vec4( 0, 0, 0, 1.0)$; ++ )"; ++ ++ *generated_code = { ++ /*parameters=*/params, ++ /*objects=*/{}, ++ /*shared_variables=*/{}, ++ /*workload=*/uint3(1, 1, 1), ++ /*workgroup=*/uint3(1, 1, 1), ++ /*source_code=*/std::move(source), ++ /*input=*/IOStructure::ONLY_DEFINITIONS, ++ /*output=*/IOStructure::ONLY_DEFINITIONS, ++ }; ++ return absl::OkStatus(); ++} ++ ++} // namespace v1 ++ ++namespace v2 { ++ ++std::string ReadLandmark(const std::string& landmark, const std::string& idx) { ++ std::string source = R"( ++ vec4 )" + landmark + ++ R"(; ++ { ++ int z_coord = )" + ++ idx + ++ R"( * $dimensions$ / 4; ++ vec4 result = $input_data_0[0, 0, z_coord]$; ++ int rest = )" + idx + ++ R"( * $dimensions$ % 4; ++ if (rest != 0) { ++ if (rest == 1) { ++ result.x = result.y; ++ result.y = result.z; ++ } ++ if (rest == 2) { ++ result.x = result.z; ++ result.y = result.w; ++ } ++ if (rest == 3) { ++ vec4 next_after_result = $input_data_0[0, 0, z_coord + 1]$; ++ result.x = result.w; ++ result.y = next_after_result.x; ++ } ++ } ++ result *= $multiplier$; ++ )" + landmark + R"( = result; ++ } )"; ++ return source; ++} ++ ++static bool IsSupported(const NodeShader::GenerationContext& ctx) { ++ return ctx.input_shapes.size() == 1 && ctx.input_shapes[0][1] == 1 && ++ ctx.input_shapes[0][2] == 1 && ctx.input_shapes[0][3] % 3 == 0; ++} ++ ++absl::Status GenerateCode(const LandmarksToTransformMatrixV2Attributes& attr, ++ const NodeShader::GenerationContext& ctx, ++ GeneratedCode* generated_code) { ++ if (!IsSupported(ctx)) { ++ return absl::InvalidArgumentError( ++ "This case is not supported by LandmarksToTransformMatrixV2"); ++ } ++ ++ std::vector params = { ++ {"dimensions", static_cast(3)}, ++ {"scale_x", static_cast(attr.scale_x)}, ++ {"scale_y", static_cast(attr.scale_y)}, ++ {"left_rotation_idx", static_cast(attr.left_rotation_idx)}, ++ {"right_rotation_idx", static_cast(attr.right_rotation_idx)}, ++ {"target_rotation_radians", ++ static_cast(attr.target_rotation_radians)}, ++ {"output_width", static_cast(attr.output_width)}, ++ {"output_height", static_cast(attr.output_height)}, ++ {"subset_idxs", attr.subset_idxs}, ++ {"subset_idxs_size", static_cast(attr.subset_idxs.size())}, ++ {"multiplier", static_cast(attr.multiplier)}, ++ }; ++ ++ std::string source = R"( ++ )" + ReadLandmark("left_landmark", "$left_rotation_idx$") + ++ R"( ++ )" + ReadLandmark("right_landmark", "$right_rotation_idx$") + ++ R"( ++ ++ float diff_y = right_landmark.y - left_landmark.y; ++ float diff_x = right_landmark.x - left_landmark.x; ++ float rotation = 0.0; ++ if (diff_y != 0.0 && diff_x != 0.0) rotation = atan(diff_y, diff_x); ++ float r = $target_rotation_radians$ - rotation; ++ ++ vec4 max_value = vec4(-100000, -100000, 0.0, 0.0); ++ vec4 min_value = vec4(100000, 100000, 0.0, 0.0); ++ for (int i = 0; i < $subset_idxs_size$; i++) { ++ for (int j = 0; j < 2; j++) { ++ )" + ReadLandmark("landmark_current", "$subset_idxs$[i][j]") + ++ R"( ++ vec4 rotated = vec4(landmark_current.x * cos(r) - ++ landmark_current.y * sin(r), ++ landmark_current.x * sin(r) + ++ landmark_current.y * cos(r), ++ 0.0, 0.0); ++ // both by x and y ++ max_value = vec4(max(max_value.x, rotated.x), ++ max(max_value.y, rotated.y), ++ 0.0, 0.0); ++ min_value = vec4(min(min_value.x, rotated.x), ++ min(min_value.y, rotated.y), ++ 0.0, 0.0); ++ } ++ } ++ ++ float crop_width = max_value.x - min_value.x; ++ float crop_height = max_value.y - min_value.y; ++ ++ vec4 crop_xy1 = (max_value + min_value) / vec4(2.0); ++ ++ float crop_x = cos(-r) * crop_xy1.x - sin(-r) * crop_xy1.y; ++ float crop_y = sin(-r) * crop_xy1.x + cos(-r) * crop_xy1.y; ++ ++ ++ mat4 t = mat4(1.0, 0.0, 0.0, 0.0, // first column ++ 0.0, 1.0, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ 0.0, 0.0, 0.0, 1.0); // forth column ++ ++ mat4 t_shift = mat4(1.0, 0.0, 0.0, 0.0, // first column ++ 0.0, 1.0, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ crop_x, crop_y, 0.0, 1.0); // forth column ++ t *= t_shift; ++ ++ r = -r; ++ ++ mat4 t_rotation = mat4(cos(r), sin(r), 0.0, 0.0, // first column ++ -sin(r), cos(r), 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ 0.0, 0.0, 0.0, 1.0); // forth column ++ ++ t *= t_rotation; ++ // cropped scale for x and y ++ float cs_x = $scale_x$ * crop_width / $output_width$; ++ float cs_y = $scale_y$ * crop_height / $output_height$; ++ mat4 t_scale = mat4(cs_x, 0.0, 0.0, 0.0, // first column ++ 0.0, cs_y, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ 0.0, 0.0, 0.0, 1.0); // forth column ++ t *= t_scale; ++ float shift_x = -1.0 * ($output_width$ / 2.0); ++ float shift_y = -1.0 * ($output_height$ / 2.0); ++ mat4 t_shift2 = mat4(1.0, 0.0, 0.0, 0.0, // first column ++ 0.0, 1.0, 0.0, 0.0, // second column ++ 0.0, 0.0, 1.0, 0.0, // third column ++ shift_x, shift_y, 0.0, 1.0); // forth column ++ t *= t_shift2; ++ // Inverse Transformation Matrix ++ $output_data_0[0, 0, 0] = vec4(t[0][0], t[1][0], t[2][0], t[3][0])$; ++ $output_data_0[1, 0, 0] = vec4(t[0][1], t[1][1], t[2][1], t[3][1])$; ++ $output_data_0[2, 0, 0] = vec4(t[0][2], t[1][2], t[2][2], t[3][2])$; ++ $output_data_0[3, 0, 0] = vec4(t[0][3], t[1][3], t[2][3], t[3][3])$; ++ )"; ++ ++ *generated_code = { ++ /*parameters=*/params, ++ /*objects=*/{}, ++ /*shared_variables=*/{}, ++ /*workload=*/uint3(1, 1, 1), ++ /*workgroup=*/uint3(1, 1, 1), ++ /*source_code=*/std::move(source), ++ /*input=*/IOStructure::ONLY_DEFINITIONS, ++ /*output=*/IOStructure::ONLY_DEFINITIONS, ++ }; ++ return absl::OkStatus(); ++} ++ ++} // namespace v2 ++ ++class LandmarksToTransformMatrix : public NodeShader { ++ public: ++ absl::Status GenerateCode(const GenerationContext& ctx, ++ GeneratedCode* generated_code) const final { ++ auto* attr_v1 = ++ absl::any_cast(&ctx.op_attr); ++ if (attr_v1) return v1::GenerateCode(*attr_v1, ctx, generated_code); ++ ++ auto* attr_v2 = ++ absl::any_cast(&ctx.op_attr); ++ if (attr_v2) return v2::GenerateCode(*attr_v2, ctx, generated_code); ++ ++ return absl::InvalidArgumentError("Incorrect attributes' type."); ++ } ++}; ++ ++} // namespace ++ ++std::unique_ptr NewLandmarksToTransformMatrixNodeShader() { ++ return absl::make_unique(); ++} ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.h b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.h +new file mode 100644 +index 00000000000..d3949050578 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.h +@@ -0,0 +1,19 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_LANDMARKS_TO_TRANSFORM_MATRIX_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_LANDMARKS_TO_TRANSFORM_MATRIX_H_ ++ ++#include ++ ++#include "tensorflow/lite/delegates/gpu/common/operations.h" ++#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++ ++std::unique_ptr NewLandmarksToTransformMatrixNodeShader(); ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_LANDMARKS_TO_TRANSFORM_MATRIX_H_ +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/registry.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/registry.cc +new file mode 100644 +index 00000000000..3ef02a248c3 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/registry.cc +@@ -0,0 +1,28 @@ ++#include ++#include ++#include ++ ++#include "absl/container/flat_hash_map.h" ++#include "tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h" ++#include "tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/landmarks_to_transform_matrix.h" ++#include "tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.h" ++#include "tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++ ++void RegisterCustomOps( ++ absl::flat_hash_map>>* ++ shaders) { ++ (*shaders)["landmarks_to_transform_matrix"].push_back( ++ NewLandmarksToTransformMatrixNodeShader()); ++ (*shaders)["transform_landmarks"].push_back( ++ NewTransformLandmarksNodeShader()); ++ (*shaders)["transform_tensor_bilinear"].push_back( ++ NewTransformTensorBilinearNodeShader()); ++} ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.cc +new file mode 100644 +index 00000000000..980e2aa99e6 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.cc +@@ -0,0 +1,123 @@ ++#include "tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.h" ++ ++#include ++#include ++#include ++#include ++#include ++ ++#include "absl/memory/memory.h" ++#include "absl/strings/substitute.h" ++#include "absl/types/any.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/types.h" ++#include "tensorflow/lite/delegates/gpu/common/util.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++namespace { ++ ++class TransformLandmarks : public NodeShader { ++ public: ++ absl::Status GenerateCode(const GenerationContext& ctx, ++ GeneratedCode* generated_code) const final { ++ if (!IsSupported(ctx)) { ++ return absl::InvalidArgumentError( ++ "This case is not supported by TransformLandmarks"); ++ } ++ ++ const auto& attr = ++ absl::any_cast(ctx.op_attr); ++ ++ // For transformlandmarks v2 scale parameter is set to 1 when operation is ++ // parsed. ++ std::vector params; ++ if (attr.scale != 1) { ++ params.push_back({"scale", static_cast(attr.scale)}); ++ } ++ std::string source = R"( ++ vec4 x_transform = $input_data_1[0, 0, 0]$; ++ vec4 y_transform = $input_data_1[1, 0, 0]$; )"; ++ if (attr.scale != 1) { ++ source += R"( ++ x_transform.w *= $scale$; ++ y_transform.w *= $scale$; ++ )"; ++ } ++ source += R"( ++ vec4 landmks = $input_data_0[gid.x, gid.y, gid.z]$; ++ vec4 transformed = vec4(0.0); ++ )"; ++ switch (attr.dimensions) { ++ case 2: ++ source += R"( ++ // x y x y ++ vec4 l_pair1_ = vec4(landmks.x, landmks.y, 0.0, 1.0); ++ vec4 l_pair2_ = vec4(landmks.z, landmks.w, 0.0, 1.0); ++ transformed = vec4(dot(x_transform, l_pair1_), dot(y_transform, l_pair1_), ++ dot(x_transform, l_pair2_), dot(y_transform, l_pair2_)); ++ ++ value_0 = transformed; ++ )"; ++ break; ++ case 3: ++ source += R"( ++ if ((gid.z * 4) % 3 == 0) { // 0, 3, 6 ++ // x y z x ++ vec4 landmks_next = $input_data_0[gid.x, gid.y, gid.z + 1]$; ++ vec4 l_= landmks; ++ l_.z = 0.0; ++ l_.w = 1.0; ++ transformed = vec4(dot(x_transform, l_), ++ dot(y_transform, l_), ++ landmks.z, dot(x_transform, vec4(landmks.w, landmks_next.x, 0.0, 1.0))); ++ } else if ((gid.z * 4) % 3 == 1) { // 1, 4, 7 ++ // y z x y ++ vec4 landmks_prev = $input_data_0[gid.x, gid.y, gid.z - 1]$; ++ vec4 l_ = vec4(landmks.z, landmks.w, 0.0, 1.0); ++ transformed = vec4(dot(y_transform, vec4(landmks_prev.w, landmks.x, 0.0, 1.0)), landmks.y, ++ dot(x_transform, l_), dot(y_transform, l_)); ++ } else if ((gid.z * 4) % 3 == 2) { // 2, 5, 8 ++ // z, x, y, z ++ vec4 l_ = vec4(landmks.y, landmks.z, 0.0, 1.0); ++ transformed = vec4(landmks.x, dot(x_transform, l_), ++ dot(y_transform, l_), landmks.w); ++ } ++ value_0 = transformed; ++ )"; ++ break; ++ } ++ ++ *generated_code = { ++ /*parameters=*/params, ++ /*objects=*/{}, ++ /*shared_variables=*/{}, ++ /*workload=*/uint3(), ++ /*workgroup=*/uint3(), ++ /*source_code=*/std::move(source), ++ /*input=*/IOStructure::ONLY_DEFINITIONS, ++ /*output=*/IOStructure::AUTO, ++ }; ++ return absl::OkStatus(); ++ } ++ ++ private: ++ static bool IsSupported(const GenerationContext& ctx) { ++ const auto& attr = ++ absl::any_cast(ctx.op_attr); ++ return (attr.dimensions == 2 || attr.dimensions == 3) && attr.version == 1; ++ } ++}; ++ ++} // namespace ++ ++std::unique_ptr NewTransformLandmarksNodeShader() { ++ return absl::make_unique(); ++} ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.h b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.h +new file mode 100644 +index 00000000000..cfb656675e4 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_landmarks.h +@@ -0,0 +1,19 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_TRANSFORM_LANDMARKS_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_TRANSFORM_LANDMARKS_H_ ++ ++#include ++ ++#include "tensorflow/lite/delegates/gpu/common/operations.h" ++#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++ ++std::unique_ptr NewTransformLandmarksNodeShader(); ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_TRANSFORM_LANDMARKS_H_ +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.cc b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.cc +new file mode 100644 +index 00000000000..8013b9b3505 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.cc +@@ -0,0 +1,169 @@ ++#include "tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.h" ++ ++#include ++#include ++#include ++#include ++#include ++ ++#include "absl/memory/memory.h" ++#include "absl/strings/substitute.h" ++#include "absl/types/any.h" ++#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_tensor_bilinear.h" ++#include "tensorflow/lite/delegates/gpu/common/shape.h" ++#include "tensorflow/lite/delegates/gpu/common/status.h" ++#include "tensorflow/lite/delegates/gpu/common/types.h" ++#include "tensorflow/lite/delegates/gpu/common/util.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++namespace { ++ ++class TransformTensorBilinear : public NodeShader { ++ public: ++ absl::Status GenerateCode(const GenerationContext& ctx, ++ GeneratedCode* generated_code) const final { ++ if (!IsSupported(ctx)) { ++ return absl::InvalidArgumentError( ++ "This case is not supported by TransformTensorBilinear."); ++ } ++ ++ std::vector params = { ++ {"input_data_0_h", static_cast(ctx.input_shapes[0][1])}, ++ {"input_data_0_w", static_cast(ctx.input_shapes[0][2])}}; ++ ++ // Only bilinear transformation is supported right now. ++ std::string source = R"( ++ vec4 first_line = $input_data_1[0, 0, 0]$; ++ vec4 second_line = $input_data_1[1, 0, 0]$; ++ )" + AlignCornersCorrection(ctx) + ++ R"( ++ vec4 before_transform_coord_2d = vec4(gid.x, gid.y, 0.0, 1.0); ++ ++ // Get transformed coordinates ++ vec2 xy = vec2(dot(first_line, before_transform_coord_2d), ++ dot(second_line, before_transform_coord_2d)); ++ ++ // Get coordinates of corners to interpolate from. ++ int x1 = int(floor(xy.x)); // x2 is x1 + 1 ++ int y1 = int(floor(xy.y)); // y2 is y1 + 1 ++ ++ // Apply interpolation if coordinate is in bounds. ++ vec4 result = vec4(0.0); ++ ++ if(xy.x >= 0.0 && xy.x <= float($input_data_0_w$ -1) && ++ xy.y >= 0.0 && xy.y <= float($input_data_0_h$ -1)) { ++ ++ // Corners position: ++ // q_11 --- q_21 ++ // ---- ---- ++ // q_12 --- q_22 ++)"; ++ source += SampleFromInput0("q_11", "x1", "y1") + ++ SampleFromInput0("q_12", "x1", "y1 + 1") + ++ SampleFromInput0("q_21", "x1 + 1", "y1") + ++ SampleFromInput0("q_22", "x1 + 1", "y1 + 1") + R"( ++ ++ float right_contrib = xy.x - float(x1); ++ float lower_contrib = xy.y - float(y1); ++ ++ vec4 upper = (1.0 - right_contrib) * q_11 + right_contrib * q_21; ++ vec4 lower = (1.0 - right_contrib) * q_12 + right_contrib * q_22; ++ ++ result = lower_contrib * lower + (1.0 - lower_contrib) * upper; ++ ++ } ++ value_0 = result; ++ )"; ++ ++ *generated_code = { ++ /*parameters=*/params, ++ /*objects=*/{}, ++ /*shared_variables=*/{}, ++ /*workload=*/uint3(), ++ /*workgroup=*/uint3(), ++ /*source_code=*/std::move(source), ++ /*input=*/IOStructure::ONLY_DEFINITIONS, ++ /*output=*/IOStructure::AUTO, ++ }; ++ return absl::OkStatus(); ++ } ++ ++ private: ++ std::string SampleFromInput0(absl::string_view variable, ++ absl::string_view x_coord, ++ absl::string_view y_coord) const { ++ // This function generates code, which samples data from the first input ++ // tensor and checks the coordinates' bounds: ++ // ++ // vec4 q = vec4(0.0); ++ // [0, H) ++ // if (x >= 0 && x < $input_data_0_w$ && y >= 0 && y < $input_data_0_h$) { ++ // q = $input_data_0[x, y, gid.z]$; ++ // } ++ ++ // Create zero initialized variable on stack ++ std::string result = ++ absl::Substitute(" vec4 $0 = vec4(0.0);\n", variable); ++ // If coordinates are not out of scope, load value from input_data_0 ++ absl::SubstituteAndAppend( ++ &result, ++ " if ($0 >= 0 && $1 < $$input_data_0_w$$ && " ++ "$2 >= 0 && $3 < $$input_data_0_h$$) {\n", ++ x_coord, x_coord, y_coord, y_coord); ++ absl::SubstituteAndAppend( ++ &result, ++ " $0 = $$input_data_0[$1, $2, gid.z]$$;\n }\n\n", ++ variable, x_coord, y_coord); ++ return result; ++ } ++ ++ std::string AlignCornersCorrection(const GenerationContext& ctx) const { ++ const auto& attr = ++ absl::any_cast(ctx.op_attr); ++ // Align corners correction: T -> S * ( T * A ), where T is a ++ // transformation matrix, and subtruction and addition matrices are: ++ // S A ++ // 1 0 0 -0.5 1 0 0 0.5 ++ // 0 1 0 -0.5 0 1 0 0.5 ++ // 0 0 1 0 0 0 1 0 ++ // 0 0 0 1 0 0 0 1 ++ // Transformation matrix column 3 and rows 3, 4 are identity, which makes ++ // the final formula pretty simple and easy to get if doing a manual ++ // multiuplication. ++ if (attr.align_corners) { ++ return R"( ++ first_line.w += first_line.x * 0.5 + first_line.y * 0.5 - 0.5; ++ second_line.w += second_line.x * 0.5 + second_line.y * 0.5 - 0.5; ++ )"; ++ } else { ++ return ""; ++ } ++ } ++ ++ static bool IsSupported(const GenerationContext& ctx) { ++ // if version 2 - align corners is turned on. ++ // both versions expect transformation matrix as 1x1x1x16 ++ if (ctx.input_shapes.size() != 2) return false; ++ ++ if (ctx.input_shapes[1][0] != 1 || ctx.input_shapes[1][1] != 1 || ++ ctx.input_shapes[1][2] != 4 || ctx.input_shapes[1][3] != 4) ++ return false; ++ ++ const auto& attr = ++ absl::any_cast(ctx.op_attr); ++ return attr.output_size.h > 0 && attr.output_size.w > 0 && ++ attr.version == 1; ++ } ++}; ++ ++} // namespace ++ ++std::unique_ptr NewTransformTensorBilinearNodeShader() { ++ return absl::make_unique(); ++} ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite +diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.h b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.h +new file mode 100644 +index 00000000000..c62387a4b96 +--- /dev/null ++++ b/tensorflow/lite/delegates/gpu/gl/kernels/mediapipe/transform_tensor_bilinear.h +@@ -0,0 +1,19 @@ ++#ifndef TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_TRANSFORM_TENSOR_BILINEAR_H_ ++#define TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_TRANSFORM_TENSOR_BILINEAR_H_ ++ ++#include ++ ++#include "tensorflow/lite/delegates/gpu/common/operations.h" ++#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" ++ ++namespace tflite { ++namespace gpu { ++namespace gl { ++ ++std::unique_ptr NewTransformTensorBilinearNodeShader(); ++ ++} // namespace gl ++} // namespace gpu ++} // namespace tflite ++ ++#endif // TENSORFLOW_LITE_DELEGATES_GPU_GL_KERNELS_MEDIAPIPE_TRANSFORM_TENSOR_BILINEAR_H_