Merge pull request #1 from NimagnaAG/pose_tracking_dll

Add pose tracking subproject
This commit is contained in:
MaksymAtNimagna 2022-01-06 12:24:55 +01:00 committed by GitHub
commit 00d593120e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 421 additions and 0 deletions

View File

@ -0,0 +1,59 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("windows_dll_library.bzl", "windows_dll_library")
licenses(["notice"])
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
# Define the shared library
windows_dll_library(
name = "pose_tracking_lib",
srcs = ["pose_tracking.cpp"],
hdrs = ["pose_tracking.h"],
# Define COMPILING_DLL to export symbols during the DLL compilation.
copts = ["-DCOMPILING_DLL"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/core:packet_presence_calculator",
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/tflite:tflite_model_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/graphs/pose_tracking:pose_tracking_cpu_deps",
]
)
# **Implicitly link to face_mesh_lib.dll**
cc_binary(
name = "pose_tracking_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
"//mediapipe/graphs/pose_tracking:pose_tracking_cpu_deps",
":pose_tracking_lib"
],
)

View File

@ -0,0 +1,36 @@
## Description
The pose_tracking_dll module allows for building a Mediapipe-based pose tracking DLL library that can be used with any C++ project. All the dependencies such as tensorflow are built statically into the dll.
Currently, the following features are supported:
- Segmenting the person(s) of interest
- Segmenting the skeleton(s)
- Accessing the 3D coordinates of each node of the skeleton
## Prerequisites
Follow the guidelines on the official Mediapipe website: https://google.github.io/mediapipe/getting_started/install.html#installing-on-windows
IMPORTANT: The tutorial does not specify which version of Bazel to install. Install Bazel version 3.7.2. The OpenCV version used by default in mediapipe is 3.4.10.
If you are using a different OpenCV version, adapt the `OPENCV_VERSION` variable in the file `mediapipe/external/opencv_<platform>.BUILD` to the one installed in the system (https://github.com/google/mediapipe/issues/1926#issuecomment-825874197).
## How to build
Assuming you're in the root of the repository:
```
cd mediapipe
bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 --action_env PYTHON_BIN_PATH=<path to the python executable described using forward slashes ("/")> pose_tracking_dll:pose_tracking_cpu
```
Alternatively `dbg` can be used in place of `opt` to build the library with debug symbols in Visual Studio pdb format.
The results will be stored in the bazel-bin\mediapipe\pose_tracking_dll folder.
## How to use
Go to bazel-bin\mediapipe\pose_tracking_dll
Link pose_tracking_cpu.lib and pose_tracking_lib.dll.if.lib statically in your project.
Make sure the opencv_world3410.dll and pose_tracking_lib.dll are accessible in your working directory.
Use the mediapipe\pose_tracking_dll\pose_tracking.h header file to access the methods of the library.

View File

@ -0,0 +1,163 @@
#include "pose_tracking.h"
#include <cstdlib>
#include <string>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/opencv_video_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
class PoseTrackingImpl {
public:
PoseTrackingImpl(const std::string& calculatorGraphConfigFile) {
auto status = initialize(calculatorGraphConfigFile);
if (!status.ok()) {
LOG(WARNING) << "Warning: " << status;
}
}
absl::Status initialize(const std::string& calculatorGraphConfigFile) {
std::string graphContents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(calculatorGraphConfigFile, &graphContents));
mediapipe::CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(graphContents);
MP_RETURN_IF_ERROR(graph.Initialize(config));
ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller,
graph.AddOutputStreamPoller(kOutputSegmentationStream));
ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller landmarksPoller,
graph.AddOutputStreamPoller(kOutpuLandmarksStream));
ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller posePresencePoller,
graph.AddOutputStreamPoller(kOutpuPosePresenceStream));
maskPollerPtr = std::make_unique<mediapipe::OutputStreamPoller>(std::move(poller));
landmarksPollerPtr =
std::make_unique<mediapipe::OutputStreamPoller>(std::move(landmarksPoller));
posePresencePollerPtr =
std::make_unique<mediapipe::OutputStreamPoller>(std::move(posePresencePoller));
MP_RETURN_IF_ERROR(graph.StartRun({}));
}
bool processFrame(const cv::Mat& inputRGB8Bit) {
// Wrap Mat into an ImageFrame.
auto inputFrame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, inputRGB8Bit.cols, inputRGB8Bit.rows,
mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat inputFrameMat = mediapipe::formats::MatView(inputFrame.get());
inputRGB8Bit.copyTo(inputFrameMat);
// Send image packet into the graph.
size_t frameTimestampUs =
static_cast<double>(cv::getTickCount()) / static_cast<double>(cv::getTickFrequency()) * 1e6;
auto status = graph.AddPacketToInputStream(
kInputStream,
mediapipe::Adopt(inputFrame.release()).At(mediapipe::Timestamp(frameTimestampUs)));
if (!status.ok()) {
LOG(WARNING) << "Graph execution failed: " << status;
return false;
}
mediapipe::Packet posePresencePacket;
if (!posePresencePollerPtr || !posePresencePollerPtr->Next(&posePresencePacket)) return false;
auto landmarksDetected = posePresencePacket.Get<bool>();
if (!landmarksDetected) {
return false;
}
// Get the graph result packet, or stop if that fails.
mediapipe::Packet maskPacket;
if (!maskPollerPtr || !maskPollerPtr->Next(&maskPacket)) return false;
auto& outputFrame = maskPacket.Get<mediapipe::ImageFrame>();
// Get pose landmarks.
if (!landmarksPollerPtr || !landmarksPollerPtr->Next(&poseLandmarksPacket)) {
return false;
}
// Convert back to opencv for display or saving.
auto mask = mediapipe::formats::MatView(&outputFrame);
segmentedMask = mask.clone();
absl::Status landmarksStatus = detectLandmarksWithStatus(poseLandmarks);
return landmarksStatus.ok();
}
absl::Status detectLandmarksWithStatus(nimagna::cv_wrapper::Point3f* poseLandmarks) {
if (poseLandmarksPacket.IsEmpty()) {
return absl::CancelledError("Pose landmarks packet is empty.");
}
auto retrievedLandmarks = poseLandmarksPacket.Get<::mediapipe::NormalizedLandmarkList>();
// Convert landmarks to cv::Point3f**.
const auto landmarksCount = retrievedLandmarks.landmark_size();
for (int j = 0; j < landmarksCount; ++j) {
const auto& landmark = retrievedLandmarks.landmark(j);
poseLandmarks[j].x = landmark.x();
poseLandmarks[j].y = landmark.y();
poseLandmarks[j].z = landmark.z();
}
return absl::OkStatus();
}
nimagna::cv_wrapper::Point3f* lastDetectedLandmarks() { return poseLandmarks; }
cv::Mat lastSegmentedFrame() { return segmentedMask; }
static constexpr size_t kLandmarksCount = 33u;
private:
mediapipe::Packet poseLandmarksPacket;
cv::Mat segmentedMask;
nimagna::cv_wrapper::Point3f poseLandmarks[kLandmarksCount];
std::unique_ptr<mediapipe::OutputStreamPoller> posePresencePollerPtr;
std::unique_ptr<mediapipe::OutputStreamPoller> maskPollerPtr;
std::unique_ptr<mediapipe::OutputStreamPoller> landmarksPollerPtr;
mediapipe::CalculatorGraph graph;
const char* kInputStream = "input_video";
const char* kOutputSegmentationStream = "segmentation_mask";
const char* kOutpuLandmarksStream = "pose_landmarks";
const char* kOutpuPosePresenceStream = "pose_presence";
};
namespace nimagna {
PoseTracking::PoseTracking(const char* calculatorGraphConfigFile) {
mImplementation = new PoseTrackingImpl(calculatorGraphConfigFile);
}
bool PoseTracking::processFrame(const cv_wrapper::Mat& inputRGB8Bit) {
const auto frame = cv::Mat(inputRGB8Bit.rows, inputRGB8Bit.cols, CV_8UC3, inputRGB8Bit.data);
return mImplementation->processFrame(frame);
}
cv_wrapper::Point3f* PoseTracking::lastDetectedLandmarks() {
return mImplementation->lastDetectedLandmarks();
}
cv_wrapper::Mat PoseTracking::lastSegmentedFrame() {
const cv::Mat result = mImplementation->lastSegmentedFrame();
return cv_wrapper::Mat(result.rows, result.cols, result.data);
}
} // namespace nimagna

View File

@ -0,0 +1,101 @@
#ifndef POSE_TRACKING_LIBRARY_H
#define POSE_TRACKING_LIBRARY_H
#ifdef COMPILING_DLL
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT __declspec(dllimport)
#endif
class PoseTrackingImpl;
namespace nimagna {
namespace cv_wrapper {
struct Point2f {
float x = 0;
float y = 0;
Point2f() = default;
Point2f(float x, float y) : x(x), y(y) {}
};
struct Point3f {
float x = 0;
float y = 0;
float z = 0;
Point3f() = default;
Point3f(float x, float y, float z) : x(x), y(y), z(z) {}
};
struct Rect {
int x = 0;
int y = 0;
int width = 0;
int height = 0;
Rect() = default;
Rect(int x, int y, int width, int height) : x(x), y(y), width(width), height(height) {}
};
struct Mat {
int rows = 0;
int cols = 0;
unsigned char* data = 0;
Mat(int rows, int cols, unsigned char* data) : rows(rows), cols(cols), data(data) {}
};
} // namespace cv_wrapper
class DLLEXPORT PoseTracking {
public:
static constexpr size_t landmarksCount = 33u;
enum LandmarkNames {
NOSE = 0,
LEFT_EYE_INNER,
LEFT_EYE,
LEFT_EYE_OUTER,
RIGHT_EYE_INNER,
RIGHT_EYE,
RIGHT_EYE_OUTER,
LEFT_EAR,
RIGHT_EAR,
MOUTH_LEFT,
MOUTH_RIGHT,
LEFT_SHOULDER,
RIGHT_SHOULDER,
LEFT_ELBOW,
RIGHT_ELBOW,
LEFT_WRIST,
RIGHT_WRIST,
LEFT_PINKY,
RIGHT_PINKY,
LEFT_INDEX,
RIGHT_INDEX,
LEFT_THUMB,
RIGHT_THUMB,
LEFT_HIP,
RIGHT_HIP,
LEFT_KNEE,
RIGHT_KNEE,
LEFT_ANKLE,
RIGHT_ANKLE,
LEFT_HEEL,
RIGHT_HEEL,
LEFT_FOOT_INDEX,
RIGHT_FOOT_INDEX,
COUNT = landmarksCount
};
PoseTracking(const char* calculatorGraphConfigFile);
~PoseTracking() { delete mImplementation; }
bool processFrame(const cv_wrapper::Mat& inputRGB8Bit);
cv_wrapper::Mat lastSegmentedFrame();
cv_wrapper::Point3f* lastDetectedLandmarks();
private:
PoseTrackingImpl* mImplementation;
};
} // namespace nimagna
#endif

View File

@ -0,0 +1,62 @@
"""
This is a simple windows_dll_library rule for builing a DLL Windows
that can be depended on by other cc rules.
Example useage:
windows_dll_library(
name = "hellolib",
srcs = [
"hello-library.cpp",
],
hdrs = ["hello-library.h"],
# Define COMPILING_DLL to export symbols during compiling the DLL.
copts = ["/DCOMPILING_DLL"],
)
"""
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_import", "cc_library")
def windows_dll_library(
name,
srcs = [],
deps = [],
hdrs = [],
visibility = None,
**kwargs):
"""A simple windows_dll_library rule for builing a DLL Windows."""
dll_name = name + ".dll"
import_lib_name = name + "_import_lib"
import_target_name = name + "_dll_import"
# Build the shared library
cc_binary(
name = dll_name,
srcs = srcs + hdrs,
deps = deps,
linkshared = 1,
**kwargs
)
# Get the import library for the dll
native.filegroup(
name = import_lib_name,
srcs = [":" + dll_name],
output_group = "interface_library",
)
# Because we cannot directly depend on cc_binary from other cc rules in deps attribute,
# we use cc_import as a bridge to depend on the dll.
cc_import(
name = import_target_name,
interface_library = ":" + import_lib_name,
shared_library = ":" + dll_name,
)
# Create a new cc_library to also include the headers needed for the shared library
cc_library(
name = name,
hdrs = hdrs,
visibility = visibility,
deps = deps + [
":" + import_target_name,
],
)