From ed1275b673a1b534874bfe1a974ba8b8e5ae96b0 Mon Sep 17 00:00:00 2001 From: Maksym Walczak Date: Mon, 3 Jan 2022 15:42:27 +0100 Subject: [PATCH] Add pose tracking subproject --- mediapipe/pose_tracking_dll/BUILD | 59 ++++++ mediapipe/pose_tracking_dll/README.md | 30 +++ mediapipe/pose_tracking_dll/pose_tracking.cpp | 179 ++++++++++++++++++ mediapipe/pose_tracking_dll/pose_tracking.h | 99 ++++++++++ .../pose_tracking_dll/windows_dll_library.bzl | 62 ++++++ 5 files changed, 429 insertions(+) create mode 100644 mediapipe/pose_tracking_dll/BUILD create mode 100644 mediapipe/pose_tracking_dll/README.md create mode 100644 mediapipe/pose_tracking_dll/pose_tracking.cpp create mode 100644 mediapipe/pose_tracking_dll/pose_tracking.h create mode 100644 mediapipe/pose_tracking_dll/windows_dll_library.bzl diff --git a/mediapipe/pose_tracking_dll/BUILD b/mediapipe/pose_tracking_dll/BUILD new file mode 100644 index 000000000..98b5f9dc9 --- /dev/null +++ b/mediapipe/pose_tracking_dll/BUILD @@ -0,0 +1,59 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("windows_dll_library.bzl", "windows_dll_library") +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/examples:__subpackages__"]) + +# Define the shared library +windows_dll_library( + name = "pose_tracking_lib", + srcs = ["pose_tracking.cpp"], + hdrs = ["pose_tracking.h"], + # Define COMPILING_DLL to export symbols during compiling the DLL. + copts = ["-DCOMPILING_DLL"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_video", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + + "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/core:packet_presence_calculator", + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/calculators/tflite:tflite_model_calculator", + "//mediapipe/calculators/util:local_file_contents_calculator", + "//mediapipe/graphs/pose_tracking:pose_tracking_cpu_deps", + ] +) + +# **Implicitly link to face_mesh_lib.dll** +cc_binary( + name = "pose_tracking_cpu", + deps = [ + "//mediapipe/examples/desktop:demo_run_graph_main", + "//mediapipe/graphs/pose_tracking:pose_tracking_cpu_deps", + ":pose_tracking_lib" + ], +) diff --git a/mediapipe/pose_tracking_dll/README.md b/mediapipe/pose_tracking_dll/README.md new file mode 100644 index 000000000..a183c4f2b --- /dev/null +++ b/mediapipe/pose_tracking_dll/README.md @@ -0,0 +1,30 @@ +## Description +The pose_tracking_dll module allows for building a dll library that can be used with any C++ project. All the dependencies such as tensorflow are built statically into the dll. + +Currently, the following features are supported: +- Segmenting the person(s) of interest +- Segmenting the skeleton(s) +- Accessing the 3D coordinates of each node of the skeleton by name (using enum) + +## Prerequisites +Follow the guidelines on the official Mediapipe website: https://google.github.io/mediapipe/getting_started/install.html#installing-on-windows + +IMPORTANT: The tutorial does not specify which version of Bazel to install. Install Bazel version 3.7.2 + +## How to build +Assuming you're in the root of the repository: + +cd mediapipe + +bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 pose_tracking_dll:pose_tracking_cpu + +The results will be stored in bazel-bin\mediapipe\pose_tracking_dll folder. + +## How to use +Go to bazel-bin\mediapipe\pose_tracking_dll + +Link pose_tracking_cpu.lib and pose_tracking_lib.dll.if.lib statically in your project. + +Make sure the opencv_world3410.dll and pose_tracking_lib.dll are accessible in your working directory. + +Use mediapipe\pose_tracking_dll\pose_tracking.h header file to access the methods of the library. diff --git a/mediapipe/pose_tracking_dll/pose_tracking.cpp b/mediapipe/pose_tracking_dll/pose_tracking.cpp new file mode 100644 index 000000000..65f9f619f --- /dev/null +++ b/mediapipe/pose_tracking_dll/pose_tracking.cpp @@ -0,0 +1,179 @@ +#include +#include + +#include "pose_tracking.h" + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_video_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" + +class PoseTrackingImpl { +public: + PoseTrackingImpl(const std::string& calculatorGraphConfigFile) { + auto status = initialize(calculatorGraphConfigFile); + if (!status.ok()) { + LOG(WARNING) << "Warning: " << status; + } + } + + absl::Status initialize(const std::string& calculatorGraphConfigFile) { + std::string graphContents; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents( + calculatorGraphConfigFile, + &graphContents)); + + mediapipe::CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + graphContents); + + MP_RETURN_IF_ERROR(graph.Initialize(config)); + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, + graph.AddOutputStreamPoller(kOutputSegmentationStream)); + + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller landmarksPoller, + graph.AddOutputStreamPoller(kOutpuLandmarksStream)); + + ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller posePresencePoller, + graph.AddOutputStreamPoller(kOutpuPosePresenceStream)); + + + maskPollerPtr = std::make_unique(std::move(poller)); + + landmarksPollerPtr = std::make_unique( + std::move(landmarksPoller)); + + posePresencePollerPtr = std::make_unique( + std::move(posePresencePoller)); + + MP_RETURN_IF_ERROR(graph.StartRun({})); + } + + bool processFrame(const cv::Mat& inputRGB8Bit) { + // Wrap Mat into an ImageFrame. + auto inputFrame = absl::make_unique( + mediapipe::ImageFormat::SRGB, inputRGB8Bit.cols, inputRGB8Bit.rows, + mediapipe::ImageFrame::kDefaultAlignmentBoundary); + cv::Mat inputFrameMat = mediapipe::formats::MatView(inputFrame.get()); + inputRGB8Bit.copyTo(inputFrameMat); + + // Send image packet into the graph. + size_t frameTimestampUs = + (double)cv::getTickCount() / (double)cv::getTickFrequency() * 1e6; + auto status = graph.AddPacketToInputStream( + kInputStream, mediapipe::Adopt(inputFrame.release()) + .At(mediapipe::Timestamp(frameTimestampUs))); + + if (!status.ok()) { + LOG(WARNING) << "Graph execution failed: " << status; + return false; + } + + mediapipe::Packet posePresencePacket; + if (!posePresencePollerPtr || !posePresencePollerPtr->Next(&posePresencePacket)) return false; + auto landmarksDetected = posePresencePacket.Get(); + + if (!landmarksDetected) { + return false; + } + + // Get the graph result packet, or stop if that fails. + mediapipe::Packet maskPacket; + if (!maskPollerPtr || !maskPollerPtr->Next(&maskPacket)) return false; + auto& outputFrame = maskPacket.Get(); + + // Get pose landmarks. + if (!landmarksPollerPtr || + !landmarksPollerPtr->Next(&poseLandmarksPacket)) { + return false; + } + + // Convert back to opencv for display or saving. + auto mask = mediapipe::formats::MatView(&outputFrame); + segmentedMask = mask.clone(); + + absl::Status landmarksStatus = detectLandmarksWithStatus(poseLandmarks); + + return landmarksStatus.ok(); + } + + absl::Status detectLandmarksWithStatus( + nimagna::cv_wrapper::Point3f* poseLandmarks) { + + if (poseLandmarksPacket.IsEmpty()) { + return absl::CancelledError("Pose landmarks packet is empty."); + } + + auto retrievedLandmarks = + poseLandmarksPacket + .Get<::mediapipe::NormalizedLandmarkList>(); + + // Convert landmarks to cv::Point3f**. + const auto landmarksCount = retrievedLandmarks.landmark_size(); + + for (int j = 0; j < landmarksCount; ++j) { + const auto& landmark = retrievedLandmarks.landmark(j); + poseLandmarks[j].x = landmark.x(); + poseLandmarks[j].y = landmark.y(); + poseLandmarks[j].z = landmark.z(); + } + + return absl::OkStatus(); + } + + nimagna::cv_wrapper::Point3f* lastDetectedLandmarks() { + return poseLandmarks; + } + + cv::Mat lastSegmentedFrame() { + return segmentedMask; + } + + static constexpr size_t kLandmarksCount = 33u; + +private: + mediapipe::Packet poseLandmarksPacket; + cv::Mat segmentedMask; + nimagna::cv_wrapper::Point3f poseLandmarks[kLandmarksCount]; + std::unique_ptr posePresencePollerPtr; + std::unique_ptr maskPollerPtr; + std::unique_ptr landmarksPollerPtr; + mediapipe::CalculatorGraph graph; + const char* kInputStream = "input_video"; + const char* kOutputSegmentationStream = "segmentation_mask"; + const char* kOutpuLandmarksStream = "pose_landmarks"; + const char* kOutpuPosePresenceStream = "pose_presence"; +}; + +namespace nimagna { + PoseTracking::PoseTracking(const char* calculatorGraphConfigFile) { + myInstance = new PoseTrackingImpl(calculatorGraphConfigFile); + } + + bool PoseTracking::processFrame(const cv_wrapper::Mat& inputRGB8Bit) { + auto* instance = static_cast(myInstance); + const auto frame = cv::Mat(inputRGB8Bit.rows, inputRGB8Bit.cols, CV_8UC3, inputRGB8Bit.data); + return instance->processFrame(frame); + } + + cv_wrapper::Point3f* PoseTracking::lastDetectedLandmarks() { + auto* instance = static_cast(myInstance); + return instance->lastDetectedLandmarks(); + } + + cv_wrapper::Mat PoseTracking::lastSegmentedFrame() { + auto* instance = static_cast(myInstance); + const cv::Mat result = instance->lastSegmentedFrame(); + + return cv_wrapper::Mat(result.rows, result.cols, result.data); + } + +} diff --git a/mediapipe/pose_tracking_dll/pose_tracking.h b/mediapipe/pose_tracking_dll/pose_tracking.h new file mode 100644 index 000000000..34161506c --- /dev/null +++ b/mediapipe/pose_tracking_dll/pose_tracking.h @@ -0,0 +1,99 @@ +#ifndef POSE_TRACKING_LIBRARY_H +#define POSE_TRACKING_LIBRARY_H + +#ifdef COMPILING_DLL +#define DLLEXPORT __declspec(dllexport) +#else +#define DLLEXPORT __declspec(dllimport) +#endif + +namespace nimagna { + namespace cv_wrapper { + struct Point2f { + float x = 0; + float y = 0; + + Point2f() = default; + Point2f(float x, float y) : x(x), y(y) {} + }; + struct Point3f { + float x = 0; + float y = 0; + float z = 0; + + Point3f() = default; + Point3f(float x, float y, float z) : x(x), y(y), z(z) {} + }; + + struct Rect { + int x = 0; + int y = 0; + int width = 0; + int height = 0; + + Rect() = default; + Rect(int x, int y, int width, int height) : x(x), y(y), width(width), height(height) {} + }; + + struct Mat { + int rows = 0; + int cols = 0; + unsigned char* data = 0; + + Mat(int rows, int cols, unsigned char* data) : rows(rows), cols(cols), data(data) {} + }; + } + + class DLLEXPORT PoseTracking { + public: + static constexpr size_t landmarksCount = 33u; + enum LandmarkNames { + NOSE = 0, + LEFT_EYE_INNER, + LEFT_EYE, + LEFT_EYE_OUTER, + RIGHT_EYE_INNER, + RIGHT_EYE, + RIGHT_EYE_OUTER, + LEFT_EAR, + RIGHT_EAR, + MOUTH_LEFT, + MOUTH_RIGHT, + LEFT_SHOULDER, + RIGHT_SHOULDER, + LEFT_ELBOW, + RIGHT_ELBOW, + LEFT_WRIST, + RIGHT_WRIST, + LEFT_PINKY, + RIGHT_PINKY, + LEFT_INDEX, + RIGHT_INDEX, + LEFT_THUMB, + RIGHT_THUMB, + LEFT_HIP, + RIGHT_HIP, + LEFT_KNEE, + RIGHT_KNEE, + LEFT_ANKLE, + RIGHT_ANKLE, + LEFT_HEEL, + RIGHT_HEEL, + LEFT_FOOT_INDEX, + RIGHT_FOOT_INDEX, + COUNT = landmarksCount + }; + + PoseTracking(const char* calculatorGraphConfigFile); + ~PoseTracking() { delete myInstance; } + + bool processFrame(const cv_wrapper::Mat& inputRGB8Bit); + cv_wrapper::Mat lastSegmentedFrame(); + cv_wrapper::Point3f* lastDetectedLandmarks(); + + private: + void* myInstance; + }; +} + +#endif \ No newline at end of file diff --git a/mediapipe/pose_tracking_dll/windows_dll_library.bzl b/mediapipe/pose_tracking_dll/windows_dll_library.bzl new file mode 100644 index 000000000..69c243d60 --- /dev/null +++ b/mediapipe/pose_tracking_dll/windows_dll_library.bzl @@ -0,0 +1,62 @@ +""" +This is a simple windows_dll_library rule for builing a DLL Windows +that can be depended on by other cc rules. +Example useage: + windows_dll_library( + name = "hellolib", + srcs = [ + "hello-library.cpp", + ], + hdrs = ["hello-library.h"], + # Define COMPILING_DLL to export symbols during compiling the DLL. + copts = ["/DCOMPILING_DLL"], + ) +""" + +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_import", "cc_library") + +def windows_dll_library( + name, + srcs = [], + deps = [], + hdrs = [], + visibility = None, + **kwargs): + """A simple windows_dll_library rule for builing a DLL Windows.""" + dll_name = name + ".dll" + import_lib_name = name + "_import_lib" + import_target_name = name + "_dll_import" + + # Build the shared library + cc_binary( + name = dll_name, + srcs = srcs + hdrs, + deps = deps, + linkshared = 1, + **kwargs + ) + + # Get the import library for the dll + native.filegroup( + name = import_lib_name, + srcs = [":" + dll_name], + output_group = "interface_library", + ) + + # Because we cannot directly depend on cc_binary from other cc rules in deps attribute, + # we use cc_import as a bridge to depend on the dll. + cc_import( + name = import_target_name, + interface_library = ":" + import_lib_name, + shared_library = ":" + dll_name, + ) + + # Create a new cc_library to also include the headers needed for the shared library + cc_library( + name = name, + hdrs = hdrs, + visibility = visibility, + deps = deps + [ + ":" + import_target_name, + ], + ) \ No newline at end of file