Compare commits

..

No commits in common. "master" and "release" have entirely different histories.

91 changed files with 1338 additions and 6354 deletions

View File

@ -301,11 +301,9 @@ cc_test(
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator",
"//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator",
"//mediapipe/gpu:multi_pool",
"//third_party:opencv", "//third_party:opencv",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],

View File

@ -656,15 +656,6 @@ absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) {
input.format()); input.format());
gpu_helper_.BindFramebuffer(dst); gpu_helper_.BindFramebuffer(dst);
if (scale_mode_ == mediapipe::ScaleMode::FIT) {
// In kFit scale mode, the rendered quad does not fill the whole
// framebuffer, so clear it beforehand.
glClearColor(padding_color_[0] / 255.0f, padding_color_[1] / 255.0f,
padding_color_[2] / 255.0f, 1.0f);
glClear(GL_COLOR_BUFFER_BIT);
}
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(src1.target(), src1.name()); glBindTexture(src1.target(), src1.name());

View File

@ -46,14 +46,13 @@ message ImageTransformationCalculatorOptions {
optional bool flip_horizontally = 5 [default = false]; optional bool flip_horizontally = 5 [default = false];
// Scale mode. // Scale mode.
optional ScaleMode.Mode scale_mode = 6; optional ScaleMode.Mode scale_mode = 6;
// Padding type. This option is only used when the scale mode is FIT. If set // Padding type. This option is only used when the scale mode is FIT.
// to true (default), a constant border is added with color specified by // Default is to use BORDER_CONSTANT. If set to false, it will use
// padding_color. If set to false, a border is added by replicating edge // BORDER_REPLICATE instead.
// pixels (only supported for CPU).
optional bool constant_padding = 7 [default = true]; optional bool constant_padding = 7 [default = true];
// The color for the padding. This option is only used when the scale mode is // The color for the padding. This option is only used when the scale mode is
// FIT. Default is black. // FIT. Default is black. This is for CPU only.
optional Color padding_color = 8; optional Color padding_color = 8;
// Interpolation method to use. Note that on CPU when LINEAR is specified, // Interpolation method to use. Note that on CPU when LINEAR is specified,

View File

@ -1,11 +1,9 @@
#include <algorithm>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/log/absl_check.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -18,14 +16,10 @@
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/gpu/multi_pool.h"
#include "testing/base/public/gmock.h" #include "testing/base/public/gmock.h"
#include "testing/base/public/googletest.h" #include "testing/base/public/googletest.h"
#include "testing/base/public/gunit.h"
#include "third_party/OpenCV/core.hpp" // IWYU pragma: keep #include "third_party/OpenCV/core.hpp" // IWYU pragma: keep
#include "third_party/OpenCV/core/base.hpp"
#include "third_party/OpenCV/core/mat.hpp" #include "third_party/OpenCV/core/mat.hpp"
#include "third_party/OpenCV/core/types.hpp"
namespace mediapipe { namespace mediapipe {
@ -82,12 +76,11 @@ TEST(ImageTransformationCalculatorTest, NearestNeighborResizing) {
->Tag("OUTPUT_DIMENSIONS") ->Tag("OUTPUT_DIMENSIONS")
.packets.push_back(input_output_dim_packet.At(Timestamp(0))); .packets.push_back(input_output_dim_packet.At(Timestamp(0)));
ABSL_QCHECK_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs(); const auto& outputs = runner.Outputs();
ABSL_QCHECK_EQ(outputs.NumEntries(), 1); ASSERT_EQ(outputs.NumEntries(), 1);
const std::vector<Packet>& packets = outputs.Tag("IMAGE").packets; const std::vector<Packet>& packets = outputs.Tag("IMAGE").packets;
ABSL_QCHECK_EQ(packets.size(), 1); ASSERT_EQ(packets.size(), 1);
const auto& result = packets[0].Get<ImageFrame>(); const auto& result = packets[0].Get<ImageFrame>();
ASSERT_EQ(output_dim.first, result.Width()); ASSERT_EQ(output_dim.first, result.Width());
ASSERT_EQ(output_dim.second, result.Height()); ASSERT_EQ(output_dim.second, result.Height());
@ -144,12 +137,11 @@ TEST(ImageTransformationCalculatorTest,
->Tag("OUTPUT_DIMENSIONS") ->Tag("OUTPUT_DIMENSIONS")
.packets.push_back(input_output_dim_packet.At(Timestamp(0))); .packets.push_back(input_output_dim_packet.At(Timestamp(0)));
ABSL_QCHECK_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs(); const auto& outputs = runner.Outputs();
ABSL_QCHECK_EQ(outputs.NumEntries(), 1); ASSERT_EQ(outputs.NumEntries(), 1);
const std::vector<Packet>& packets = outputs.Tag("IMAGE").packets; const std::vector<Packet>& packets = outputs.Tag("IMAGE").packets;
ABSL_QCHECK_EQ(packets.size(), 1); ASSERT_EQ(packets.size(), 1);
const auto& result = packets[0].Get<ImageFrame>(); const auto& result = packets[0].Get<ImageFrame>();
ASSERT_EQ(output_dim.first, result.Width()); ASSERT_EQ(output_dim.first, result.Width());
ASSERT_EQ(output_dim.second, result.Height()); ASSERT_EQ(output_dim.second, result.Height());
@ -215,17 +207,17 @@ TEST(ImageTransformationCalculatorTest, NearestNeighborResizingGpu) {
tool::AddVectorSink("output_image", &graph_config, &output_image_packets); tool::AddVectorSink("output_image", &graph_config, &output_image_packets);
CalculatorGraph graph(graph_config); CalculatorGraph graph(graph_config);
ABSL_QCHECK_OK(graph.StartRun({})); MP_ASSERT_OK(graph.StartRun({}));
ABSL_QCHECK_OK(graph.AddPacketToInputStream( MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image", "input_image",
MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0)))); MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0))));
ABSL_QCHECK_OK(graph.AddPacketToInputStream( MP_ASSERT_OK(graph.AddPacketToInputStream(
"image_size", "image_size",
MakePacket<std::pair<int, int>>(output_dim).At(Timestamp(0)))); MakePacket<std::pair<int, int>>(output_dim).At(Timestamp(0))));
ABSL_QCHECK_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
ABSL_QCHECK_EQ(output_image_packets.size(), 1); ASSERT_THAT(output_image_packets, testing::SizeIs(1));
const auto& output_image = output_image_packets[0].Get<ImageFrame>(); const auto& output_image = output_image_packets[0].Get<ImageFrame>();
ASSERT_EQ(output_dim.first, output_image.Width()); ASSERT_EQ(output_dim.first, output_image.Width());
@ -295,16 +287,16 @@ TEST(ImageTransformationCalculatorTest,
tool::AddVectorSink("output_image", &graph_config, &output_image_packets); tool::AddVectorSink("output_image", &graph_config, &output_image_packets);
CalculatorGraph graph(graph_config); CalculatorGraph graph(graph_config);
ABSL_QCHECK_OK(graph.StartRun({})); MP_ASSERT_OK(graph.StartRun({}));
ABSL_QCHECK_OK(graph.AddPacketToInputStream( MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image", input_image_packet.At(Timestamp(0)))); "input_image", input_image_packet.At(Timestamp(0))));
ABSL_QCHECK_OK(graph.AddPacketToInputStream( MP_ASSERT_OK(graph.AddPacketToInputStream(
"image_size", "image_size",
MakePacket<std::pair<int, int>>(output_dim).At(Timestamp(0)))); MakePacket<std::pair<int, int>>(output_dim).At(Timestamp(0))));
ABSL_QCHECK_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
ABSL_QCHECK_EQ(output_image_packets.size(), 1); ASSERT_THAT(output_image_packets, testing::SizeIs(1));
const auto& output_image = output_image_packets[0].Get<ImageFrame>(); const auto& output_image = output_image_packets[0].Get<ImageFrame>();
ASSERT_EQ(output_dim.first, output_image.Width()); ASSERT_EQ(output_dim.first, output_image.Width());
@ -319,112 +311,5 @@ TEST(ImageTransformationCalculatorTest,
} }
} }
TEST(ImageTransformationCalculatorTest, FitScalingClearsBackground) {
// Regression test for not clearing the background in FIT scaling mode.
// First scale an all-red (=r) image from 8x4 to 8x4, so it's a plain copy:
// rrrrrrrr
// rrrrrrrr
// rrrrrrrr
// rrrrrrrr
// Then scale an all-blue image from 4x4 to 8x4 in FIT mode. This should
// introduce dark yellow (=y) letterboxes left and right due to padding_color:
// yybbbbyy
// yybbbbyy
// yybbbbyy
// yybbbbyy
// We make sure that the all-red buffer gets reused. Without clearing the
// background, the blue (=b) image will have red letterboxes:
// rrbbbbrr
// rrbbbbrr
// rrbbbbrr
// rrbbbbrr
constexpr int kSmall = 4, kLarge = 8;
ImageFrame input_image_red(ImageFormat::SRGBA, kLarge, kSmall);
cv::Mat input_image_red_mat = formats::MatView(&input_image_red);
input_image_red_mat = cv::Scalar(255, 0, 0, 255);
ImageFrame input_image_blue(ImageFormat::SRGBA, kSmall, kSmall);
cv::Mat input_image_blue_mat = formats::MatView(&input_image_blue);
input_image_blue_mat = cv::Scalar(0, 0, 255, 255);
Packet input_image_red_packet =
MakePacket<ImageFrame>(std::move(input_image_red));
Packet input_image_blue_packet =
MakePacket<ImageFrame>(std::move(input_image_blue));
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "input_image"
output_stream: "output_image"
node {
calculator: "ImageFrameToGpuBufferCalculator"
input_stream: "input_image"
output_stream: "input_image_gpu"
}
node {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE_GPU:input_image_gpu"
output_stream: "IMAGE_GPU:output_image_gpu"
options: {
[mediapipe.ImageTransformationCalculatorOptions.ext]: {
scale_mode: FIT
output_width: $0,
output_height: $1,
padding_color: { red: 128, green: 128, blue: 0 }
}
}
}
node {
calculator: "GpuBufferToImageFrameCalculator"
input_stream: "output_image_gpu"
output_stream: "output_image"
})pb",
kLarge, kSmall));
std::vector<Packet> output_image_packets;
tool::AddVectorSink("output_image", &graph_config, &output_image_packets);
CalculatorGraph graph(graph_config);
ABSL_QCHECK_OK(graph.StartRun({}));
// Send the red image multiple times to cause the GPU pool to actually use
// a pool.
int num_red_packets =
std::max(kDefaultMultiPoolOptions.min_requests_before_pool, 1);
for (int n = 0; n < num_red_packets; ++n) {
ABSL_QCHECK_OK(graph.AddPacketToInputStream(
"input_image", input_image_red_packet.At(Timestamp(n))));
}
ABSL_QCHECK_OK(graph.AddPacketToInputStream(
"input_image", input_image_blue_packet.At(Timestamp(num_red_packets))));
ABSL_QCHECK_OK(graph.WaitUntilIdle());
ABSL_QCHECK_EQ(output_image_packets.size(), num_red_packets + 1);
const auto& output_image_red = output_image_packets[0].Get<ImageFrame>();
const auto& output_image_blue =
output_image_packets[num_red_packets].Get<ImageFrame>();
ABSL_QCHECK_EQ(output_image_red.Width(), kLarge);
ABSL_QCHECK_EQ(output_image_red.Height(), kSmall);
ABSL_QCHECK_EQ(output_image_blue.Width(), kLarge);
ABSL_QCHECK_EQ(output_image_blue.Height(), kSmall);
cv::Mat output_image_blue_mat = formats::MatView(&output_image_blue);
ImageFrame expected_image_blue(ImageFormat::SRGBA, kLarge, kSmall);
cv::Mat expected_image_blue_mat = formats::MatView(&expected_image_blue);
expected_image_blue_mat = cv::Scalar(128, 128, 0, 255);
cv::Rect rect((kLarge - kSmall) / 2, 0, kSmall, kSmall);
cv::rectangle(expected_image_blue_mat, rect, cv::Scalar(0, 0, 255, 255),
cv::FILLED);
EXPECT_EQ(cv::sum(cv::sum(output_image_blue_mat != expected_image_blue_mat)),
cv::Scalar(0));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -43,9 +43,9 @@ class KinematicPathSolver {
initialized_(false), initialized_(false),
pixels_per_degree_(pixels_per_degree) {} pixels_per_degree_(pixels_per_degree) {}
// Add an observation (detection) at a position and time. // Add an observation (detection) at a position and time.
absl::Status AddObservation(int position, const uint64_t time_us); absl::Status AddObservation(int position, const uint64 time_us);
// Get the predicted position at a time. // Get the predicted position at a time.
absl::Status UpdatePrediction(const int64_t time_us); absl::Status UpdatePrediction(const int64 time_us);
// Get the state at a time, as an int. // Get the state at a time, as an int.
absl::Status GetState(int* position); absl::Status GetState(int* position);
// Get the state at a time, as a float. // Get the state at a time, as a float.
@ -63,7 +63,7 @@ class KinematicPathSolver {
bool IsMotionTooSmall(double delta_degs); bool IsMotionTooSmall(double delta_degs);
// Check if a position measurement will cause the camera to be in motion // Check if a position measurement will cause the camera to be in motion
// without updating the internal state. // without updating the internal state.
absl::Status PredictMotionState(int position, const uint64_t time_us, absl::Status PredictMotionState(int position, const uint64 time_us,
bool* state); bool* state);
// Clear any history buffer of positions that are used when // Clear any history buffer of positions that are used when
// filtering_time_window_us is set to a non-zero value. // filtering_time_window_us is set to a non-zero value.
@ -85,9 +85,9 @@ class KinematicPathSolver {
double current_position_px_; double current_position_px_;
double prior_position_px_; double prior_position_px_;
double current_velocity_deg_per_s_; double current_velocity_deg_per_s_;
uint64_t current_time_ = 0; uint64 current_time_ = 0;
// History of observations (second) and their time (first). // History of observations (second) and their time (first).
std::deque<std::pair<uint64_t, int>> raw_positions_at_time_; std::deque<std::pair<uint64, int>> raw_positions_at_time_;
// Current target position. // Current target position.
double target_position_px_; double target_position_px_;
// Defines if the camera is moving to a target (true) or reached a target // Defines if the camera is moving to a target (true) or reached a target

View File

@ -67,7 +67,7 @@ class SceneCameraMotionAnalyzer {
const KeyFrameCropOptions& key_frame_crop_options, const KeyFrameCropOptions& key_frame_crop_options,
const std::vector<KeyFrameCropResult>& key_frame_crop_results, const std::vector<KeyFrameCropResult>& key_frame_crop_results,
const int scene_frame_width, const int scene_frame_height, const int scene_frame_width, const int scene_frame_height,
const std::vector<int64_t>& scene_frame_timestamps, const std::vector<int64>& scene_frame_timestamps,
const bool has_solid_color_background, const bool has_solid_color_background,
SceneKeyFrameCropSummary* scene_summary, SceneKeyFrameCropSummary* scene_summary,
std::vector<FocusPointFrame>* focus_point_frames, std::vector<FocusPointFrame>* focus_point_frames,
@ -78,7 +78,7 @@ class SceneCameraMotionAnalyzer {
// crop window in SceneKeyFrameCropSummary in the case of steady motion. // crop window in SceneKeyFrameCropSummary in the case of steady motion.
absl::Status DecideCameraMotionType( absl::Status DecideCameraMotionType(
const KeyFrameCropOptions& key_frame_crop_options, const KeyFrameCropOptions& key_frame_crop_options,
const double scene_span_sec, const int64_t end_time_us, const double scene_span_sec, const int64 end_time_us,
SceneKeyFrameCropSummary* scene_summary, SceneKeyFrameCropSummary* scene_summary,
SceneCameraMotion* scene_camera_motion) const; SceneCameraMotion* scene_camera_motion) const;
@ -87,7 +87,7 @@ class SceneCameraMotionAnalyzer {
absl::Status PopulateFocusPointFrames( absl::Status PopulateFocusPointFrames(
const SceneKeyFrameCropSummary& scene_summary, const SceneKeyFrameCropSummary& scene_summary,
const SceneCameraMotion& scene_camera_motion, const SceneCameraMotion& scene_camera_motion,
const std::vector<int64_t>& scene_frame_timestamps, const std::vector<int64>& scene_frame_timestamps,
std::vector<FocusPointFrame>* focus_point_frames) const; std::vector<FocusPointFrame>* focus_point_frames) const;
private: private:
@ -118,7 +118,7 @@ class SceneCameraMotionAnalyzer {
absl::Status PopulateFocusPointFramesForTracking( absl::Status PopulateFocusPointFramesForTracking(
const SceneKeyFrameCropSummary& scene_summary, const SceneKeyFrameCropSummary& scene_summary,
const FocusPointFrameType focus_point_frame_type, const FocusPointFrameType focus_point_frame_type,
const std::vector<int64_t>& scene_frame_timestamps, const std::vector<int64>& scene_frame_timestamps,
std::vector<FocusPointFrame>* focus_point_frames) const; std::vector<FocusPointFrame>* focus_point_frames) const;
// Decide to use steady motion. // Decide to use steady motion.
@ -142,7 +142,7 @@ class SceneCameraMotionAnalyzer {
// Last position // Last position
SceneCameraMotion last_scene_with_salient_region_; SceneCameraMotion last_scene_with_salient_region_;
int64_t time_since_last_salient_region_us_; int64 time_since_last_salient_region_us_;
// Scene has solid color background. // Scene has solid color background.
bool has_solid_color_background_; bool has_solid_color_background_;

View File

@ -62,7 +62,7 @@ class SceneCropper {
// TODO: split this function into two separate functions. // TODO: split this function into two separate functions.
absl::Status CropFrames( absl::Status CropFrames(
const SceneKeyFrameCropSummary& scene_summary, const SceneKeyFrameCropSummary& scene_summary,
const std::vector<int64_t>& scene_timestamps, const std::vector<int64>& scene_timestamps,
const std::vector<bool>& is_key_frames, const std::vector<bool>& is_key_frames,
const std::vector<cv::Mat>& scene_frames_or_empty, const std::vector<cv::Mat>& scene_frames_or_empty,
const std::vector<FocusPointFrame>& focus_point_frames, const std::vector<FocusPointFrame>& focus_point_frames,
@ -73,7 +73,7 @@ class SceneCropper {
absl::Status ProcessKinematicPathSolver( absl::Status ProcessKinematicPathSolver(
const SceneKeyFrameCropSummary& scene_summary, const SceneKeyFrameCropSummary& scene_summary,
const std::vector<int64_t>& scene_timestamps, const std::vector<int64>& scene_timestamps,
const std::vector<bool>& is_key_frames, const std::vector<bool>& is_key_frames,
const std::vector<FocusPointFrame>& focus_point_frames, const std::vector<FocusPointFrame>& focus_point_frames,
const bool continue_last_scene, std::vector<cv::Mat>* all_xforms); const bool continue_last_scene, std::vector<cv::Mat>* all_xforms);

View File

@ -29,7 +29,7 @@ namespace autoflip {
// Packs detected features and timestamp (ms) into a KeyFrameInfo object. Scales // Packs detected features and timestamp (ms) into a KeyFrameInfo object. Scales
// features back to the original frame size if features have been detected on a // features back to the original frame size if features have been detected on a
// different frame size. // different frame size.
absl::Status PackKeyFrameInfo(const int64_t frame_timestamp_ms, absl::Status PackKeyFrameInfo(const int64 frame_timestamp_ms,
const DetectionSet& detections, const DetectionSet& detections,
const int original_frame_width, const int original_frame_width,
const int original_frame_height, const int original_frame_height,
@ -71,7 +71,7 @@ absl::Status ComputeSceneStaticBordersSize(
// interpolation functions in Lab space using input timestamps. // interpolation functions in Lab space using input timestamps.
absl::Status FindSolidBackgroundColor( absl::Status FindSolidBackgroundColor(
const std::vector<StaticFeatures>& static_features, const std::vector<StaticFeatures>& static_features,
const std::vector<int64_t>& static_features_timestamps, const std::vector<int64>& static_features_timestamps,
const double min_fraction_solid_background_color, const double min_fraction_solid_background_color,
bool* has_solid_background, bool* has_solid_background,
PiecewiseLinearFunction* background_color_l_function, PiecewiseLinearFunction* background_color_l_function,

View File

@ -155,27 +155,6 @@ cc_library(
], ],
) )
cc_library(
name = "hardware_buffer",
srcs = ["hardware_buffer_android.cc"],
hdrs = ["hardware_buffer.h"],
linkopts = select({
"//conditions:default": [],
# Option for vendor binaries to avoid linking libandroid.so.
"//mediapipe/framework:android_no_jni": [],
"//mediapipe:android": ["-landroid"],
":android_link_native_window": [
"-lnativewindow", # Provides <android/hardware_buffer.h> to vendor binaries on Android API >= 26.
],
}),
visibility = ["//visibility:private"],
deps = [
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/log:absl_check",
],
)
cc_library( cc_library(
name = "image_frame", name = "image_frame",
srcs = ["image_frame.cc"], srcs = ["image_frame.cc"],
@ -514,6 +493,10 @@ cc_library(
"//conditions:default": [], "//conditions:default": [],
# Option for vendor binaries to avoid linking libandroid.so. # Option for vendor binaries to avoid linking libandroid.so.
"//mediapipe/framework:android_no_jni": [], "//mediapipe/framework:android_no_jni": [],
"//mediapipe:android": ["-landroid"],
":android_link_native_window": [
"-lnativewindow", # Provides <android/hardware_buffer.h> to vendor binaries on Android API >= 26.
],
}), }),
deps = [ deps = [
"//mediapipe/framework:port", "//mediapipe/framework:port",
@ -528,16 +511,9 @@ cc_library(
"//mediapipe/gpu:gl_base", "//mediapipe/gpu:gl_base",
"//mediapipe/gpu:gl_context", "//mediapipe/gpu:gl_context",
], ],
"//mediapipe:android": [ }) +
":hardware_buffer", select({
"//mediapipe/gpu:gl_base", "//conditions:default": [],
"//mediapipe/gpu:gl_context",
],
":android_link_native_window": [
":hardware_buffer",
"//mediapipe/gpu:gl_base",
"//mediapipe/gpu:gl_context",
],
}), }),
) )

View File

@ -17,25 +17,10 @@
#include <stdio.h> #include <stdio.h>
#include <utility>
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
namespace mediapipe { namespace mediapipe {
DeletingFile::DeletingFile(DeletingFile&& other)
: path_(std::move(other.path_)),
delete_on_destruction_(other.delete_on_destruction_) {
other.delete_on_destruction_ = false;
}
DeletingFile& DeletingFile::operator=(DeletingFile&& other) {
path_ = std::move(other.path_);
delete_on_destruction_ = other.delete_on_destruction_;
other.delete_on_destruction_ = false;
return *this;
}
DeletingFile::DeletingFile(const std::string& path, bool delete_on_destruction) DeletingFile::DeletingFile(const std::string& path, bool delete_on_destruction)
: path_(path), delete_on_destruction_(delete_on_destruction) {} : path_(path), delete_on_destruction_(delete_on_destruction) {}

View File

@ -28,11 +28,6 @@ class DeletingFile {
DeletingFile(const DeletingFile&) = delete; DeletingFile(const DeletingFile&) = delete;
DeletingFile& operator=(const DeletingFile&) = delete; DeletingFile& operator=(const DeletingFile&) = delete;
// DeletingFile is movable. The moved-from object remains in valid but
// unspecified state and will not perform any operations on destruction.
DeletingFile(DeletingFile&& other);
DeletingFile& operator=(DeletingFile&& other);
// Provide the path to the file and whether the file should be deleted // Provide the path to the file and whether the file should be deleted
// when this object is destroyed. // when this object is destroyed.
DeletingFile(const std::string& path, bool delete_on_destruction); DeletingFile(const std::string& path, bool delete_on_destruction);

View File

@ -16,11 +16,11 @@ limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
#include <cstdint>
#include <vector> #include <vector>
#include "absl/log/absl_check.h" #include "absl/log/absl_check.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe { namespace mediapipe {
@ -76,13 +76,13 @@ class FrameBuffer {
// Plane encapsulates buffer and stride information. // Plane encapsulates buffer and stride information.
struct Plane { struct Plane {
Plane(uint8_t* buffer, Stride stride) : buffer_(buffer), stride_(stride) {} Plane(uint8* buffer, Stride stride) : buffer_(buffer), stride_(stride) {}
const uint8_t* buffer() const { return buffer_; } const uint8* buffer() const { return buffer_; }
uint8_t* mutable_buffer() { return buffer_; } uint8* mutable_buffer() { return buffer_; }
Stride stride() const { return stride_; } Stride stride() const { return stride_; }
private: private:
uint8_t* buffer_; uint8* buffer_;
Stride stride_; Stride stride_;
}; };
@ -121,9 +121,9 @@ class FrameBuffer {
// YUV data structure. // YUV data structure.
struct YuvData { struct YuvData {
const uint8_t* y_buffer; const uint8* y_buffer;
const uint8_t* u_buffer; const uint8* u_buffer;
const uint8_t* v_buffer; const uint8* v_buffer;
// Y buffer row stride in bytes. // Y buffer row stride in bytes.
int y_row_stride; int y_row_stride;
// U/V buffer row stride in bytes. // U/V buffer row stride in bytes.

View File

@ -1,167 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_HARDWARE_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_HARDWARE_BUFFER_H_
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
typedef struct AHardwareBuffer AHardwareBuffer;
namespace mediapipe {
struct HardwareBufferSpec {
// Buffer pixel formats. See NDK's hardware_buffer.h for descriptions.
enum {
// This must be kept in sync with NDK's hardware_buffer.h
AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM = 0x01,
AHARDWAREBUFFER_FORMAT_R8G8B8_UNORM = 0x03,
AHARDWAREBUFFER_FORMAT_R16G16B16A16_FLOAT = 0x16,
AHARDWAREBUFFER_FORMAT_BLOB = 0x21,
AHARDWAREBUFFER_FORMAT_R8_UNORM = 0x38,
};
// Buffer usage descriptions. See NDK's hardware_buffer.h for descriptions.
enum {
// This must be kept in sync with NDK's hardware_buffer.h
AHARDWAREBUFFER_USAGE_CPU_READ_NEVER = 0x0UL,
AHARDWAREBUFFER_USAGE_CPU_READ_RARELY = 0x2UL,
AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN = 0x3UL,
AHARDWAREBUFFER_USAGE_CPU_WRITE_NEVER = UINT64_C(0) << 4,
AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY = UINT64_C(2) << 4,
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN = UINT64_C(3) << 4,
AHARDWAREBUFFER_USAGE_GPU_SAMPLED_IMAGE = UINT64_C(1) << 8,
AHARDWAREBUFFER_USAGE_GPU_FRAMEBUFFER = UINT64_C(1) << 9,
AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER = UINT64_C(1) << 24,
};
// Hashing required to use HardwareBufferSpec as key in buffer pools. See
// absl::Hash for details.
template <typename H>
friend H AbslHashValue(H h, const HardwareBufferSpec& spec) {
return H::combine(std::move(h), spec.width, spec.height, spec.layers,
spec.format, spec.usage);
}
uint32_t width = 0;
uint32_t height = 0;
uint32_t layers = 0;
uint32_t format = 0;
uint64_t usage = 0;
};
// Equality operators
inline bool operator==(const HardwareBufferSpec& lhs,
const HardwareBufferSpec& rhs) {
return lhs.width == rhs.width && lhs.height == rhs.height &&
lhs.layers == rhs.layers && lhs.format == rhs.format &&
lhs.usage == rhs.usage;
}
inline bool operator!=(const HardwareBufferSpec& lhs,
const HardwareBufferSpec& rhs) {
return !operator==(lhs, rhs);
}
// For internal use only. Thinly wraps the Android NDK AHardwareBuffer.
class HardwareBuffer {
public:
// Constructs a HardwareBuffer instance from a newly allocated Android NDK
// AHardwareBuffer.
static absl::StatusOr<HardwareBuffer> Create(const HardwareBufferSpec& spec);
// Destructs the HardwareBuffer, releasing the AHardwareBuffer.
~HardwareBuffer();
// Support HardwareBuffer moves.
HardwareBuffer(HardwareBuffer&& other);
// Delete assignment and copy constructors.
HardwareBuffer(HardwareBuffer& other) = delete;
HardwareBuffer(const HardwareBuffer& other) = delete;
HardwareBuffer& operator=(const HardwareBuffer&) = delete;
// Returns true if AHWB is supported.
static bool IsSupported();
// Lock the hardware buffer for the given usage flags. fence_file_descriptor
// specifies a fence file descriptor on which to wait before locking the
// buffer. Returns raw memory address if lock is successful, nullptr
// otherwise.
ABSL_MUST_USE_RESULT absl::StatusOr<void*> Lock(
uint64_t usage, std::optional<int> fence_file_descriptor = std::nullopt);
// Unlocks the hardware buffer synchronously. This method blocks until
// unlocking is complete.
absl::Status Unlock();
// Unlocks the hardware buffer asynchronously. It returns a file_descriptor
// which can be used as a fence that is signaled once unlocking is complete.
absl::StatusOr<int> UnlockAsync();
// Returns the underlying raw AHardwareBuffer pointer to be used directly with
// AHardwareBuffer APIs.
AHardwareBuffer* GetAHardwareBuffer() const { return ahw_buffer_; }
// Returns whether this HardwareBuffer contains a valid AHardwareBuffer.
bool IsValid() const { return ahw_buffer_ != nullptr; }
// Returns whether this HardwareBuffer is locked.
bool IsLocked() const { return is_locked_; }
// Releases the AHardwareBuffer.
void Reset();
// Ahwb's are aligned to an implementation specific cacheline size.
uint32_t GetAlignedWidth() const;
// Returns buffer spec.
const HardwareBufferSpec& spec() const { return spec_; }
private:
// Allocates an AHardwareBuffer instance;
static absl::StatusOr<AHardwareBuffer*> AllocateAHardwareBuffer(
const HardwareBufferSpec& spec);
// Constructs a HardwareBuffer instance from an already aquired
// AHardwareBuffer instance and its spec.
HardwareBuffer(const HardwareBufferSpec& spec, AHardwareBuffer* ahwb);
// Unlocks the hardware buffer. If fence_file_descriptor_ptr is not nullptr,
// the function won't block and instead fence_file_descriptor_ptr will be set
// to a file descriptor to become signaled once unlocking is complete.
absl::Status UnlockInternal(int* fence_file_descriptor_ptr);
// Releases ahw_buffer_ AHardwareBuffer instance;
absl::Status ReleaseAHardwareBuffer();
// Buffer spec.
HardwareBufferSpec spec_ = {};
// Android NDK AHardwareBuffer.
AHardwareBuffer* ahw_buffer_ = nullptr;
// Indicates if AHardwareBuffer is locked for reading or writing.
bool is_locked_ = false;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_AHWB_BUFFER_H_

View File

@ -1,152 +0,0 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if !defined(MEDIAPIPE_NO_JNI) && \
(__ANDROID_API__ >= 26 || \
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
#include <android/hardware_buffer.h>
#include <memory>
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/hardware_buffer.h"
#include "mediapipe/framework/port/ret_check.h"
namespace mediapipe {
HardwareBuffer::HardwareBuffer(HardwareBuffer &&other) {
spec_ = std::exchange(other.spec_, {});
ahw_buffer_ = std::exchange(other.ahw_buffer_, nullptr);
is_locked_ = std::exchange(other.is_locked_, false);
}
HardwareBuffer::HardwareBuffer(const HardwareBufferSpec &spec,
AHardwareBuffer *ahwb)
: spec_(spec), ahw_buffer_(ahwb), is_locked_(false) {}
HardwareBuffer::~HardwareBuffer() { Reset(); }
absl::StatusOr<HardwareBuffer> HardwareBuffer::Create(
const HardwareBufferSpec &spec) {
MP_ASSIGN_OR_RETURN(AHardwareBuffer * ahwb, AllocateAHardwareBuffer(spec));
return HardwareBuffer(spec, ahwb);
}
bool HardwareBuffer::IsSupported() {
if (__builtin_available(android 26, *)) {
return true;
}
return false;
}
absl::StatusOr<AHardwareBuffer *> HardwareBuffer::AllocateAHardwareBuffer(
const HardwareBufferSpec &spec) {
RET_CHECK(IsSupported()) << "AndroidHWBuffers not supported";
AHardwareBuffer *output = nullptr;
int error = 0;
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc = {
.width = spec.width,
.height = spec.height,
.layers = spec.layers,
.format = spec.format,
.usage = spec.usage,
};
error = AHardwareBuffer_allocate(&desc, &output);
}
RET_CHECK(!error && output != nullptr) << "AHardwareBuffer_allocate failed";
return output;
}
absl::Status HardwareBuffer::ReleaseAHardwareBuffer() {
if (ahw_buffer_ == nullptr) {
return absl::OkStatus();
}
if (is_locked_) {
MP_RETURN_IF_ERROR(Unlock());
}
if (__builtin_available(android 26, *)) {
AHardwareBuffer_release(ahw_buffer_);
}
spec_ = {};
ahw_buffer_ = nullptr;
return absl::OkStatus();
}
absl::StatusOr<void *> HardwareBuffer::Lock(
uint64_t usage, std::optional<int> fence_file_descriptor) {
RET_CHECK(ahw_buffer_ != nullptr) << "Hardware Buffer not allocated";
RET_CHECK(!is_locked_) << "Hardware Buffer already locked";
void *mem = nullptr;
if (__builtin_available(android 26, *)) {
const int error = AHardwareBuffer_lock(
ahw_buffer_, usage,
fence_file_descriptor.has_value() ? *fence_file_descriptor : -1,
nullptr, &mem);
RET_CHECK(error == 0) << "Hardware Buffer lock failed. Error: " << error;
}
is_locked_ = true;
return mem;
}
absl::Status HardwareBuffer::Unlock() {
return UnlockInternal(/*fence_file_descriptor=*/nullptr);
}
absl::StatusOr<int> HardwareBuffer::UnlockAsync() {
int fence_file_descriptor = -1;
MP_RETURN_IF_ERROR(UnlockInternal(&fence_file_descriptor));
return fence_file_descriptor;
}
absl::Status HardwareBuffer::UnlockInternal(int *fence_file_descriptor) {
RET_CHECK(ahw_buffer_ != nullptr) << "Hardware Buffer not allocated";
if (!is_locked_) {
return absl::OkStatus();
}
if (__builtin_available(android 26, *)) {
const int error =
AHardwareBuffer_unlock(ahw_buffer_, fence_file_descriptor);
RET_CHECK(error == 0) << "Hardware Buffer unlock failed. error: " << error;
}
is_locked_ = false;
return absl::OkStatus();
}
uint32_t HardwareBuffer::GetAlignedWidth() const {
if (__builtin_available(android 26, *)) {
ABSL_CHECK(ahw_buffer_ != nullptr) << "Hardware Buffer not allocated";
AHardwareBuffer_Desc desc = {};
AHardwareBuffer_describe(ahw_buffer_, &desc);
ABSL_CHECK_GT(desc.stride, 0);
return desc.stride;
}
return 0;
}
void HardwareBuffer::Reset() {
const auto success = ReleaseAHardwareBuffer();
if (!success.ok()) {
ABSL_LOG(DFATAL) << "Failed to release AHardwareBuffer: " << success;
}
}
} // namespace mediapipe
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__>= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -1,131 +0,0 @@
#include "mediapipe/framework/formats/hardware_buffer.h"
#include <android/hardware_buffer.h>
#include <memory>
#include "base/logging.h"
#include "mediapipe/framework/port/status_macros.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
namespace {
HardwareBufferSpec GetTestHardwareBufferSpec(uint32_t size_bytes) {
return {.width = size_bytes,
.height = 1,
.layers = 1,
.format = HardwareBufferSpec::AHARDWAREBUFFER_FORMAT_BLOB,
.usage = HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY |
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN |
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER};
}
TEST(HardwareBufferTest, ShouldConstructValidAHardwareBuffer) {
MP_ASSERT_OK_AND_ASSIGN(
HardwareBuffer hardware_buffer,
HardwareBuffer::Create(GetTestHardwareBufferSpec(/*size_bytes=*/123)));
EXPECT_NE(hardware_buffer.GetAHardwareBuffer(), nullptr);
EXPECT_TRUE(hardware_buffer.IsValid());
}
TEST(HardwareBufferTest, ShouldResetValidAHardwareBuffer) {
MP_ASSERT_OK_AND_ASSIGN(
HardwareBuffer hardware_buffer,
HardwareBuffer::Create(GetTestHardwareBufferSpec(/*size_bytes=*/123)));
EXPECT_TRUE(hardware_buffer.IsValid());
EXPECT_NE(*hardware_buffer.Lock(
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY),
nullptr);
EXPECT_TRUE(hardware_buffer.IsLocked());
hardware_buffer.Reset();
EXPECT_FALSE(hardware_buffer.IsValid());
EXPECT_FALSE(hardware_buffer.IsLocked());
}
TEST(HardwareBufferTest, ShouldAllocateRequestedBufferSize) {
constexpr int kBufferSize = 123;
const HardwareBufferSpec spec = GetTestHardwareBufferSpec(kBufferSize);
MP_ASSERT_OK_AND_ASSIGN(HardwareBuffer hardware_buffer,
HardwareBuffer::Create(spec));
EXPECT_TRUE(hardware_buffer.IsValid());
if (__builtin_available(android 26, *)) {
AHardwareBuffer_Desc desc;
AHardwareBuffer_describe(hardware_buffer.GetAHardwareBuffer(), &desc);
EXPECT_EQ(desc.width, spec.width);
EXPECT_EQ(desc.height, spec.height);
EXPECT_EQ(desc.layers, spec.layers);
EXPECT_EQ(desc.format, spec.format);
EXPECT_EQ(desc.usage, spec.usage);
}
EXPECT_EQ(hardware_buffer.spec().width, spec.width);
EXPECT_EQ(hardware_buffer.spec().height, spec.height);
EXPECT_EQ(hardware_buffer.spec().layers, spec.layers);
EXPECT_EQ(hardware_buffer.spec().format, spec.format);
EXPECT_EQ(hardware_buffer.spec().usage, spec.usage);
}
TEST(HardwareBufferTest, ShouldSupportMoveConstructor) {
constexpr int kBufferSize = 123;
const auto spec = GetTestHardwareBufferSpec(kBufferSize);
MP_ASSERT_OK_AND_ASSIGN(HardwareBuffer hardware_buffer_a,
HardwareBuffer::Create(spec));
EXPECT_TRUE(hardware_buffer_a.IsValid());
void* const ahardware_buffer_ptr_a = hardware_buffer_a.GetAHardwareBuffer();
EXPECT_NE(ahardware_buffer_ptr_a, nullptr);
EXPECT_FALSE(hardware_buffer_a.IsLocked());
MP_ASSERT_OK_AND_ASSIGN(
void* const hardware_buffer_a_locked_ptr,
hardware_buffer_a.Lock(
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY));
EXPECT_NE(hardware_buffer_a_locked_ptr, nullptr);
EXPECT_TRUE(hardware_buffer_a.IsLocked());
HardwareBuffer hardware_buffer_b(std::move(hardware_buffer_a));
EXPECT_FALSE(hardware_buffer_a.IsValid());
EXPECT_FALSE(hardware_buffer_a.IsLocked());
void* const ahardware_buffer_ptr_b = hardware_buffer_b.GetAHardwareBuffer();
EXPECT_EQ(ahardware_buffer_ptr_a, ahardware_buffer_ptr_b);
EXPECT_TRUE(hardware_buffer_b.IsValid());
EXPECT_TRUE(hardware_buffer_b.IsLocked());
EXPECT_EQ(hardware_buffer_a.spec(), HardwareBufferSpec());
EXPECT_EQ(hardware_buffer_b.spec(), spec);
MP_ASSERT_OK(hardware_buffer_b.Unlock());
}
TEST(HardwareBufferTest, ShouldSupportReadWrite) {
constexpr std::string_view kTestString = "TestString";
constexpr int kBufferSize = kTestString.size();
MP_ASSERT_OK_AND_ASSIGN(
HardwareBuffer hardware_buffer,
HardwareBuffer::Create(GetTestHardwareBufferSpec(kBufferSize)));
// Write test string.
MP_ASSERT_OK_AND_ASSIGN(
void* const write_ptr,
hardware_buffer.Lock(
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY));
memcpy(write_ptr, kTestString.data(), kBufferSize);
MP_ASSERT_OK(hardware_buffer.Unlock());
// Read test string.
MP_ASSERT_OK_AND_ASSIGN(
void* const read_ptr,
hardware_buffer.Lock(
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_READ_RARELY));
EXPECT_EQ(memcmp(read_ptr, kTestString.data(), kBufferSize), 0);
MP_ASSERT_OK(hardware_buffer.Unlock());
}
} // namespace
} // namespace mediapipe

View File

@ -15,7 +15,6 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_ #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_
#include <cstdint>
#include <utility> #include <utility>
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
@ -208,7 +207,7 @@ inline void Image::UnlockPixels() const {}
// Image buf = ... // Image buf = ...
// { // {
// PixelLock lock(&buf); // PixelLock lock(&buf);
// uint8_t* buf_ptr = lock.Pixels(); // uint8* buf_ptr = lock.Pixels();
// ... use buf_ptr to access pixel data ... // ... use buf_ptr to access pixel data ...
// ... lock released automatically at end of scope ... // ... lock released automatically at end of scope ...
// } // }
@ -229,7 +228,7 @@ class PixelReadLock {
} }
PixelReadLock(const PixelReadLock&) = delete; PixelReadLock(const PixelReadLock&) = delete;
const uint8_t* Pixels() const { const uint8* Pixels() const {
if (frame_) return frame_->PixelData(); if (frame_) return frame_->PixelData();
return nullptr; return nullptr;
} }
@ -255,7 +254,7 @@ class PixelWriteLock {
} }
PixelWriteLock(const PixelWriteLock&) = delete; PixelWriteLock(const PixelWriteLock&) = delete;
uint8_t* Pixels() { uint8* Pixels() {
if (frame_) return frame_->MutablePixelData(); if (frame_) return frame_->MutablePixelData();
return nullptr; return nullptr;
} }

View File

@ -35,13 +35,13 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_H_ #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_H_
#include <cstdint>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/tool/type_util.h" #include "mediapipe/framework/tool/type_util.h"
#define IMAGE_FRAME_RAW_IMAGE MEDIAPIPE_HAS_RTTI #define IMAGE_FRAME_RAW_IMAGE MEDIAPIPE_HAS_RTTI
@ -63,7 +63,7 @@ namespace mediapipe {
// stored with row padding for alignment purposes. // stored with row padding for alignment purposes.
class ImageFrame { class ImageFrame {
public: public:
typedef std::function<void(uint8_t*)> Deleter; typedef std::function<void(uint8*)> Deleter;
// This class offers a few standard delete functions and retains // This class offers a few standard delete functions and retains
// compatibility with the previous API. // compatibility with the previous API.
@ -78,13 +78,13 @@ class ImageFrame {
// Use a default alignment boundary of 16 because Intel SSE2 instructions may // Use a default alignment boundary of 16 because Intel SSE2 instructions may
// incur performance penalty when accessing data not aligned on a 16-byte // incur performance penalty when accessing data not aligned on a 16-byte
// boundary. FFmpeg requires at least this level of alignment. // boundary. FFmpeg requires at least this level of alignment.
static const uint32_t kDefaultAlignmentBoundary = 16; static const uint32 kDefaultAlignmentBoundary = 16;
// If the pixel data of an ImageFrame will be passed to an OpenGL function // If the pixel data of an ImageFrame will be passed to an OpenGL function
// such as glTexImage2D() or glReadPixels(), use a four-byte alignment // such as glTexImage2D() or glReadPixels(), use a four-byte alignment
// boundary because that is the initial value of the OpenGL GL_PACK_ALIGNMENT // boundary because that is the initial value of the OpenGL GL_PACK_ALIGNMENT
// and GL_UNPACK_ALIGNMENT parameters. // and GL_UNPACK_ALIGNMENT parameters.
static const uint32_t kGlDefaultAlignmentBoundary = 4; static const uint32 kGlDefaultAlignmentBoundary = 4;
// Returns number of channels for an ImageFormat. // Returns number of channels for an ImageFormat.
static int NumberOfChannelsForFormat(ImageFormat::Format format); static int NumberOfChannelsForFormat(ImageFormat::Format format);
@ -104,7 +104,7 @@ class ImageFrame {
// must be a power of 2 (the number 1 is valid, and means the data will // must be a power of 2 (the number 1 is valid, and means the data will
// be stored contiguously). // be stored contiguously).
ImageFrame(ImageFormat::Format format, int width, int height, ImageFrame(ImageFormat::Format format, int width, int height,
uint32_t alignment_boundary); uint32 alignment_boundary);
// Same as above, but use kDefaultAlignmentBoundary for alignment_boundary. // Same as above, but use kDefaultAlignmentBoundary for alignment_boundary.
ImageFrame(ImageFormat::Format format, int width, int height); ImageFrame(ImageFormat::Format format, int width, int height);
@ -115,8 +115,8 @@ class ImageFrame {
// width*num_channels*depth. Both width_step and depth are in units // width*num_channels*depth. Both width_step and depth are in units
// of bytes. // of bytes.
ImageFrame(ImageFormat::Format format, int width, int height, int width_step, ImageFrame(ImageFormat::Format format, int width, int height, int width_step,
uint8_t* pixel_data, uint8* pixel_data,
Deleter deleter = std::default_delete<uint8_t[]>()); Deleter deleter = std::default_delete<uint8[]>());
ImageFrame(ImageFrame&& move_from); ImageFrame(ImageFrame&& move_from);
ImageFrame& operator=(ImageFrame&& move_from); ImageFrame& operator=(ImageFrame&& move_from);
@ -142,7 +142,7 @@ class ImageFrame {
// alignment_boundary. If IsAligned(16) is true then so are // alignment_boundary. If IsAligned(16) is true then so are
// IsAligned(8), IsAligned(4), IsAligned(2), and IsAligned(1). // IsAligned(8), IsAligned(4), IsAligned(2), and IsAligned(1).
// alignment_boundary must be 1 or a power of 2. // alignment_boundary must be 1 or a power of 2.
bool IsAligned(uint32_t alignment_boundary) const; bool IsAligned(uint32 alignment_boundary) const;
// Returns the image / video format. // Returns the image / video format.
ImageFormat::Format Format() const { return format_; } ImageFormat::Format Format() const { return format_; }
@ -167,13 +167,13 @@ class ImageFrame {
// Reset the current image frame and copy the data from image_frame into // Reset the current image frame and copy the data from image_frame into
// this image frame. The alignment_boundary must be given (and won't // this image frame. The alignment_boundary must be given (and won't
// necessarily match the alignment_boundary of the input image_frame). // necessarily match the alignment_boundary of the input image_frame).
void CopyFrom(const ImageFrame& image_frame, uint32_t alignment_boundary); void CopyFrom(const ImageFrame& image_frame, uint32 alignment_boundary);
// Get a mutable pointer to the underlying image data. The ImageFrame // Get a mutable pointer to the underlying image data. The ImageFrame
// retains ownership. // retains ownership.
uint8_t* MutablePixelData() { return pixel_data_.get(); } uint8* MutablePixelData() { return pixel_data_.get(); }
// Get a const pointer to the underlying image data. // Get a const pointer to the underlying image data.
const uint8_t* PixelData() const { return pixel_data_.get(); } const uint8* PixelData() const { return pixel_data_.get(); }
// Returns the total size of the pixel data. // Returns the total size of the pixel data.
int PixelDataSize() const { return Height() * WidthStep(); } int PixelDataSize() const { return Height() * WidthStep(); }
@ -187,41 +187,41 @@ class ImageFrame {
// ImageFrame takes ownership of pixel_data. See the Constructor // ImageFrame takes ownership of pixel_data. See the Constructor
// with the same arguments for details. // with the same arguments for details.
void AdoptPixelData(ImageFormat::Format format, int width, int height, void AdoptPixelData(ImageFormat::Format format, int width, int height,
int width_step, uint8_t* pixel_data, int width_step, uint8* pixel_data,
Deleter deleter = std::default_delete<uint8_t[]>()); Deleter deleter = std::default_delete<uint8[]>());
// Resets the ImageFrame and makes it a copy of the provided pixel // Resets the ImageFrame and makes it a copy of the provided pixel
// data, which is assumed to be stored contiguously. The ImageFrame // data, which is assumed to be stored contiguously. The ImageFrame
// will use the given alignment_boundary. // will use the given alignment_boundary.
void CopyPixelData(ImageFormat::Format format, int width, int height, void CopyPixelData(ImageFormat::Format format, int width, int height,
const uint8_t* pixel_data, uint32_t alignment_boundary); const uint8* pixel_data, uint32 alignment_boundary);
// Resets the ImageFrame and makes it a copy of the provided pixel // Resets the ImageFrame and makes it a copy of the provided pixel
// data, with given width_step. The ImageFrame // data, with given width_step. The ImageFrame
// will use the given alignment_boundary. // will use the given alignment_boundary.
void CopyPixelData(ImageFormat::Format format, int width, int height, void CopyPixelData(ImageFormat::Format format, int width, int height,
int width_step, const uint8_t* pixel_data, int width_step, const uint8* pixel_data,
uint32_t alignment_boundary); uint32 alignment_boundary);
// Allocates a frame of the specified format, width, height, and alignment, // Allocates a frame of the specified format, width, height, and alignment,
// without clearing any current pixel data. See the constructor with the same // without clearing any current pixel data. See the constructor with the same
// argument list. // argument list.
void Reset(ImageFormat::Format format, int width, int height, void Reset(ImageFormat::Format format, int width, int height,
uint32_t alignment_boundary); uint32 alignment_boundary);
// Relinquishes ownership of the pixel data. Notice that the unique_ptr // Relinquishes ownership of the pixel data. Notice that the unique_ptr
// uses a non-standard deleter. // uses a non-standard deleter.
std::unique_ptr<uint8_t[], Deleter> Release(); std::unique_ptr<uint8[], Deleter> Release();
// Copy the 8-bit ImageFrame into a contiguous, pre-allocated buffer. Note // Copy the 8-bit ImageFrame into a contiguous, pre-allocated buffer. Note
// that ImageFrame does not necessarily store its data contiguously (i.e. do // that ImageFrame does not necessarily store its data contiguously (i.e. do
// not use copy_n to move image data). // not use copy_n to move image data).
void CopyToBuffer(uint8_t* buffer, int buffer_size) const; void CopyToBuffer(uint8* buffer, int buffer_size) const;
// A version of CopyToBuffer for 16-bit pixel data. Note that buffer_size // A version of CopyToBuffer for 16-bit pixel data. Note that buffer_size
// stores the number of 16-bit elements in the buffer, not the number of // stores the number of 16-bit elements in the buffer, not the number of
// bytes. // bytes.
void CopyToBuffer(uint16_t* buffer, int buffer_size) const; void CopyToBuffer(uint16* buffer, int buffer_size) const;
// A version of CopyToBuffer for float pixel data. Note that buffer_size // A version of CopyToBuffer for float pixel data. Note that buffer_size
// stores the number of float elements in the buffer, not the number of // stores the number of float elements in the buffer, not the number of
@ -233,12 +233,12 @@ class ImageFrame {
private: private:
// Returns true if alignment_number is 1 or a power of 2. // Returns true if alignment_number is 1 or a power of 2.
static bool IsValidAlignmentNumber(uint32_t alignment_boundary); static bool IsValidAlignmentNumber(uint32 alignment_boundary);
// The internal implementation of copying data from the provided pixel data. // The internal implementation of copying data from the provided pixel data.
// If width_step is 0, then calculates width_step assuming no padding. // If width_step is 0, then calculates width_step assuming no padding.
void InternalCopyFrom(int width, int height, int width_step, int channel_size, void InternalCopyFrom(int width, int height, int width_step, int channel_size,
const uint8_t* pixel_data); const uint8* pixel_data);
// The internal implementation of copying data to the provided buffer. // The internal implementation of copying data to the provided buffer.
// If width_step is 0, then calculates width_step assuming no padding. // If width_step is 0, then calculates width_step assuming no padding.
@ -249,7 +249,7 @@ class ImageFrame {
int height_; int height_;
int width_step_; int width_step_;
std::unique_ptr<uint8_t[], Deleter> pixel_data_; std::unique_ptr<uint8[], Deleter> pixel_data_;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -16,7 +16,6 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <cstdint>
#include <memory> #include <memory>
#include "absl/log/absl_check.h" #include "absl/log/absl_check.h"

View File

@ -215,7 +215,7 @@ Location CreateCvMaskLocation(const cv::Mat_<T>& mask) {
return Location(location_data); return Location(location_data);
} }
template Location CreateCvMaskLocation(const cv::Mat_<uint8_t>& mask); template Location CreateCvMaskLocation(const cv::Mat_<uint8>& mask);
template Location CreateCvMaskLocation(const cv::Mat_<float>& mask); template Location CreateCvMaskLocation(const cv::Mat_<float>& mask);
} // namespace mediapipe } // namespace mediapipe

View File

@ -24,9 +24,6 @@
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include "mediapipe/framework/formats/hardware_buffer.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
#import <Metal/Metal.h> #import <Metal/Metal.h>
@ -342,14 +339,6 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
<< "Tensor conversion between different GPU backing formats is not " << "Tensor conversion between different GPU backing formats is not "
"supported yet."; "supported yet.";
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_)); auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
if ((valid_ & kValidOpenGlBuffer) && gl_context_ != nullptr &&
!gl_context_->IsCurrent() && GlContext::IsAnyContextCurrent()) {
ABSL_LOG_FIRST_N(WARNING, 1)
<< "Tensor::GetOpenGlBufferReadView is not executed on the same GL "
"context where GL buffer was created. Note that Tensor has "
"limited synchronization support when sharing OpenGl objects "
"between multiple OpenGL contexts.";
}
AllocateOpenGlBuffer(); AllocateOpenGlBuffer();
if (!(valid_ & kValidOpenGlBuffer)) { if (!(valid_ & kValidOpenGlBuffer)) {
// If the call succeeds then AHWB -> SSBO are synchronized so any usage of // If the call succeeds then AHWB -> SSBO are synchronized so any usage of
@ -367,13 +356,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
} }
return {opengl_buffer_, std::move(lock), return {opengl_buffer_, std::move(lock),
#ifdef MEDIAPIPE_TENSOR_USE_AHWB #ifdef MEDIAPIPE_TENSOR_USE_AHWB
// ssbo_read_ is passed to be populated on OpenGlBufferView &ssbo_read_
// destruction in order to perform delayed resources releasing (see
// tensor_ahwb.cc/DelayedReleaser) only when AHWB is in use.
//
// Not passing for the case when AHWB is not in use to avoid creation
// of unnecessary sync object and memory leak.
use_ahwb_ ? &ssbo_read_ : nullptr
#else #else
nullptr nullptr
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB
@ -384,14 +367,6 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView(
uint64_t source_location_hash) const { uint64_t source_location_hash) const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_)); auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
TrackAhwbUsage(source_location_hash); TrackAhwbUsage(source_location_hash);
if ((valid_ & kValidOpenGlBuffer) && gl_context_ != nullptr &&
!gl_context_->IsCurrent() && GlContext::IsAnyContextCurrent()) {
ABSL_LOG_FIRST_N(WARNING, 1)
<< "Tensor::GetOpenGlBufferWriteView is not executed on the same GL "
"context where GL buffer was created. Note that Tensor has "
"limited synchronization support when sharing OpenGl objects "
"between multiple OpenGL contexts.";
}
AllocateOpenGlBuffer(); AllocateOpenGlBuffer();
valid_ = kValidOpenGlBuffer; valid_ = kValidOpenGlBuffer;
return {opengl_buffer_, std::move(lock), nullptr}; return {opengl_buffer_, std::move(lock), nullptr};
@ -561,8 +536,9 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
void* ptr = MapAhwbToCpuRead(); void* ptr = MapAhwbToCpuRead();
if (ptr) { if (ptr) {
valid_ |= kValidCpu; valid_ |= kValidCpu;
return {ptr, std::move(lock), [ahwb = ahwb_.get()] { return {ptr, std::move(lock), [ahwb = ahwb_] {
ABSL_CHECK_OK(ahwb->Unlock()) << "Unlock failed."; auto error = AHardwareBuffer_unlock(ahwb, nullptr);
ABSL_CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
}}; }};
} }
} }
@ -644,11 +620,9 @@ Tensor::CpuWriteView Tensor::GetCpuWriteView(
if (__builtin_available(android 26, *)) { if (__builtin_available(android 26, *)) {
void* ptr = MapAhwbToCpuWrite(); void* ptr = MapAhwbToCpuWrite();
if (ptr) { if (ptr) {
return {ptr, std::move(lock), return {ptr, std::move(lock), [ahwb = ahwb_, fence_fd = &fence_fd_] {
[ahwb = ahwb_.get(), fence_fd = &fence_fd_] { auto error = AHardwareBuffer_unlock(ahwb, fence_fd);
auto fence_fd_status = ahwb->UnlockAsync(); ABSL_CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
ABSL_CHECK_OK(fence_fd_status) << "Unlock failed.";
*fence_fd = fence_fd_status.value();
}}; }};
} }
} }

View File

@ -44,8 +44,7 @@
#ifdef MEDIAPIPE_TENSOR_USE_AHWB #ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <EGL/egl.h> #include <EGL/egl.h>
#include <EGL/eglext.h> #include <EGL/eglext.h>
#include <android/hardware_buffer.h>
#include "mediapipe/framework/formats/hardware_buffer.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
@ -196,11 +195,9 @@ class Tensor {
using FinishingFunc = std::function<bool(bool)>; using FinishingFunc = std::function<bool(bool)>;
class AHardwareBufferView : public View { class AHardwareBufferView : public View {
public: public:
AHardwareBuffer* handle() const { AHardwareBuffer* handle() const { return handle_; }
return hardware_buffer_->GetAHardwareBuffer();
}
AHardwareBufferView(AHardwareBufferView&& src) : View(std::move(src)) { AHardwareBufferView(AHardwareBufferView&& src) : View(std::move(src)) {
hardware_buffer_ = std::move(src.hardware_buffer_); handle_ = std::exchange(src.handle_, nullptr);
file_descriptor_ = src.file_descriptor_; file_descriptor_ = src.file_descriptor_;
fence_fd_ = std::exchange(src.fence_fd_, nullptr); fence_fd_ = std::exchange(src.fence_fd_, nullptr);
ahwb_written_ = std::exchange(src.ahwb_written_, nullptr); ahwb_written_ = std::exchange(src.ahwb_written_, nullptr);
@ -225,17 +222,17 @@ class Tensor {
protected: protected:
friend class Tensor; friend class Tensor;
AHardwareBufferView(HardwareBuffer* hardware_buffer, int file_descriptor, AHardwareBufferView(AHardwareBuffer* handle, int file_descriptor,
int* fence_fd, FinishingFunc* ahwb_written, int* fence_fd, FinishingFunc* ahwb_written,
std::function<void()>* release_callback, std::function<void()>* release_callback,
std::unique_ptr<absl::MutexLock>&& lock) std::unique_ptr<absl::MutexLock>&& lock)
: View(std::move(lock)), : View(std::move(lock)),
hardware_buffer_(hardware_buffer), handle_(handle),
file_descriptor_(file_descriptor), file_descriptor_(file_descriptor),
fence_fd_(fence_fd), fence_fd_(fence_fd),
ahwb_written_(ahwb_written), ahwb_written_(ahwb_written),
release_callback_(release_callback) {} release_callback_(release_callback) {}
HardwareBuffer* hardware_buffer_; AHardwareBuffer* handle_;
int file_descriptor_; int file_descriptor_;
// The view sets some Tensor's fields. The view is released prior to tensor. // The view sets some Tensor's fields. The view is released prior to tensor.
int* fence_fd_; int* fence_fd_;
@ -288,22 +285,18 @@ class Tensor {
class OpenGlBufferView : public View { class OpenGlBufferView : public View {
public: public:
GLuint name() const { return name_; } GLuint name() const { return name_; }
OpenGlBufferView(OpenGlBufferView&& src) : View(std::move(src)) { OpenGlBufferView(OpenGlBufferView&& src) : View(std::move(src)) {
name_ = std::exchange(src.name_, GL_INVALID_INDEX); name_ = std::exchange(src.name_, GL_INVALID_INDEX);
ssbo_read_ = std::exchange(src.ssbo_read_, nullptr); ssbo_read_ = std::exchange(src.ssbo_read_, nullptr);
} }
~OpenGlBufferView() { ~OpenGlBufferView() {
if (ssbo_read_) { if (ssbo_read_) {
// TODO: update tensor to properly handle cases when
// multiple views were requested multiple sync fence may be needed.
*ssbo_read_ = glFenceSync(GL_SYNC_GPU_COMMANDS_COMPLETE, 0); *ssbo_read_ = glFenceSync(GL_SYNC_GPU_COMMANDS_COMPLETE, 0);
} }
} }
protected: protected:
friend class Tensor; friend class Tensor;
OpenGlBufferView(GLuint name, std::unique_ptr<absl::MutexLock>&& lock, OpenGlBufferView(GLuint name, std::unique_ptr<absl::MutexLock>&& lock,
GLsync* ssbo_read) GLsync* ssbo_read)
: View(std::move(lock)), name_(name), ssbo_read_(ssbo_read) {} : View(std::move(lock)), name_(name), ssbo_read_(ssbo_read) {}
@ -391,7 +384,7 @@ class Tensor {
mutable std::unique_ptr<MtlResources> mtl_resources_; mutable std::unique_ptr<MtlResources> mtl_resources_;
#ifdef MEDIAPIPE_TENSOR_USE_AHWB #ifdef MEDIAPIPE_TENSOR_USE_AHWB
mutable std::unique_ptr<HardwareBuffer> ahwb_; mutable AHardwareBuffer* ahwb_ = nullptr;
// Signals when GPU finished writing into SSBO so AHWB can be used then. Or // Signals when GPU finished writing into SSBO so AHWB can be used then. Or
// signals when writing into AHWB has been finished so GPU can read from SSBO. // signals when writing into AHWB has been finished so GPU can read from SSBO.
// Sync and FD are bound together. // Sync and FD are bound together.

View File

@ -10,7 +10,7 @@
#include "absl/log/absl_check.h" #include "absl/log/absl_check.h"
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/formats/hardware_buffer.h" #include "mediapipe/framework/port.h"
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB
@ -97,7 +97,7 @@ class DelayedReleaser {
DelayedReleaser(DelayedReleaser&&) = delete; DelayedReleaser(DelayedReleaser&&) = delete;
DelayedReleaser& operator=(DelayedReleaser&&) = delete; DelayedReleaser& operator=(DelayedReleaser&&) = delete;
static void Add(std::unique_ptr<HardwareBuffer> ahwb, GLuint opengl_buffer, static void Add(AHardwareBuffer* ahwb, GLuint opengl_buffer,
EGLSyncKHR ssbo_sync, GLsync ssbo_read, EGLSyncKHR ssbo_sync, GLsync ssbo_read,
Tensor::FinishingFunc&& ahwb_written, Tensor::FinishingFunc&& ahwb_written,
std::shared_ptr<mediapipe::GlContext> gl_context, std::shared_ptr<mediapipe::GlContext> gl_context,
@ -115,8 +115,8 @@ class DelayedReleaser {
// Using `new` to access a non-public constructor. // Using `new` to access a non-public constructor.
to_release_local.emplace_back(absl::WrapUnique(new DelayedReleaser( to_release_local.emplace_back(absl::WrapUnique(new DelayedReleaser(
std::move(ahwb), opengl_buffer, ssbo_sync, ssbo_read, ahwb, opengl_buffer, ssbo_sync, ssbo_read, std::move(ahwb_written),
std::move(ahwb_written), gl_context, std::move(callback)))); gl_context, std::move(callback))));
for (auto it = to_release_local.begin(); it != to_release_local.end();) { for (auto it = to_release_local.begin(); it != to_release_local.end();) {
if ((*it)->IsSignaled()) { if ((*it)->IsSignaled()) {
it = to_release_local.erase(it); it = to_release_local.erase(it);
@ -136,6 +136,9 @@ class DelayedReleaser {
~DelayedReleaser() { ~DelayedReleaser() {
if (release_callback_) release_callback_(); if (release_callback_) release_callback_();
if (__builtin_available(android 26, *)) {
AHardwareBuffer_release(ahwb_);
}
} }
bool IsSignaled() { bool IsSignaled() {
@ -178,7 +181,7 @@ class DelayedReleaser {
} }
protected: protected:
std::unique_ptr<HardwareBuffer> ahwb_; AHardwareBuffer* ahwb_;
GLuint opengl_buffer_; GLuint opengl_buffer_;
// TODO: use wrapper instead. // TODO: use wrapper instead.
EGLSyncKHR fence_sync_; EGLSyncKHR fence_sync_;
@ -189,12 +192,12 @@ class DelayedReleaser {
std::function<void()> release_callback_; std::function<void()> release_callback_;
static inline std::deque<std::unique_ptr<DelayedReleaser>> to_release_; static inline std::deque<std::unique_ptr<DelayedReleaser>> to_release_;
DelayedReleaser(std::unique_ptr<HardwareBuffer> ahwb, GLuint opengl_buffer, DelayedReleaser(AHardwareBuffer* ahwb, GLuint opengl_buffer,
EGLSyncKHR fence_sync, GLsync ssbo_read, EGLSyncKHR fence_sync, GLsync ssbo_read,
Tensor::FinishingFunc&& ahwb_written, Tensor::FinishingFunc&& ahwb_written,
std::shared_ptr<mediapipe::GlContext> gl_context, std::shared_ptr<mediapipe::GlContext> gl_context,
std::function<void()>&& callback) std::function<void()>&& callback)
: ahwb_(std::move(ahwb)), : ahwb_(ahwb),
opengl_buffer_(opengl_buffer), opengl_buffer_(opengl_buffer),
fence_sync_(fence_sync), fence_sync_(fence_sync),
ssbo_read_(ssbo_read), ssbo_read_(ssbo_read),
@ -211,7 +214,7 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
ABSL_CHECK(!(valid_ & kValidOpenGlTexture2d)) ABSL_CHECK(!(valid_ & kValidOpenGlTexture2d))
<< "Tensor conversion between OpenGL texture and AHardwareBuffer is not " << "Tensor conversion between OpenGL texture and AHardwareBuffer is not "
"supported."; "supported.";
bool transfer = ahwb_ == nullptr; bool transfer = !ahwb_;
ABSL_CHECK(AllocateAHardwareBuffer()) ABSL_CHECK(AllocateAHardwareBuffer())
<< "AHardwareBuffer is not supported on the target system."; << "AHardwareBuffer is not supported on the target system.";
valid_ |= kValidAHardwareBuffer; valid_ |= kValidAHardwareBuffer;
@ -220,10 +223,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
} else { } else {
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd(); if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
} }
return {ahwb_.get(), ssbo_written_, return {ahwb_,
ssbo_written_,
&fence_fd_, // The FD is created for SSBO -> AHWB synchronization. &fence_fd_, // The FD is created for SSBO -> AHWB synchronization.
&ahwb_written_, // Filled by SetReadingFinishedFunc. &ahwb_written_, // Filled by SetReadingFinishedFunc.
&release_callback_, std::move(lock)}; &release_callback_,
std::move(lock)};
} }
void Tensor::CreateEglSyncAndFd() const { void Tensor::CreateEglSyncAndFd() const {
@ -253,11 +258,12 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
ABSL_CHECK(AllocateAHardwareBuffer(size_alignment)) ABSL_CHECK(AllocateAHardwareBuffer(size_alignment))
<< "AHardwareBuffer is not supported on the target system."; << "AHardwareBuffer is not supported on the target system.";
valid_ = kValidAHardwareBuffer; valid_ = kValidAHardwareBuffer;
return {ahwb_.get(), return {ahwb_,
/*ssbo_written=*/-1, /*ssbo_written=*/-1,
&fence_fd_, // For SetWritingFinishedFD. &fence_fd_, // For SetWritingFinishedFD.
&ahwb_written_, // Filled by SetReadingFinishedFunc. &ahwb_written_,
&release_callback_, std::move(lock)}; &release_callback_,
std::move(lock)};
} }
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const { bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
@ -270,43 +276,40 @@ bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
} }
use_ahwb_ = true; use_ahwb_ = true;
if (__builtin_available(android 26, *)) {
if (ahwb_ == nullptr) { if (ahwb_ == nullptr) {
HardwareBufferSpec spec = {}; AHardwareBuffer_Desc desc = {};
if (size_alignment == 0) { if (size_alignment == 0) {
spec.width = bytes(); desc.width = bytes();
} else { } else {
// We expect allocations to be page-aligned, implicitly satisfying any // We expect allocations to be page-aligned, implicitly satisfying any
// requirements from Edge TPU. No need to add a check for this, // requirements from Edge TPU. No need to add a check for this,
// since Edge TPU will check for us. // since Edge TPU will check for us.
spec.width = AlignedToPowerOf2(bytes(), size_alignment); desc.width = AlignedToPowerOf2(bytes(), size_alignment);
} }
spec.height = 1; desc.height = 1;
spec.layers = 1; desc.layers = 1;
spec.format = HardwareBufferSpec::AHARDWAREBUFFER_FORMAT_BLOB; desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
spec.usage = HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN | desc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN |
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
HardwareBufferSpec::AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER; AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER;
auto new_ahwb = HardwareBuffer::Create(spec); return AHardwareBuffer_allocate(&desc, &ahwb_) == 0;
if (!new_ahwb.ok()) {
ABSL_LOG(ERROR) << "Allocation of NDK Hardware Buffer failed: "
<< new_ahwb.status();
return false;
}
ahwb_ = std::make_unique<HardwareBuffer>(std::move(*new_ahwb));
} }
return true; return true;
}
return false;
} }
bool Tensor::AllocateAhwbMapToSsbo() const { bool Tensor::AllocateAhwbMapToSsbo() const {
if (__builtin_available(android 26, *)) { if (__builtin_available(android 26, *)) {
if (AllocateAHardwareBuffer()) { if (AllocateAHardwareBuffer()) {
if (MapAHardwareBufferToGlBuffer(ahwb_->GetAHardwareBuffer(), bytes()) if (MapAHardwareBufferToGlBuffer(ahwb_, bytes()).ok()) {
.ok()) {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
return true; return true;
} }
// Unable to make OpenGL <-> AHWB binding. Use regular SSBO instead. // Unable to make OpenGL <-> AHWB binding. Use regular SSBO instead.
ahwb_.reset(); AHardwareBuffer_release(ahwb_);
ahwb_ = nullptr;
} }
} }
return false; return false;
@ -314,11 +317,14 @@ bool Tensor::AllocateAhwbMapToSsbo() const {
// Moves Cpu/Ssbo resource under the Ahwb backed memory. // Moves Cpu/Ssbo resource under the Ahwb backed memory.
void Tensor::MoveCpuOrSsboToAhwb() const { void Tensor::MoveCpuOrSsboToAhwb() const {
auto dest = void* dest = nullptr;
ahwb_->Lock(HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY); if (__builtin_available(android 26, *)) {
ABSL_CHECK_OK(dest) << "Lock of AHWB failed"; auto error = AHardwareBuffer_lock(
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest);
ABSL_CHECK(error == 0) << "AHardwareBuffer_lock " << error;
}
if (valid_ & kValidCpu) { if (valid_ & kValidCpu) {
std::memcpy(*dest, cpu_buffer_, bytes()); std::memcpy(dest, cpu_buffer_, bytes());
// Free CPU memory because next time AHWB is mapped instead. // Free CPU memory because next time AHWB is mapped instead.
free(cpu_buffer_); free(cpu_buffer_);
cpu_buffer_ = nullptr; cpu_buffer_ = nullptr;
@ -328,7 +334,7 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
GL_MAP_READ_BIT); GL_MAP_READ_BIT);
std::memcpy(*dest, src, bytes()); std::memcpy(dest, src, bytes());
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
glDeleteBuffers(1, &opengl_buffer_); glDeleteBuffers(1, &opengl_buffer_);
}); });
@ -341,7 +347,10 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
ABSL_LOG(FATAL) << "Can't convert tensor with mask " << valid_ ABSL_LOG(FATAL) << "Can't convert tensor with mask " << valid_
<< " into AHWB."; << " into AHWB.";
} }
ABSL_CHECK_OK(ahwb_->Unlock()) << "Unlock of AHWB failed"; if (__builtin_available(android 26, *)) {
auto error = AHardwareBuffer_unlock(ahwb_, nullptr);
ABSL_CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
}
} }
// SSBO is created on top of AHWB. A fence is inserted into the GPU queue before // SSBO is created on top of AHWB. A fence is inserted into the GPU queue before
@ -394,52 +403,59 @@ void Tensor::ReleaseAhwbStuff() {
if (ahwb_) { if (ahwb_) {
if (ssbo_read_ != 0 || fence_sync_ != EGL_NO_SYNC_KHR || ahwb_written_) { if (ssbo_read_ != 0 || fence_sync_ != EGL_NO_SYNC_KHR || ahwb_written_) {
if (ssbo_written_ != -1) close(ssbo_written_); if (ssbo_written_ != -1) close(ssbo_written_);
DelayedReleaser::Add(std::move(ahwb_), opengl_buffer_, fence_sync_, DelayedReleaser::Add(ahwb_, opengl_buffer_, fence_sync_, ssbo_read_,
ssbo_read_, std::move(ahwb_written_), gl_context_, std::move(ahwb_written_), gl_context_,
std::move(release_callback_)); std::move(release_callback_));
opengl_buffer_ = GL_INVALID_INDEX; opengl_buffer_ = GL_INVALID_INDEX;
} else { } else {
if (release_callback_) release_callback_(); if (release_callback_) release_callback_();
ahwb_.reset(); AHardwareBuffer_release(ahwb_);
} }
} }
} }
} }
void* Tensor::MapAhwbToCpuRead() const { void* Tensor::MapAhwbToCpuRead() const {
if (ahwb_ != nullptr) { if (__builtin_available(android 26, *)) {
if (ahwb_) {
if (!(valid_ & kValidCpu)) { if (!(valid_ & kValidCpu)) {
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) { if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method. // EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
gl_context_->Run([]() { glFinish(); }); gl_context_->Run([]() { glFinish(); });
} else if (valid_ & kValidAHardwareBuffer) { } else if (valid_ & kValidAHardwareBuffer) {
ABSL_CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the " ABSL_CHECK(ahwb_written_)
<< "Ahwb-to-Cpu synchronization requires the "
"completion function to be set"; "completion function to be set";
ABSL_CHECK(ahwb_written_(true)) ABSL_CHECK(ahwb_written_(true))
<< "An error oqcured while waiting for the buffer to be written"; << "An error oqcured while waiting for the buffer to be written";
} }
} }
auto ptr = void* ptr;
ahwb_->Lock(HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, auto error =
ssbo_written_); AHardwareBuffer_lock(ahwb_, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
ABSL_CHECK_OK(ptr) << "Lock of AHWB failed"; ssbo_written_, nullptr, &ptr);
ABSL_CHECK(error == 0) << "AHardwareBuffer_lock " << error;
close(ssbo_written_); close(ssbo_written_);
ssbo_written_ = -1; ssbo_written_ = -1;
return *ptr; return ptr;
}
} }
return nullptr; return nullptr;
} }
void* Tensor::MapAhwbToCpuWrite() const { void* Tensor::MapAhwbToCpuWrite() const {
if (ahwb_ != nullptr) { if (__builtin_available(android 26, *)) {
if (ahwb_) {
// TODO: If previously acquired view is GPU write view then need // TODO: If previously acquired view is GPU write view then need
// to be sure that writing is finished. That's a warning: two consequent // to be sure that writing is finished. That's a warning: two consequent
// write views should be interleaved with read view. // write views should be interleaved with read view.
auto locked_ptr = void* ptr;
ahwb_->Lock(HardwareBufferSpec::AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN); auto error = AHardwareBuffer_lock(
ABSL_CHECK_OK(locked_ptr) << "Lock of AHWB failed"; ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, -1, nullptr, &ptr);
return *locked_ptr; ABSL_CHECK(error == 0) << "AHardwareBuffer_lock " << error;
return ptr;
}
} }
return nullptr; return nullptr;
} }

View File

@ -6,7 +6,6 @@
#include <cstdint> #include <cstdint>
#include "absl/algorithm/container.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/formats/tensor/views/data_types.h" #include "mediapipe/framework/formats/tensor/views/data_types.h"
#include "mediapipe/gpu/gpu_test_base.h" #include "mediapipe/gpu/gpu_test_base.h"
@ -19,7 +18,7 @@
// Then the test requests the CPU view and compares the values. // Then the test requests the CPU view and compares the values.
// Float32 and Float16 tests are there. // Float32 and Float16 tests are there.
namespace mediapipe { namespace {
using mediapipe::Float16; using mediapipe::Float16;
using mediapipe::Tensor; using mediapipe::Tensor;
@ -28,16 +27,6 @@ MATCHER_P(NearWithPrecision, precision, "") {
return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision; return std::abs(std::get<0>(arg) - std::get<1>(arg)) < precision;
} }
template <typename F = float>
std::vector<F> CreateReferenceData(int num_elements) {
std::vector<F> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
return reference;
}
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
// Utility function to fill the GPU buffer. // Utility function to fill the GPU buffer.
@ -121,7 +110,11 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat32) {
}); });
auto ptr = tensor.GetCpuReadView().buffer<float>(); auto ptr = tensor.GetCpuReadView().buffer<float>();
ASSERT_NE(ptr, nullptr); ASSERT_NE(ptr, nullptr);
std::vector<float> reference = CreateReferenceData(num_elements); std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements), EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference)); testing::Pointwise(testing::FloatEq(), reference));
} }
@ -144,7 +137,11 @@ TEST_F(TensorAhwbGpuTest, TestGpuToCpuFloat16) {
}); });
auto ptr = tensor.GetCpuReadView().buffer<Float16>(); auto ptr = tensor.GetCpuReadView().buffer<Float16>();
ASSERT_NE(ptr, nullptr); ASSERT_NE(ptr, nullptr);
std::vector<Float16> reference = CreateReferenceData<Float16>(num_elements); std::vector<Float16> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
// Precision is set to a reasonable value for Float16. // Precision is set to a reasonable value for Float16.
EXPECT_THAT(absl::Span<const Float16>(ptr, num_elements), EXPECT_THAT(absl::Span<const Float16>(ptr, num_elements),
testing::Pointwise(NearWithPrecision(0.001), reference)); testing::Pointwise(NearWithPrecision(0.001), reference));
@ -169,7 +166,11 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
} }
auto ptr = tensor.GetCpuReadView().buffer<float>(); auto ptr = tensor.GetCpuReadView().buffer<float>();
ASSERT_NE(ptr, nullptr); ASSERT_NE(ptr, nullptr);
std::vector<float> reference = CreateReferenceData(num_elements); std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements), EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference)); testing::Pointwise(testing::FloatEq(), reference));
} }
@ -193,107 +194,17 @@ TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
} }
auto ptr = tensor.GetCpuReadView().buffer<float>(); auto ptr = tensor.GetCpuReadView().buffer<float>();
ASSERT_NE(ptr, nullptr); ASSERT_NE(ptr, nullptr);
std::vector<float> reference = CreateReferenceData(num_elements); std::vector<float> reference;
reference.resize(num_elements);
for (int i = 0; i < num_elements; i++) {
reference[i] = static_cast<float>(i) / 10.0f;
}
EXPECT_THAT(absl::Span<const float>(ptr, num_elements), EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
testing::Pointwise(testing::FloatEq(), reference)); testing::Pointwise(testing::FloatEq(), reference));
} }
std::vector<float> ReadGlBufferView(const Tensor::OpenGlBufferView& view,
int num_elements) {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, view.name());
int bytes = num_elements * sizeof(float);
void* ptr =
glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes, GL_MAP_READ_BIT);
ABSL_CHECK(ptr) << "glMapBufferRange failed: " << glGetError();
std::vector<float> data;
data.resize(num_elements);
std::memcpy(data.data(), ptr, bytes);
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
return data;
}
TEST_F(TensorAhwbGpuTest, TestGetOpenGlBufferReadViewNoAhwb) {
constexpr size_t kNumElements = 20;
std::vector<float> reference = CreateReferenceData(kNumElements);
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape({kNumElements}));
{
// Populate tensor on CPU and make sure view is destroyed
absl::c_copy(reference, tensor.GetCpuWriteView().buffer<float>());
}
RunInGlContext([&] {
// Triggers conversion to GL buffer.
auto ssbo_view = tensor.GetOpenGlBufferReadView();
ASSERT_NE(ssbo_view.name(), 0);
// ssbo_read_ must NOT be populated, as there's no AHWB associated with
// GL buffer
ASSERT_EQ(ssbo_view.ssbo_read_, nullptr);
std::vector<float> output = ReadGlBufferView(ssbo_view, kNumElements);
EXPECT_THAT(output, testing::Pointwise(testing::FloatEq(), reference));
});
}
TEST_F(TensorAhwbGpuTest, TestGetOpenGlBufferReadViewAhwbFromCpu) {
constexpr size_t kNumElements = 20;
std::vector<float> reference = CreateReferenceData(kNumElements);
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape({kNumElements}));
{
// Populate tensor on CPU and make sure view is destroyed
absl::c_copy(reference, tensor.GetCpuWriteView().buffer<float>());
}
{
// Make tensor to allocate ahwb and make sure view is destroyed.
ASSERT_NE(tensor.GetAHardwareBufferReadView().handle(), nullptr);
}
RunInGlContext([&] {
// Triggers conversion to GL buffer.
auto ssbo_view = tensor.GetOpenGlBufferReadView();
ASSERT_NE(ssbo_view.name(), 0);
// ssbo_read_ must be populated, so during view destruction it's set
// properly for further AHWB destruction
ASSERT_NE(ssbo_view.ssbo_read_, nullptr);
std::vector<float> output = ReadGlBufferView(ssbo_view, kNumElements);
EXPECT_THAT(output, testing::Pointwise(testing::FloatEq(), reference));
});
}
TEST_F(TensorAhwbGpuTest, TestGetOpenGlBufferReadViewAhwbFromGpu) {
constexpr size_t kNumElements = 20;
std::vector<float> reference = CreateReferenceData(kNumElements);
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape({kNumElements}));
{
// Make tensor to allocate ahwb and make sure view is destroyed.
ASSERT_NE(tensor.GetAHardwareBufferWriteView().handle(), nullptr);
}
RunInGlContext([&] {
FillGpuBuffer(tensor.GetOpenGlBufferWriteView().name(),
tensor.shape().num_elements(), tensor.element_type());
});
RunInGlContext([&] {
// Triggers conversion to GL buffer.
auto ssbo_view = tensor.GetOpenGlBufferReadView();
ASSERT_NE(ssbo_view.name(), 0);
// ssbo_read_ must be populated, so during view destruction it's set
// properly for further AHWB destruction
ASSERT_NE(ssbo_view.ssbo_read_, nullptr);
std::vector<float> output = ReadGlBufferView(ssbo_view, kNumElements);
EXPECT_THAT(output, testing::Pointwise(testing::FloatEq(), reference));
});
}
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
} // namespace mediapipe } // namespace
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 || #endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)) // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))

View File

@ -1,5 +1,3 @@
#include <android/hardware_buffer.h>
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "testing/base/public/gmock.h" #include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h" #include "testing/base/public/gunit.h"

View File

@ -15,7 +15,6 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_YUV_IMAGE_H_ #ifndef MEDIAPIPE_FRAMEWORK_FORMATS_YUV_IMAGE_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_YUV_IMAGE_H_ #define MEDIAPIPE_FRAMEWORK_FORMATS_YUV_IMAGE_H_
#include <cstdint>
#include <functional> #include <functional>
#include <memory> #include <memory>
@ -121,12 +120,12 @@ class YUVImage {
// Convenience constructor // Convenience constructor
YUVImage(libyuv::FourCC fourcc, // YUVImage(libyuv::FourCC fourcc, //
std::unique_ptr<uint8_t[]> data_location, // std::unique_ptr<uint8[]> data_location, //
uint8_t* data0, int stride0, // uint8* data0, int stride0, //
uint8_t* data1, int stride1, // uint8* data1, int stride1, //
uint8_t* data2, int stride2, // uint8* data2, int stride2, //
int width, int height, int bit_depth = 8) { int width, int height, int bit_depth = 8) {
uint8_t* tmp = data_location.release(); uint8* tmp = data_location.release();
std::function<void()> deallocate = [tmp]() { delete[] tmp; }; std::function<void()> deallocate = [tmp]() { delete[] tmp; };
Initialize(fourcc, // Initialize(fourcc, //
deallocate, // deallocate, //
@ -139,13 +138,13 @@ class YUVImage {
// Convenience constructor to construct the YUVImage with data stored // Convenience constructor to construct the YUVImage with data stored
// in three unique_ptrs. // in three unique_ptrs.
YUVImage(libyuv::FourCC fourcc, // YUVImage(libyuv::FourCC fourcc, //
std::unique_ptr<uint8_t[]> data0, int stride0, // std::unique_ptr<uint8[]> data0, int stride0, //
std::unique_ptr<uint8_t[]> data1, int stride1, // std::unique_ptr<uint8[]> data1, int stride1, //
std::unique_ptr<uint8_t[]> data2, int stride2, // std::unique_ptr<uint8[]> data2, int stride2, //
int width, int height, int bit_depth = 8) { int width, int height, int bit_depth = 8) {
uint8_t* tmp0 = data0.release(); uint8* tmp0 = data0.release();
uint8_t* tmp1 = data1.release(); uint8* tmp1 = data1.release();
uint8_t* tmp2 = data2.release(); uint8* tmp2 = data2.release();
std::function<void()> deallocate = [tmp0, tmp1, tmp2]() { std::function<void()> deallocate = [tmp0, tmp1, tmp2]() {
delete[] tmp0; delete[] tmp0;
delete[] tmp1; delete[] tmp1;
@ -178,9 +177,9 @@ class YUVImage {
// pixel format it holds. // pixel format it holds.
void Initialize(libyuv::FourCC fourcc, // void Initialize(libyuv::FourCC fourcc, //
std::function<void()> deallocation_function, // std::function<void()> deallocation_function, //
uint8_t* data0, int stride0, // uint8* data0, int stride0, //
uint8_t* data1, int stride1, // uint8* data1, int stride1, //
uint8_t* data2, int stride2, // uint8* data2, int stride2, //
int width, int height, int bit_depth = 8) { int width, int height, int bit_depth = 8) {
Clear(); Clear();
deallocation_function_ = deallocation_function; deallocation_function_ = deallocation_function;
@ -215,7 +214,7 @@ class YUVImage {
// Getters. // Getters.
libyuv::FourCC fourcc() const { return fourcc_; } libyuv::FourCC fourcc() const { return fourcc_; }
const uint8_t* data(int index) const { return data_[index]; } const uint8* data(int index) const { return data_[index]; }
int stride(int index) const { return stride_[index]; } int stride(int index) const { return stride_[index]; }
int width() const { return width_; } int width() const { return width_; }
int height() const { return height_; } int height() const { return height_; }
@ -227,7 +226,7 @@ class YUVImage {
// Setters. // Setters.
void set_fourcc(libyuv::FourCC fourcc) { fourcc_ = fourcc; } void set_fourcc(libyuv::FourCC fourcc) { fourcc_ = fourcc; }
uint8_t* mutable_data(int index) { return data_[index]; } uint8* mutable_data(int index) { return data_[index]; }
void set_stride(int index, int stride) { stride_[index] = stride; } void set_stride(int index, int stride) { stride_[index] = stride; }
void set_width(int width) { width_ = width; } void set_width(int width) { width_ = width; }
void set_height(int height) { height_ = height; } void set_height(int height) { height_ = height; }
@ -242,7 +241,7 @@ class YUVImage {
std::function<void()> deallocation_function_; std::function<void()> deallocation_function_;
libyuv::FourCC fourcc_ = libyuv::FOURCC_ANY; libyuv::FourCC fourcc_ = libyuv::FOURCC_ANY;
uint8_t* data_[kMaxNumPlanes]; uint8* data_[kMaxNumPlanes];
int stride_[kMaxNumPlanes]; int stride_[kMaxNumPlanes];
int width_ = 0; int width_ = 0;
int height_ = 0; int height_ = 0;

View File

@ -26,7 +26,7 @@ def replace_suffix(string, old, new):
def mediapipe_ts_library( def mediapipe_ts_library(
name, name,
srcs = [], srcs,
visibility = None, visibility = None,
deps = [], deps = [],
testonly = 0, testonly = 0,

View File

@ -196,7 +196,6 @@ cc_library(
":gpu_buffer_format", ":gpu_buffer_format",
"//mediapipe/framework:executor", "//mediapipe/framework:executor",
"//mediapipe/framework:mediapipe_profiling", "//mediapipe/framework:mediapipe_profiling",
"//mediapipe/framework:port",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
@ -210,7 +209,6 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],

View File

@ -26,9 +26,7 @@
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port.h" // IWYU pragma: keep
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
@ -50,17 +48,6 @@
namespace mediapipe { namespace mediapipe {
namespace internal_gl_context {
bool IsOpenGlVersionSameOrAbove(const OpenGlVersion& version,
const OpenGlVersion& expected_version) {
return (version.major == expected_version.major &&
version.minor >= expected_version.minor) ||
version.major > expected_version.major;
}
} // namespace internal_gl_context
static void SetThreadName(const char* name) { static void SetThreadName(const char* name) {
#if defined(__GLIBC_PREREQ) #if defined(__GLIBC_PREREQ)
#define LINUX_STYLE_SETNAME_NP __GLIBC_PREREQ(2, 12) #define LINUX_STYLE_SETNAME_NP __GLIBC_PREREQ(2, 12)
@ -651,11 +638,6 @@ class GlSyncWrapper {
// TODO: do something if the wait fails? // TODO: do something if the wait fails?
} }
// This method exists only for investigation purposes to distinguish stack
// traces: external vs. internal context.
// TODO: remove after glWaitSync crashes are resolved.
void WaitOnGpuExternalContext() { glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED); }
void WaitOnGpu() { void WaitOnGpu() {
if (!sync_) return; if (!sync_) return;
// WebGL2 specifies a waitSync call, but since cross-context // WebGL2 specifies a waitSync call, but since cross-context
@ -663,33 +645,6 @@ class GlSyncWrapper {
// a warning when it's called, so let's just skip the call. See // a warning when it's called, so let's just skip the call. See
// b/184637485 for details. // b/184637485 for details.
#ifndef __EMSCRIPTEN__ #ifndef __EMSCRIPTEN__
if (!GlContext::IsAnyContextCurrent()) {
// glWaitSync must be called on with some context current. Doing the
// opposite doesn't necessarily result in a crash or GL error. Hence,
// just logging an error and skipping the call.
ABSL_LOG_FIRST_N(ERROR, 1)
<< "An attempt to wait for a sync without any context current.";
return;
}
auto context = GlContext::GetCurrent();
if (context == nullptr) {
// This can happen when WaitOnGpu is invoked on an external context,
// created by other than GlContext::Create means.
WaitOnGpuExternalContext();
return;
}
// GlContext::ShouldUseFenceSync guards creation of sync objects, so this
// CHECK should never fail if clients use MediaPipe APIs in an intended way.
// TODO: remove after glWaitSync crashes are resolved.
ABSL_CHECK(context->ShouldUseFenceSync()) << absl::StrFormat(
"An attempt to wait for a sync when it should not be used. (OpenGL "
"Version "
"%d.%d)",
context->gl_major_version(), context->gl_minor_version());
glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED); glWaitSync(sync_, 0, GL_TIMEOUT_IGNORED);
#endif #endif
} }
@ -742,13 +697,10 @@ class GlFenceSyncPoint : public GlSyncPoint {
void Wait() override { void Wait() override {
if (!sync_) return; if (!sync_) return;
if (GlContext::IsAnyContextCurrent()) { gl_context_->Run([this] {
// TODO: must this run on the original context??
sync_.Wait(); sync_.Wait();
return; });
}
// In case a current GL context is not available, we fall back using the
// captured gl_context_.
gl_context_->Run([this] { sync_.Wait(); });
} }
void WaitOnGpu() override { void WaitOnGpu() override {
@ -860,25 +812,15 @@ class GlNopSyncPoint : public GlSyncPoint {
#endif #endif
bool GlContext::ShouldUseFenceSync() const { bool GlContext::ShouldUseFenceSync() const {
using internal_gl_context::OpenGlVersion; #ifdef __EMSCRIPTEN__
#if defined(__EMSCRIPTEN__)
// In Emscripten the glWaitSync function is non-null depending on linkopts, // In Emscripten the glWaitSync function is non-null depending on linkopts,
// but only works in a WebGL2 context. // but only works in a WebGL2 context, so fall back to use Finish if it is a
constexpr OpenGlVersion kMinVersionSyncAvaiable = {.major = 3, .minor = 0}; // WebGL1/ES2 context.
#elif defined(MEDIAPIPE_MOBILE) // TODO: apply this more generally once b/152794517 is fixed.
// OpenGL ES, glWaitSync is available since 3.0 return gl_major_version() > 2;
constexpr OpenGlVersion kMinVersionSyncAvaiable = {.major = 3, .minor = 0};
#else #else
// TODO: specify major/minor version per remaining platforms. return SymbolAvailable(&glWaitSync);
// By default, ignoring major/minor version requirement for backward #endif // __EMSCRIPTEN__
// compatibility.
constexpr OpenGlVersion kMinVersionSyncAvaiable = {.major = 0, .minor = 0};
#endif
return SymbolAvailable(&glWaitSync) &&
internal_gl_context::IsOpenGlVersionSameOrAbove(
{.major = gl_major_version(), .minor = gl_minor_version()},
kMinVersionSyncAvaiable);
} }
std::shared_ptr<GlSyncPoint> GlContext::CreateSyncToken() { std::shared_ptr<GlSyncPoint> GlContext::CreateSyncToken() {

View File

@ -71,8 +71,6 @@ typedef std::function<void()> GlVoidFunction;
typedef std::function<absl::Status()> GlStatusFunction; typedef std::function<absl::Status()> GlStatusFunction;
class GlContext; class GlContext;
// TODO: remove after glWaitSync crashes are resolved.
class GlSyncWrapper;
// Generic interface for synchronizing access to a shared resource from a // Generic interface for synchronizing access to a shared resource from a
// different context. This is an abstract class to keep users from // different context. This is an abstract class to keep users from
@ -192,7 +190,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// Like Run, but does not wait. // Like Run, but does not wait.
void RunWithoutWaiting(GlVoidFunction gl_func); void RunWithoutWaiting(GlVoidFunction gl_func);
// Returns a synchronization token for this GlContext. // Returns a synchronization token.
// This should not be called outside of the GlContext thread.
std::shared_ptr<GlSyncPoint> CreateSyncToken(); std::shared_ptr<GlSyncPoint> CreateSyncToken();
// If another part of the framework calls glFinish, it should call this // If another part of the framework calls glFinish, it should call this
@ -331,9 +330,6 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
SyncTokenTypeForTest type); SyncTokenTypeForTest type);
private: private:
// TODO: remove after glWaitSync crashes are resolved.
friend GlSyncWrapper;
GlContext(); GlContext();
bool ShouldUseFenceSync() const; bool ShouldUseFenceSync() const;
@ -492,18 +488,6 @@ ABSL_DEPRECATED(
const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
int plane); int plane);
namespace internal_gl_context {
struct OpenGlVersion {
int major;
int minor;
};
bool IsOpenGlVersionSameOrAbove(const OpenGlVersion& version,
const OpenGlVersion& expected_version);
} // namespace internal_gl_context
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_GPU_GL_CONTEXT_H_ #endif // MEDIAPIPE_GPU_GL_CONTEXT_H_

View File

@ -32,7 +32,7 @@ namespace mediapipe {
// TODO: Handle webGL "context lost" and "context restored" events. // TODO: Handle webGL "context lost" and "context restored" events.
GlContext::StatusOrGlContext GlContext::Create(std::nullptr_t nullp, GlContext::StatusOrGlContext GlContext::Create(std::nullptr_t nullp,
bool create_thread) { bool create_thread) {
return Create(static_cast<EMSCRIPTEN_WEBGL_CONTEXT_HANDLE>(0), create_thread); return Create(0, create_thread);
} }
GlContext::StatusOrGlContext GlContext::Create(const GlContext& share_context, GlContext::StatusOrGlContext GlContext::Create(const GlContext& share_context,

View File

@ -105,6 +105,7 @@ absl::Status GpuBufferToImageFrameCalculator::Process(CalculatorContext* cc) {
helper_.GetGlVersion()); helper_.GetGlVersion());
glReadPixels(0, 0, src.width(), src.height(), info.gl_format, glReadPixels(0, 0, src.width(), src.height(), info.gl_format,
info.gl_type, frame->MutablePixelData()); info.gl_type, frame->MutablePixelData());
glFlush();
cc->Outputs().Index(0).Add(frame.release(), cc->InputTimestamp()); cc->Outputs().Index(0).Add(frame.release(), cc->InputTimestamp());
src.Release(); src.Release();
}); });

View File

@ -18,7 +18,6 @@
#import <Accelerate/Accelerate.h> #import <Accelerate/Accelerate.h>
#include <atomic> #include <atomic>
#include <cstdint>
#import "GTMDefines.h" #import "GTMDefines.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -49,7 +48,7 @@
std::atomic<int32_t> _framesInFlight; std::atomic<int32_t> _framesInFlight;
/// Used as a sequential timestamp for MediaPipe. /// Used as a sequential timestamp for MediaPipe.
mediapipe::Timestamp _frameTimestamp; mediapipe::Timestamp _frameTimestamp;
int64_t _frameNumber; int64 _frameNumber;
// Graph config modified to expose requested output streams. // Graph config modified to expose requested output streams.
mediapipe::CalculatorGraphConfig _config; mediapipe::CalculatorGraphConfig _config;
@ -91,8 +90,8 @@
&callbackInputName, &callbackInputName,
/*use_std_function=*/true); /*use_std_function=*/true);
// No matter what ownership qualifiers are put on the pointer, // No matter what ownership qualifiers are put on the pointer,
// NewPermanentCallback will still end up with a strong pointer to // NewPermanentCallback will still end up with a strong pointer to MPPGraph*.
// MPPGraph*. That is why we use void* instead. // That is why we use void* instead.
void* wrapperVoid = (__bridge void*)self; void* wrapperVoid = (__bridge void*)self;
_inputSidePackets[callbackInputName] = _inputSidePackets[callbackInputName] =
mediapipe::MakePacket<std::function<void(const mediapipe::Packet&)>>( mediapipe::MakePacket<std::function<void(const mediapipe::Packet&)>>(

View File

@ -14,8 +14,6 @@
#import "MPPTimestampConverter.h" #import "MPPTimestampConverter.h"
#include <cstdint>
@implementation MPPTimestampConverter { @implementation MPPTimestampConverter {
mediapipe::Timestamp _mediapipeTimestamp; mediapipe::Timestamp _mediapipeTimestamp;
mediapipe::Timestamp _lastTimestamp; mediapipe::Timestamp _lastTimestamp;
@ -39,7 +37,7 @@
- (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime { - (mediapipe::Timestamp)timestampForMediaTime:(CMTime)mediaTime {
Float64 sampleSeconds = Float64 sampleSeconds =
CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0; CMTIME_IS_VALID(mediaTime) ? CMTimeGetSeconds(mediaTime) : 0;
const int64_t sampleUsec = const int64 sampleUsec =
sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond; sampleSeconds * mediapipe::Timestamp::kTimestampUnitsPerSecond;
_mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset; _mediapipeTimestamp = mediapipe::Timestamp(sampleUsec) + _timestampOffset;
if (_mediapipeTimestamp <= _lastTimestamp) { if (_mediapipeTimestamp <= _lastTimestamp) {

View File

@ -14,8 +14,6 @@
#include "mediapipe/objc/util.h" #include "mediapipe/objc/util.h"
#include <cstdint>
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/log/absl_check.h" #include "absl/log/absl_check.h"
#include "absl/log/absl_log.h" #include "absl/log/absl_log.h"
@ -641,7 +639,7 @@ std::unique_ptr<mediapipe::ImageFrame> CreateImageFrameForCVPixelBuffer(
} else { } else {
frame = absl::make_unique<mediapipe::ImageFrame>( frame = absl::make_unique<mediapipe::ImageFrame>(
image_format, width, height, bytes_per_row, image_format, width, height, bytes_per_row,
reinterpret_cast<uint8_t*>(base_address), [image_buffer](uint8_t* x) { reinterpret_cast<uint8*>(base_address), [image_buffer](uint8* x) {
CVPixelBufferUnlockBaseAddress(image_buffer, CVPixelBufferUnlockBaseAddress(image_buffer,
kCVPixelBufferLock_ReadOnly); kCVPixelBufferLock_ReadOnly);
CVPixelBufferRelease(image_buffer); CVPixelBufferRelease(image_buffer);

View File

@ -97,7 +97,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/holistic_landmarker:holistic_landmarker_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",

View File

@ -43,33 +43,6 @@ cc_test(
], ],
) )
cc_library(
name = "matrix",
hdrs = ["matrix.h"],
)
cc_library(
name = "matrix_converter",
srcs = ["matrix_converter.cc"],
hdrs = ["matrix_converter.h"],
deps = [
":matrix",
"@eigen_archive//:eigen3",
],
)
cc_test(
name = "matrix_converter_test",
srcs = ["matrix_converter_test.cc"],
deps = [
":matrix",
":matrix_converter",
"//mediapipe/framework/port:gtest",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
],
)
cc_library( cc_library(
name = "landmark", name = "landmark",
hdrs = ["landmark.h"], hdrs = ["landmark.h"],

View File

@ -1,41 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_MATRIX_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_MATRIX_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
// Data are stored in column-major order by default.
struct Matrix {
// The number of rows in the matrix.
uint32_t rows;
// The number of columns in the matrix.
uint32_t cols;
// The matrix data stored in a column-first layout.
float* data;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_MATRIX_H_

View File

@ -1,43 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/matrix_converter.h"
#include <cstring>
#include "Eigen/Core"
#include "mediapipe/tasks/c/components/containers/matrix.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToMatrix(const Eigen::MatrixXf& in, ::Matrix* out) {
out->rows = in.rows();
out->cols = in.cols();
out->data = new float[out->rows * out->cols];
// Copies data from an Eigen matrix (default column-major as used by
// MediaPipe) to a C-style matrix, preserving the sequence of elements as per
// the Eigen matrix's internal storage (column-major order by default).
memcpy(out->data, in.data(), sizeof(float) * out->rows * out->cols);
}
void CppCloseMatrix(::Matrix* m) {
if (m->data) {
delete[] m->data;
m->data = nullptr;
}
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -1,30 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_MATRIX_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_MATRIX_CONVERTER_H_
#include "Eigen/Core"
#include "mediapipe/tasks/c/components/containers/matrix.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToMatrix(const Eigen::MatrixXf& in, ::Matrix* out);
void CppCloseMatrix(::Matrix* m);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_MATRIX_CONVERTER_H_

View File

@ -1,49 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/matrix_converter.h"
#include "Eigen/Core"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/matrix.h"
namespace mediapipe::tasks::c::components::containers {
TEST(MatrixConversionTest, ConvertsEigenMatrixToCMatrixAndFreesMemory) {
// Initialize an Eigen::MatrixXf
Eigen::MatrixXf cpp_matrix(2, 2);
cpp_matrix << 1.0f, 2.0f, 3.0f, 4.0f;
// Convert this Eigen matrix to C-style Matrix
::Matrix c_matrix;
CppConvertToMatrix(cpp_matrix, &c_matrix);
// Verify the conversion
EXPECT_EQ(c_matrix.rows, 2);
EXPECT_EQ(c_matrix.cols, 2);
ASSERT_NE(c_matrix.data, nullptr);
EXPECT_FLOAT_EQ(c_matrix.data[0], 1.0f);
EXPECT_FLOAT_EQ(c_matrix.data[1], 3.0f);
EXPECT_FLOAT_EQ(c_matrix.data[2], 2.0f);
EXPECT_FLOAT_EQ(c_matrix.data[3], 4.0f);
// Close the C-style Matrix
CppCloseMatrix(&c_matrix);
// Verify that memory is freed
EXPECT_EQ(c_matrix.data, nullptr);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -1,149 +0,0 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "face_landmarker_result",
hdrs = ["face_landmarker_result.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/c/components/containers:landmark",
"//mediapipe/tasks/c/components/containers:matrix",
],
)
cc_library(
name = "face_landmarker_result_converter",
srcs = ["face_landmarker_result_converter.cc"],
hdrs = ["face_landmarker_result_converter.h"],
deps = [
":face_landmarker_result",
"//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/c/components/containers:category_converter",
"//mediapipe/tasks/c/components/containers:landmark",
"//mediapipe/tasks/c/components/containers:landmark_converter",
"//mediapipe/tasks/c/components/containers:matrix",
"//mediapipe/tasks/c/components/containers:matrix_converter",
"//mediapipe/tasks/cc/components/containers:category",
"//mediapipe/tasks/cc/components/containers:landmark",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_result",
],
)
cc_test(
name = "face_landmarker_result_converter_test",
srcs = ["face_landmarker_result_converter_test.cc"],
linkstatic = 1,
deps = [
":face_landmarker_result",
":face_landmarker_result_converter",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/components/containers:landmark",
"//mediapipe/tasks/cc/components/containers:category",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers:landmark",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_result",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
],
)
cc_library(
name = "face_landmarker_lib",
srcs = ["face_landmarker.cc"],
hdrs = ["face_landmarker.h"],
visibility = ["//visibility:public"],
deps = [
":face_landmarker_result",
":face_landmarker_result_converter",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/face_landmarker",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_result",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "face_landmarker_test",
srcs = ["face_landmarker_test.cc"],
data = [
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
linkstatic = 1,
deps = [
":face_landmarker_lib",
":face_landmarker_result",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/c/components/containers:landmark",
"//mediapipe/tasks/c/vision/core:common",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
# bazel build -c opt --linkopt -s --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
# //mediapipe/tasks/c/vision/face_landmarker:libface_landmarker.so
cc_binary(
name = "libface_landmarker.so",
linkopts = [
"-Wl,-soname=libface_landmarker.so",
"-fvisibility=hidden",
],
linkshared = True,
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [":face_landmarker_lib"],
)
# bazel build --config darwin_arm64 -c opt --strip always --define MEDIAPIPE_DISABLE_GPU=1 \
# //mediapipe/tasks/c/vision/face_landmarker:libface_landmarker.dylib
cc_binary(
name = "libface_landmarker.dylib",
linkopts = [
"-Wl,-install_name,libface_landmarker.dylib",
"-fvisibility=hidden",
],
linkshared = True,
tags = [
"manual",
"nobuilder",
"notap",
],
deps = [":face_landmarker_lib"],
)

View File

@ -1,287 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <utility>
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result.h"
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result_converter.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe::tasks::c::vision::face_landmarker {
namespace {
using ::mediapipe::tasks::c::components::containers::
CppCloseFaceLandmarkerResult;
using ::mediapipe::tasks::c::components::containers::
CppConvertToFaceLandmarkerResult;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
using ::mediapipe::tasks::vision::core::RunningMode;
using ::mediapipe::tasks::vision::face_landmarker::FaceLandmarker;
typedef ::mediapipe::tasks::vision::face_landmarker::FaceLandmarkerResult
CppFaceLandmarkerResult;
int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
*error_msg = strdup(status.ToString().c_str());
}
return status.raw_code();
}
} // namespace
void CppConvertToFaceLandmarkerOptions(
const FaceLandmarkerOptions& in,
mediapipe::tasks::vision::face_landmarker::FaceLandmarkerOptions* out) {
out->num_faces = in.num_faces;
out->min_face_detection_confidence = in.min_face_detection_confidence;
out->min_face_presence_confidence = in.min_face_presence_confidence;
out->min_tracking_confidence = in.min_tracking_confidence;
out->output_face_blendshapes = in.output_face_blendshapes;
out->output_facial_transformation_matrixes =
in.output_facial_transformation_matrixes;
}
FaceLandmarker* CppFaceLandmarkerCreate(const FaceLandmarkerOptions& options,
char** error_msg) {
auto cpp_options = std::make_unique<
::mediapipe::tasks::vision::face_landmarker::FaceLandmarkerOptions>();
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToFaceLandmarkerOptions(options, cpp_options.get());
cpp_options->running_mode = static_cast<RunningMode>(options.running_mode);
// Enable callback for processing live stream data when the running mode is
// set to RunningMode::LIVE_STREAM.
if (cpp_options->running_mode == RunningMode::LIVE_STREAM) {
if (options.result_callback == nullptr) {
const absl::Status status = absl::InvalidArgumentError(
"Provided null pointer to callback function.");
ABSL_LOG(ERROR) << "Failed to create FaceLandmarker: " << status;
CppProcessError(status, error_msg);
return nullptr;
}
FaceLandmarkerOptions::result_callback_fn result_callback =
options.result_callback;
cpp_options->result_callback =
[result_callback](absl::StatusOr<CppFaceLandmarkerResult> cpp_result,
const Image& image, int64_t timestamp) {
char* error_msg = nullptr;
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
CppProcessError(cpp_result.status(), &error_msg);
result_callback({}, MpImage(), timestamp, error_msg);
free(error_msg);
return;
}
// Result is valid for the lifetime of the callback function.
FaceLandmarkerResult result;
CppConvertToFaceLandmarkerResult(*cpp_result, &result);
const auto& image_frame = image.GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<::ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
result_callback(&result, mp_image, timestamp,
/* error_msg= */ nullptr);
CppCloseFaceLandmarkerResult(&result);
};
}
auto landmarker = FaceLandmarker::Create(std::move(cpp_options));
if (!landmarker.ok()) {
ABSL_LOG(ERROR) << "Failed to create FaceLandmarker: "
<< landmarker.status();
CppProcessError(landmarker.status(), error_msg);
return nullptr;
}
return landmarker->release();
}
int CppFaceLandmarkerDetect(void* landmarker, const MpImage& image,
FaceLandmarkerResult* result, char** error_msg) {
if (image.type == MpImage::GPU_BUFFER) {
const absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet.");
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image.image_frame.format),
image.image_frame.image_buffer, image.image_frame.width,
image.image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_landmarker = static_cast<FaceLandmarker*>(landmarker);
auto cpp_result = cpp_landmarker->Detect(*img);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToFaceLandmarkerResult(*cpp_result, result);
return 0;
}
int CppFaceLandmarkerDetectForVideo(void* landmarker, const MpImage& image,
int64_t timestamp_ms,
FaceLandmarkerResult* result,
char** error_msg) {
if (image.type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image.image_frame.format),
image.image_frame.image_buffer, image.image_frame.width,
image.image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_landmarker = static_cast<FaceLandmarker*>(landmarker);
auto cpp_result = cpp_landmarker->DetectForVideo(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Detection failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToFaceLandmarkerResult(*cpp_result, result);
return 0;
}
int CppFaceLandmarkerDetectAsync(void* landmarker, const MpImage& image,
int64_t timestamp_ms, char** error_msg) {
if (image.type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("GPU Buffer not supported yet");
ABSL_LOG(ERROR) << "Detection failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image.image_frame.format),
image.image_frame.image_buffer, image.image_frame.width,
image.image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_landmarker = static_cast<FaceLandmarker*>(landmarker);
auto cpp_result = cpp_landmarker->DetectAsync(*img, timestamp_ms);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Data preparation for the landmark detection failed: "
<< cpp_result;
return CppProcessError(cpp_result, error_msg);
}
return 0;
}
void CppFaceLandmarkerCloseResult(FaceLandmarkerResult* result) {
CppCloseFaceLandmarkerResult(result);
}
int CppFaceLandmarkerClose(void* landmarker, char** error_msg) {
auto cpp_landmarker = static_cast<FaceLandmarker*>(landmarker);
auto result = cpp_landmarker->Close();
if (!result.ok()) {
ABSL_LOG(ERROR) << "Failed to close FaceLandmarker: " << result;
return CppProcessError(result, error_msg);
}
delete cpp_landmarker;
return 0;
}
} // namespace mediapipe::tasks::c::vision::face_landmarker
extern "C" {
void* face_landmarker_create(struct FaceLandmarkerOptions* options,
char** error_msg) {
return mediapipe::tasks::c::vision::face_landmarker::CppFaceLandmarkerCreate(
*options, error_msg);
}
int face_landmarker_detect_image(void* landmarker, const MpImage& image,
FaceLandmarkerResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::face_landmarker::CppFaceLandmarkerDetect(
landmarker, image, result, error_msg);
}
int face_landmarker_detect_for_video(void* landmarker, const MpImage& image,
int64_t timestamp_ms,
FaceLandmarkerResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::face_landmarker::
CppFaceLandmarkerDetectForVideo(landmarker, image, timestamp_ms, result,
error_msg);
}
int face_landmarker_detect_async(void* landmarker, const MpImage& image,
int64_t timestamp_ms, char** error_msg) {
return mediapipe::tasks::c::vision::face_landmarker::
CppFaceLandmarkerDetectAsync(landmarker, image, timestamp_ms, error_msg);
}
void face_landmarker_close_result(FaceLandmarkerResult* result) {
mediapipe::tasks::c::vision::face_landmarker::CppFaceLandmarkerCloseResult(
result);
}
int face_landmarker_close(void* landmarker, char** error_ms) {
return mediapipe::tasks::c::vision::face_landmarker::CppFaceLandmarkerClose(
landmarker, error_ms);
}
} // extern "C"

View File

@ -1,156 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_VISION_FACE_LANDMARKER_FACE_LANDMARKER_H_
#define MEDIAPIPE_TASKS_C_VISION_FACE_LANDMARKER_FACE_LANDMARKER_H_
#include "mediapipe/tasks/c/core/base_options.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
#endif // MP_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
// The options for configuring a MediaPipe face landmarker task.
struct FaceLandmarkerOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// FaceLandmarker has three running modes:
// 1) The image mode for recognizing face landmarks on single image inputs.
// 2) The video mode for recognizing face landmarks on the decoded frames of a
// video.
// 3) The live stream mode for recognizing face landmarks on the live stream
// of input data, such as from camera. In this mode, the "result_callback"
// below must be specified to receive the detection results asynchronously.
RunningMode running_mode;
// The maximum number of faces can be detected by the FaceLandmarker.
int num_faces = 1;
// The minimum confidence score for the face detection to be considered
// successful.
float min_face_detection_confidence = 0.5;
// The minimum confidence score of face presence score in the face landmark
// detection.
float min_face_presence_confidence = 0.5;
// The minimum confidence score for the face tracking to be considered
// successful.
float min_tracking_confidence = 0.5;
// Whether FaceLandmarker outputs face blendshapes classification. Face
// blendshapes are used for rendering the 3D face model.
bool output_face_blendshapes = false;
// Whether FaceLandmarker outputs facial transformation_matrix. Facial
// transformation matrix is used to transform the face landmarks in canonical
// face to the detected face, so that users can apply face effects on the
// detected landmarks.
bool output_facial_transformation_matrixes = false;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. Arguments of the callback function include:
// the pointer to recognition result, the image that result was obtained
// on, the timestamp relevant to recognition results and pointer to error
// message in case of any failure. The validity of the passed arguments is
// true for the lifetime of the callback function.
//
// A caller is responsible for closing face landmarker result.
typedef void (*result_callback_fn)(const FaceLandmarkerResult* result,
const MpImage& image, int64_t timestamp_ms,
char* error_msg);
result_callback_fn result_callback;
};
// Creates an FaceLandmarker from the provided `options`.
// Returns a pointer to the face landmarker on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT void* face_landmarker_create(struct FaceLandmarkerOptions* options,
char** error_msg);
// Performs face landmark detection on the input `image`. Returns `0` on
// success. If an error occurs, returns an error code and sets the error
// parameter to an an error message (if `error_msg` is not `nullptr`). You must
// free the memory allocated for the error message.
MP_EXPORT int face_landmarker_detect_image(void* landmarker,
const MpImage& image,
FaceLandmarkerResult* result,
char** error_msg);
// Performs face landmark detection on the provided video frame.
// Only use this method when the FaceLandmarker is created with the video
// running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int face_landmarker_detect_for_video(void* landmarker,
const MpImage& image,
int64_t timestamp_ms,
FaceLandmarkerResult* result,
char** error_msg);
// Sends live image data to face landmark detection, and the results will be
// available via the `result_callback` provided in the FaceLandmarkerOptions.
// Only use this method when the FaceLandmarker is created with the live
// stream running mode.
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the face landmarker. The input timestamps must be monotonically
// increasing.
// The `result_callback` provides:
// - The recognition results as an FaceLandmarkerResult object.
// - The const reference to the corresponding input image that the face
// landmarker runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int face_landmarker_detect_async(void* landmarker,
const MpImage& image,
int64_t timestamp_ms,
char** error_msg);
// Frees the memory allocated inside a FaceLandmarkerResult result.
// Does not free the result pointer itself.
MP_EXPORT void face_landmarker_close_result(FaceLandmarkerResult* result);
// Frees face landmarker.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not `nullptr`). You must free the memory
// allocated for the error message.
MP_EXPORT int face_landmarker_close(void* landmarker, char** error_msg);
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_VISION_FACE_LANDMARKER_FACE_LANDMARKER_H_

View File

@ -1,59 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_VISION_FACE_LANDMARKER_RESULT_FACE_LANDMARKER_RESULT_H_
#define MEDIAPIPE_TASKS_C_VISION_FACE_LANDMARKER_RESULT_FACE_LANDMARKER_RESULT_H_
#include <cstdint>
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/components/containers/landmark.h"
#include "mediapipe/tasks/c/components/containers/matrix.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
#endif // MP_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
// The hand landmarker result from HandLandmarker, where each vector
// element represents a single hand detected in the image.
struct FaceLandmarkerResult {
// Detected face landmarks in normalized image coordinates.
struct NormalizedLandmarks* face_landmarks;
// The number of elements in the face_landmarks array.
uint32_t face_landmarks_count;
// Optional face blendshapes results.
struct Categories* face_blendshapes;
// The number of elements in the face_blendshapes array.
uint32_t face_blendshapes_count;
// Optional facial transformation matrixes.
struct Matrix* facial_transformation_matrixes;
// The number of elements in the facial_transformation_matrixes array.
uint32_t facial_transformation_matrixes_count;
};
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_VISION_FACE_LANDMARKER_RESULT_FACE_LANDMARKER_RESULT_H_

View File

@ -1,117 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result_converter.h"
#include <cstdint>
#include <vector>
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/components/containers/category_converter.h"
#include "mediapipe/tasks/c/components/containers/landmark.h"
#include "mediapipe/tasks/c/components/containers/landmark_converter.h"
#include "mediapipe/tasks/c/components/containers/matrix.h"
#include "mediapipe/tasks/c/components/containers/matrix_converter.h"
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h"
namespace mediapipe::tasks::c::components::containers {
using CppCategory = ::mediapipe::tasks::components::containers::Category;
using CppNormalizedLandmark =
::mediapipe::tasks::components::containers::NormalizedLandmark;
void CppConvertToFaceLandmarkerResult(
const ::mediapipe::tasks::vision::face_landmarker::FaceLandmarkerResult& in,
FaceLandmarkerResult* out) {
out->face_landmarks_count = in.face_landmarks.size();
out->face_landmarks = new NormalizedLandmarks[out->face_landmarks_count];
for (uint32_t i = 0; i < out->face_landmarks_count; ++i) {
std::vector<CppNormalizedLandmark> cpp_normalized_landmarks;
for (uint32_t j = 0; j < in.face_landmarks[i].landmarks.size(); ++j) {
const auto& cpp_landmark = in.face_landmarks[i].landmarks[j];
cpp_normalized_landmarks.push_back(cpp_landmark);
}
CppConvertToNormalizedLandmarks(cpp_normalized_landmarks,
&out->face_landmarks[i]);
}
if (in.face_blendshapes.has_value()) {
out->face_blendshapes_count = in.face_blendshapes->size();
out->face_blendshapes = new Categories[out->face_blendshapes_count];
for (uint32_t i = 0; i < out->face_blendshapes_count; ++i) {
uint32_t categories_count =
in.face_blendshapes.value()[i].categories.size();
out->face_blendshapes[i].categories_count = categories_count;
out->face_blendshapes[i].categories = new Category[categories_count];
for (uint32_t j = 0; j < categories_count; ++j) {
const auto& cpp_category = in.face_blendshapes.value()[i].categories[j];
CppConvertToCategory(cpp_category,
&out->face_blendshapes[i].categories[j]);
}
}
} else {
out->face_blendshapes_count = 0;
out->face_blendshapes = nullptr;
}
if (in.facial_transformation_matrixes.has_value()) {
out->facial_transformation_matrixes_count =
in.facial_transformation_matrixes.value().size();
out->facial_transformation_matrixes =
new ::Matrix[out->facial_transformation_matrixes_count];
for (uint32_t i = 0; i < out->facial_transformation_matrixes_count; ++i) {
CppConvertToMatrix(in.facial_transformation_matrixes.value()[i],
&out->facial_transformation_matrixes[i]);
}
} else {
out->facial_transformation_matrixes_count = 0;
out->facial_transformation_matrixes = nullptr;
}
}
void CppCloseFaceLandmarkerResult(FaceLandmarkerResult* result) {
for (uint32_t i = 0; i < result->face_blendshapes_count; ++i) {
for (uint32_t j = 0; j < result->face_blendshapes[i].categories_count;
++j) {
CppCloseCategory(&result->face_blendshapes[i].categories[j]);
}
delete[] result->face_blendshapes[i].categories;
}
delete[] result->face_blendshapes;
for (uint32_t i = 0; i < result->face_landmarks_count; ++i) {
CppCloseNormalizedLandmarks(&result->face_landmarks[i]);
}
delete[] result->face_landmarks;
for (uint32_t i = 0; i < result->facial_transformation_matrixes_count; ++i) {
CppCloseMatrix(&result->facial_transformation_matrixes[i]);
}
delete[] result->facial_transformation_matrixes;
result->face_blendshapes_count = 0;
result->face_landmarks_count = 0;
result->facial_transformation_matrixes_count = 0;
result->face_blendshapes = nullptr;
result->face_landmarks = nullptr;
result->facial_transformation_matrixes = nullptr;
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -1,32 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_FACE_LANDMARKER_RESULT_CONVERTER_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_FACE_LANDMARKER_RESULT_CONVERTER_H_
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h"
namespace mediapipe::tasks::c::components::containers {
void CppConvertToFaceLandmarkerResult(
const mediapipe::tasks::vision::face_landmarker::FaceLandmarkerResult& in,
FaceLandmarkerResult* out);
void CppCloseFaceLandmarkerResult(FaceLandmarkerResult* result);
} // namespace mediapipe::tasks::c::components::containers
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_FACE_LANDMARKER_RESULT_CONVERTER_H_

View File

@ -1,154 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result_converter.h"
#include <cstddef>
#include <cstdint>
#include <string>
#include <vector>
#include "Eigen/Core"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/landmark.h"
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/landmark.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h"
namespace mediapipe::tasks::c::components::containers {
void InitFaceLandmarkerResult(
::mediapipe::tasks::vision::face_landmarker::FaceLandmarkerResult*
cpp_result) {
// Initialize face_landmarks
mediapipe::tasks::components::containers::NormalizedLandmark
cpp_normalized_landmark = {/* x= */ 0.1f, /* y= */ 0.2f, /* z= */ 0.3f};
mediapipe::tasks::components::containers::NormalizedLandmarks
cpp_normalized_landmarks;
cpp_normalized_landmarks.landmarks.push_back(cpp_normalized_landmark);
cpp_result->face_landmarks.push_back(cpp_normalized_landmarks);
// Initialize face_blendshapes
mediapipe::tasks::components::containers::Category cpp_category = {
/* index= */ 1,
/* score= */ 0.8f,
/* category_name= */ "blendshape_label_1",
/* display_name= */ "blendshape_display_name_1"};
mediapipe::tasks::components::containers::Classifications
classifications_for_blendshapes;
classifications_for_blendshapes.categories.push_back(cpp_category);
cpp_result->face_blendshapes =
std::vector<mediapipe::tasks::components::containers::Classifications>{
classifications_for_blendshapes};
cpp_result->face_blendshapes->push_back(classifications_for_blendshapes);
// Initialize facial_transformation_matrixes
Eigen::MatrixXf cpp_matrix(2, 2);
cpp_matrix << 1.0f, 2.0f, 3.0f, 4.0f;
cpp_result->facial_transformation_matrixes = std::vector<Matrix>{cpp_matrix};
}
TEST(FaceLandmarkerResultConverterTest, ConvertsCustomResult) {
// Initialize a C++ FaceLandmarkerResult
::mediapipe::tasks::vision::face_landmarker::FaceLandmarkerResult cpp_result;
InitFaceLandmarkerResult(&cpp_result);
FaceLandmarkerResult c_result;
CppConvertToFaceLandmarkerResult(cpp_result, &c_result);
// Verify conversion of face_landmarks
EXPECT_EQ(c_result.face_landmarks_count, cpp_result.face_landmarks.size());
for (uint32_t i = 0; i < c_result.face_landmarks_count; ++i) {
EXPECT_EQ(c_result.face_landmarks[i].landmarks_count,
cpp_result.face_landmarks[i].landmarks.size());
for (uint32_t j = 0; j < c_result.face_landmarks[i].landmarks_count; ++j) {
const auto& cpp_landmark = cpp_result.face_landmarks[i].landmarks[j];
EXPECT_FLOAT_EQ(c_result.face_landmarks[i].landmarks[j].x,
cpp_landmark.x);
EXPECT_FLOAT_EQ(c_result.face_landmarks[i].landmarks[j].y,
cpp_landmark.y);
EXPECT_FLOAT_EQ(c_result.face_landmarks[i].landmarks[j].z,
cpp_landmark.z);
}
}
// Verify conversion of face_blendshapes
EXPECT_EQ(c_result.face_blendshapes_count,
cpp_result.face_blendshapes.value().size());
for (uint32_t i = 0; i < c_result.face_blendshapes_count; ++i) {
const auto& cpp_face_blendshapes = cpp_result.face_blendshapes.value();
EXPECT_EQ(c_result.face_blendshapes[i].categories_count,
cpp_face_blendshapes[i].categories.size());
for (uint32_t j = 0; j < c_result.face_blendshapes[i].categories_count;
++j) {
const auto& cpp_category = cpp_face_blendshapes[i].categories[j];
EXPECT_EQ(c_result.face_blendshapes[i].categories[j].index,
cpp_category.index);
EXPECT_FLOAT_EQ(c_result.face_blendshapes[i].categories[j].score,
cpp_category.score);
EXPECT_EQ(
std::string(c_result.face_blendshapes[i].categories[j].category_name),
cpp_category.category_name);
}
}
// Verify conversion of facial_transformation_matrixes
EXPECT_EQ(c_result.facial_transformation_matrixes_count,
cpp_result.facial_transformation_matrixes.value().size());
for (uint32_t i = 0; i < c_result.facial_transformation_matrixes_count; ++i) {
const auto& cpp_facial_transformation_matrixes =
cpp_result.facial_transformation_matrixes.value();
// Assuming Matrix struct contains data array and dimensions
const auto& cpp_matrix = cpp_facial_transformation_matrixes[i];
EXPECT_EQ(c_result.facial_transformation_matrixes[i].rows,
cpp_matrix.rows());
EXPECT_EQ(c_result.facial_transformation_matrixes[i].cols,
cpp_matrix.cols());
// Check each element of the matrix
for (int32_t row = 0; row < cpp_matrix.rows(); ++row) {
for (int32_t col = 0; col < cpp_matrix.cols(); ++col) {
size_t index = col * cpp_matrix.rows() + row; // Column-major index
EXPECT_FLOAT_EQ(c_result.facial_transformation_matrixes[i].data[index],
cpp_matrix(row, col));
}
}
}
CppCloseFaceLandmarkerResult(&c_result);
}
TEST(FaceLandmarkerResultConverterTest, FreesMemory) {
::mediapipe::tasks::vision::face_landmarker::FaceLandmarkerResult cpp_result;
InitFaceLandmarkerResult(&cpp_result);
FaceLandmarkerResult c_result;
CppConvertToFaceLandmarkerResult(cpp_result, &c_result);
EXPECT_NE(c_result.face_blendshapes, nullptr);
EXPECT_NE(c_result.face_landmarks, nullptr);
EXPECT_NE(c_result.facial_transformation_matrixes, nullptr);
CppCloseFaceLandmarkerResult(&c_result);
EXPECT_EQ(c_result.face_blendshapes, nullptr);
EXPECT_EQ(c_result.face_landmarks, nullptr);
EXPECT_EQ(c_result.facial_transformation_matrixes, nullptr);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -1,292 +0,0 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker.h"
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <string>
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/landmark.h"
#include "mediapipe/tasks/c/vision/core/common.h"
#include "mediapipe/tasks/c/vision/face_landmarker/face_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kModelName[] = "face_landmarker_v2_with_blendshapes.task";
constexpr char kImageFile[] = "portrait.jpg";
constexpr float kLandmarksPrecision = 0.03;
constexpr float kBlendshapesPrecision = 0.12;
constexpr float kFacialTransformationMatrixPrecision = 0.05;
constexpr int kIterations = 100;
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
void AssertHandLandmarkerResult(const FaceLandmarkerResult* result,
const float blendshapes_precision,
const float landmark_precision,
const float matrix_precison) {
// Expects to have the same number of faces detected.
EXPECT_EQ(result->face_blendshapes_count, 1);
// Actual blendshapes matches expected blendshapes.
EXPECT_EQ(
std::string{result->face_blendshapes[0].categories[0].category_name},
"_neutral");
EXPECT_NEAR(result->face_blendshapes[0].categories[0].score, 0.0f,
blendshapes_precision);
// Actual landmarks match expected landmarks.
EXPECT_NEAR(result->face_landmarks[0].landmarks[0].x, 0.4977f,
landmark_precision);
EXPECT_NEAR(result->face_landmarks[0].landmarks[0].y, 0.2485f,
landmark_precision);
EXPECT_NEAR(result->face_landmarks[0].landmarks[0].z, -0.0305f,
landmark_precision);
// Expects to have at least one facial transformation matrix.
EXPECT_GE(result->facial_transformation_matrixes_count, 1);
// Actual matrix matches expected matrix.
// Assuming the expected matrix is 2x2 for demonstration.
const float expected_matrix[4] = {0.9991f, 0.0166f, -0.0374f, 0.0f};
for (int i = 0; i < 4; ++i) {
printf(">> %f <<", result->facial_transformation_matrixes[0].data[i]);
EXPECT_NEAR(result->facial_transformation_matrixes[0].data[i],
expected_matrix[i], matrix_precison);
}
}
TEST(FaceLandmarkerTest, ImageModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
FaceLandmarkerOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* num_faces= */ 1,
/* min_face_detection_confidence= */ 0.5,
/* min_face_presence_confidence= */ 0.5,
/* min_tracking_confidence= */ 0.5,
/* output_face_blendshapes = */ true,
/* output_facial_transformation_matrixes = */ true,
};
void* landmarker = face_landmarker_create(&options, /* error_msg */ nullptr);
EXPECT_NE(landmarker, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
FaceLandmarkerResult result;
face_landmarker_detect_image(landmarker, mp_image, &result,
/* error_msg */ nullptr);
AssertHandLandmarkerResult(&result, kBlendshapesPrecision,
kLandmarksPrecision,
kFacialTransformationMatrixPrecision);
face_landmarker_close_result(&result);
face_landmarker_close(landmarker, /* error_msg */ nullptr);
}
TEST(FaceLandmarkerTest, VideoModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
FaceLandmarkerOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::VIDEO,
/* num_faces= */ 1,
/* min_face_detection_confidence= */ 0.5,
/* min_face_presence_confidence= */ 0.5,
/* min_tracking_confidence= */ 0.5,
/* output_face_blendshapes = */ true,
/* output_facial_transformation_matrixes = */ true,
};
void* landmarker = face_landmarker_create(&options,
/* error_msg */ nullptr);
EXPECT_NE(landmarker, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
FaceLandmarkerResult result;
face_landmarker_detect_for_video(landmarker, mp_image, i, &result,
/* error_msg */ nullptr);
AssertHandLandmarkerResult(&result, kBlendshapesPrecision,
kLandmarksPrecision,
kFacialTransformationMatrixPrecision);
face_landmarker_close_result(&result);
}
face_landmarker_close(landmarker, /* error_msg */ nullptr);
}
// A structure to support LiveStreamModeTest below. This structure holds a
// static method `Fn` for a callback function of C API. A `static` qualifier
// allows to take an address of the method to follow API style. Another static
// struct member is `last_timestamp` that is used to verify that current
// timestamp is greater than the previous one.
struct LiveStreamModeCallback {
static int64_t last_timestamp;
static void Fn(const FaceLandmarkerResult* landmarker_result,
const MpImage& image, int64_t timestamp, char* error_msg) {
ASSERT_NE(landmarker_result, nullptr);
ASSERT_EQ(error_msg, nullptr);
AssertHandLandmarkerResult(landmarker_result, kBlendshapesPrecision,
kLandmarksPrecision,
kFacialTransformationMatrixPrecision);
EXPECT_GT(image.image_frame.width, 0);
EXPECT_GT(image.image_frame.height, 0);
EXPECT_GT(timestamp, last_timestamp);
++last_timestamp;
}
};
int64_t LiveStreamModeCallback::last_timestamp = -1;
TEST(FaceLandmarkerTest, LiveStreamModeTest) {
const auto image = DecodeImageFromFile(GetFullPath(kImageFile));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
FaceLandmarkerOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::LIVE_STREAM,
/* num_faces= */ 1,
/* min_face_detection_confidence= */ 0.5,
/* min_face_presence_confidence= */ 0.5,
/* min_tracking_confidence= */ 0.5,
/* output_face_blendshapes = */ true,
/* output_facial_transformation_matrixes = */ true,
/* result_callback= */ LiveStreamModeCallback::Fn,
};
void* landmarker = face_landmarker_create(&options, /* error_msg */
nullptr);
EXPECT_NE(landmarker, nullptr);
const auto& image_frame = image->GetImageFrameSharedPtr();
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {.format = static_cast<ImageFormat>(image_frame->Format()),
.image_buffer = image_frame->PixelData(),
.width = image_frame->Width(),
.height = image_frame->Height()}};
for (int i = 0; i < kIterations; ++i) {
EXPECT_GE(face_landmarker_detect_async(landmarker, mp_image, i,
/* error_msg */ nullptr),
0);
}
face_landmarker_close(landmarker, /* error_msg */ nullptr);
// Due to the flow limiter, the total of outputs might be smaller than the
// number of iterations.
EXPECT_LE(LiveStreamModeCallback::last_timestamp, kIterations);
EXPECT_GT(LiveStreamModeCallback::last_timestamp, 0);
}
TEST(FaceLandmarkerTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path.
FaceLandmarkerOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
/* running_mode= */ RunningMode::IMAGE,
/* num_faces= */ 1,
/* min_face_detection_confidence= */ 0.5,
/* min_face_presence_confidence= */ 0.5,
/* min_tracking_confidence= */ 0.5,
/* output_face_blendshapes = */ true,
/* output_facial_transformation_matrixes = */ true,
};
char* error_msg;
void* landmarker = face_landmarker_create(&options, &error_msg);
EXPECT_EQ(landmarker, nullptr);
EXPECT_THAT(
error_msg,
HasSubstr("INVALID_ARGUMENT: BLENDSHAPES Tag and blendshapes model must "
"be both set. Get BLENDSHAPES is set: true, blendshapes model "
"is set: false [MediaPipeTasksStatus='601']"));
free(error_msg);
}
TEST(FaceLandmarkerTest, FailedRecognitionHandling) {
const std::string model_path = GetFullPath(kModelName);
FaceLandmarkerOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* num_faces= */ 1,
/* min_face_detection_confidence= */ 0.5,
/* min_face_presence_confidence= */ 0.5,
/* min_tracking_confidence= */ 0.5,
/* output_face_blendshapes = */ true,
/* output_facial_transformation_matrixes = */ true,
};
void* landmarker = face_landmarker_create(&options, /* error_msg */
nullptr);
EXPECT_NE(landmarker, nullptr);
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
FaceLandmarkerResult result;
char* error_msg;
face_landmarker_detect_image(landmarker, mp_image, &result, &error_msg);
EXPECT_THAT(error_msg, HasSubstr("GPU Buffer not supported yet"));
free(error_msg);
face_landmarker_close(landmarker, /* error_msg */ nullptr);
}
} // namespace

View File

@ -68,7 +68,7 @@ struct HandLandmarkerOptions {
// true for the lifetime of the callback function. // true for the lifetime of the callback function.
// //
// A caller is responsible for closing hand landmarker result. // A caller is responsible for closing hand landmarker result.
typedef void (*result_callback_fn)(const HandLandmarkerResult* result, typedef void (*result_callback_fn)(HandLandmarkerResult* result,
const MpImage& image, int64_t timestamp_ms, const MpImage& image, int64_t timestamp_ms,
char* error_msg); char* error_msg);
result_callback_fn result_callback; result_callback_fn result_callback;

View File

@ -47,7 +47,7 @@ std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name); return JoinPath("./", kTestDataDirectory, file_name);
} }
void AssertHandLandmarkerResult(const HandLandmarkerResult* result, void MatchesHandLandmarkerResult(HandLandmarkerResult* result,
const float score_precision, const float score_precision,
const float landmark_precision) { const float landmark_precision) {
// Expects to have the same number of hands detected. // Expects to have the same number of hands detected.
@ -104,7 +104,7 @@ TEST(HandLandmarkerTest, ImageModeTest) {
HandLandmarkerResult result; HandLandmarkerResult result;
hand_landmarker_detect_image(landmarker, mp_image, &result, hand_landmarker_detect_image(landmarker, mp_image, &result,
/* error_msg */ nullptr); /* error_msg */ nullptr);
AssertHandLandmarkerResult(&result, kScorePrecision, kLandmarkPrecision); MatchesHandLandmarkerResult(&result, kScorePrecision, kLandmarkPrecision);
hand_landmarker_close_result(&result); hand_landmarker_close_result(&result);
hand_landmarker_close(landmarker, /* error_msg */ nullptr); hand_landmarker_close(landmarker, /* error_msg */ nullptr);
} }
@ -141,7 +141,7 @@ TEST(HandLandmarkerTest, VideoModeTest) {
hand_landmarker_detect_for_video(landmarker, mp_image, i, &result, hand_landmarker_detect_for_video(landmarker, mp_image, i, &result,
/* error_msg */ nullptr); /* error_msg */ nullptr);
AssertHandLandmarkerResult(&result, kScorePrecision, kLandmarkPrecision); MatchesHandLandmarkerResult(&result, kScorePrecision, kLandmarkPrecision);
hand_landmarker_close_result(&result); hand_landmarker_close_result(&result);
} }
hand_landmarker_close(landmarker, /* error_msg */ nullptr); hand_landmarker_close(landmarker, /* error_msg */ nullptr);
@ -154,11 +154,11 @@ TEST(HandLandmarkerTest, VideoModeTest) {
// timestamp is greater than the previous one. // timestamp is greater than the previous one.
struct LiveStreamModeCallback { struct LiveStreamModeCallback {
static int64_t last_timestamp; static int64_t last_timestamp;
static void Fn(const HandLandmarkerResult* landmarker_result, static void Fn(HandLandmarkerResult* landmarker_result, const MpImage& image,
const MpImage& image, int64_t timestamp, char* error_msg) { int64_t timestamp, char* error_msg) {
ASSERT_NE(landmarker_result, nullptr); ASSERT_NE(landmarker_result, nullptr);
ASSERT_EQ(error_msg, nullptr); ASSERT_EQ(error_msg, nullptr);
AssertHandLandmarkerResult(landmarker_result, kScorePrecision, MatchesHandLandmarkerResult(landmarker_result, kScorePrecision,
kLandmarkPrecision); kLandmarkPrecision);
EXPECT_GT(image.image_frame.width, 0); EXPECT_GT(image.image_frame.width, 0);
EXPECT_GT(image.image_frame.height, 0); EXPECT_GT(image.image_frame.height, 0);
@ -183,7 +183,7 @@ TEST(HandLandmarkerTest, LiveStreamModeTest) {
/* min_hand_detection_confidence= */ 0.5, /* min_hand_detection_confidence= */ 0.5,
/* min_hand_presence_confidence= */ 0.5, /* min_hand_presence_confidence= */ 0.5,
/* min_tracking_confidence= */ 0.5, /* min_tracking_confidence= */ 0.5,
/* result_callback_fn= */ LiveStreamModeCallback::Fn, /* result_callback= */ LiveStreamModeCallback::Fn,
}; };
void* landmarker = hand_landmarker_create(&options, /* error_msg */ nullptr); void* landmarker = hand_landmarker_create(&options, /* error_msg */ nullptr);

View File

@ -121,8 +121,6 @@ class StableDiffusionIterateCalculator : public Node {
if (handle_) dlclose(handle_); if (handle_) dlclose(handle_);
} }
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
@ -190,11 +188,6 @@ class StableDiffusionIterateCalculator : public Node {
bool emit_empty_packet_; bool emit_empty_packet_;
}; };
absl::Status StableDiffusionIterateCalculator::UpdateContract(
CalculatorContract* cc) {
return absl::OkStatus();
}
absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) { absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
StableDiffusionIterateCalculatorOptions options; StableDiffusionIterateCalculatorOptions options;
if (kOptionsIn(cc).IsEmpty()) { if (kOptionsIn(cc).IsEmpty()) {
@ -212,11 +205,7 @@ absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
if (options.file_folder().empty()) { if (options.file_folder().empty()) {
std::strcpy(config.model_dir, "bins/"); // NOLINT std::strcpy(config.model_dir, "bins/"); // NOLINT
} else { } else {
std::string file_folder = options.file_folder(); std::strcpy(config.model_dir, options.file_folder().c_str()); // NOLINT
if (!file_folder.empty() && file_folder.back() != '/') {
file_folder.push_back('/');
}
std::strcpy(config.model_dir, file_folder.c_str()); // NOLINT
} }
MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir)) MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir))
<< config.model_dir; << config.model_dir;

View File

@ -59,7 +59,6 @@ CALCULATORS_AND_GRAPHS = [
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarker_graph",
] ]
strip_api_include_path_prefix( strip_api_include_path_prefix(
@ -107,9 +106,6 @@ strip_api_include_path_prefix(
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetector.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorOptions.h",
"//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h", "//mediapipe/tasks/ios/vision/object_detector:sources/MPPObjectDetectorResult.h",
"//mediapipe/tasks/ios/vision/pose_landmarker:sources/MPPPoseLandmarker.h",
"//mediapipe/tasks/ios/vision/pose_landmarker:sources/MPPPoseLandmarkerOptions.h",
"//mediapipe/tasks/ios/vision/pose_landmarker:sources/MPPPoseLandmarkerResult.h",
], ],
) )
@ -210,9 +206,6 @@ apple_static_xcframework(
":MPPObjectDetector.h", ":MPPObjectDetector.h",
":MPPObjectDetectorOptions.h", ":MPPObjectDetectorOptions.h",
":MPPObjectDetectorResult.h", ":MPPObjectDetectorResult.h",
":MPPPoseLandmarker.h",
":MPPPoseLandmarkerOptions.h",
":MPPPoseLandmarkerResult.h",
], ],
deps = [ deps = [
"//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetector", "//mediapipe/tasks/ios/vision/face_detector:MPPFaceDetector",
@ -222,7 +215,6 @@ apple_static_xcframework(
"//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier", "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifier",
"//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenter", "//mediapipe/tasks/ios/vision/image_segmenter:MPPImageSegmenter",
"//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector", "//mediapipe/tasks/ios/vision/object_detector:MPPObjectDetector",
"//mediapipe/tasks/ios/vision/pose_landmarker:MPPPoseLandmarker",
], ],
) )

View File

@ -30,20 +30,6 @@ namespace {
using ::mediapipe::ImageFormat; using ::mediapipe::ImageFormat;
using ::mediapipe::ImageFrame; using ::mediapipe::ImageFrame;
vImage_Buffer CreateEmptyVImageBufferFromImageFrame(ImageFrame &imageFrame, bool shouldAllocate) {
UInt8 *data = shouldAllocate ? new UInt8[imageFrame.Height() * imageFrame.WidthStep()] : nullptr;
return {.data = data,
.height = static_cast<vImagePixelCount>(imageFrame.Height()),
.width = static_cast<vImagePixelCount>(imageFrame.Width()),
.rowBytes = static_cast<size_t>(imageFrame.WidthStep())};
}
vImage_Buffer CreateVImageBufferFromImageFrame(ImageFrame &imageFrame) {
vImage_Buffer imageBuffer = CreateEmptyVImageBufferFromImageFrame(imageFrame, false);
imageBuffer.data = imageFrame.MutablePixelData();
return imageBuffer;
}
vImage_Buffer allocatedVImageBuffer(vImagePixelCount width, vImagePixelCount height, vImage_Buffer allocatedVImageBuffer(vImagePixelCount width, vImagePixelCount height,
size_t rowBytes) { size_t rowBytes) {
UInt8 *data = new UInt8[height * rowBytes]; UInt8 *data = new UInt8[height * rowBytes];
@ -54,8 +40,6 @@ static void FreeDataProviderReleaseCallback(void *buffer, const void *data, size
delete[] (vImage_Buffer *)buffer; delete[] (vImage_Buffer *)buffer;
} }
static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { free(refCon); }
} // namespace } // namespace
@interface MPPPixelDataUtils : NSObject @interface MPPPixelDataUtils : NSObject
@ -67,10 +51,6 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
pixelBufferFormat:(OSType)pixelBufferFormatType pixelBufferFormat:(OSType)pixelBufferFormatType
error:(NSError **)error; error:(NSError **)error;
+ (UInt8 *)pixelDataFromImageFrame:(ImageFrame &)imageFrame
shouldCopy:(BOOL)shouldCopy
error:(NSError **)error;
@end @end
@interface MPPCVPixelBufferUtils : NSObject @interface MPPCVPixelBufferUtils : NSObject
@ -78,24 +58,6 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
+ (std::unique_ptr<ImageFrame>)imageFrameFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer + (std::unique_ptr<ImageFrame>)imageFrameFromCVPixelBuffer:(CVPixelBufferRef)pixelBuffer
error:(NSError **)error; error:(NSError **)error;
// This method is used to create CVPixelBuffer from output images of tasks like `FaceStylizer` only
// when the input `MPImage` source type is `pixelBuffer`.
// Always copies the pixel data of the image frame to the created `CVPixelBuffer`.
//
// The only possible 32 RGBA pixel format of input `CVPixelBuffer` is `kCVPixelFormatType_32BGRA`.
// But Mediapipe does not support inference on images of format `BGRA`. Hence the channels of the
// underlying pixel data of `CVPixelBuffer` are permuted to the supported RGBA format before passing
// them to the task for inference. The pixel format of the output images of any MediaPipe task will
// be the same as the pixel format of the input image. (RGBA in this case).
//
// Since creation of `CVPixelBuffer` from the output image pixels with a format of
// `kCVPixelFormatType_32RGBA` is not possible, the channels of the output C++ image `RGBA` have to
// be permuted to the format `BGRA`. When the pixels are copied to create `CVPixelBuffer` this does
// not pose a challenge.
//
// TODO: Investigate if permuting channels of output `mediapipe::Image` in place is possible for
// creating `CVPixelBuffer`s without copying the underlying pixels.
+ (CVPixelBufferRef)cvPixelBufferFromImageFrame:(ImageFrame &)imageFrame error:(NSError **)error;
@end @end
@interface MPPCGImageUtils : NSObject @interface MPPCGImageUtils : NSObject
@ -137,9 +99,6 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
// Convert the raw pixel data to RGBA format and un-premultiply the alpha from the R, G, B values // Convert the raw pixel data to RGBA format and un-premultiply the alpha from the R, G, B values
// since MediaPipe C++ APIs only accept un pre-multiplied channels. // since MediaPipe C++ APIs only accept un pre-multiplied channels.
//
// This method is commonly used for `MPImage`s of all source types. Hence supporting BGRA and RGBA
// formats. Only `pixelBuffer` source type is restricted to `BGRA` format.
switch (pixelBufferFormatType) { switch (pixelBufferFormatType) {
case kCVPixelFormatType_32RGBA: { case kCVPixelFormatType_32RGBA: {
destBuffer = allocatedVImageBuffer((vImagePixelCount)width, (vImagePixelCount)height, destBuffer = allocatedVImageBuffer((vImagePixelCount)width, (vImagePixelCount)height,
@ -148,8 +107,6 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
break; break;
} }
case kCVPixelFormatType_32BGRA: { case kCVPixelFormatType_32BGRA: {
// Permute channels to `RGBA` since MediaPipe tasks don't support inference on images of
// format `BGRA`.
const uint8_t permute_map[4] = {2, 1, 0, 3}; const uint8_t permute_map[4] = {2, 1, 0, 3};
destBuffer = allocatedVImageBuffer((vImagePixelCount)width, (vImagePixelCount)height, destBuffer = allocatedVImageBuffer((vImagePixelCount)width, (vImagePixelCount)height,
destinationBytesPerRow); destinationBytesPerRow);
@ -163,7 +120,8 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
default: { default: {
[MPPCommonUtils createCustomError:error [MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError withCode:MPPTasksErrorCodeInvalidArgumentError
description:@"Some internal error occured."]; description:@"Invalid source pixel buffer format. Expecting one of "
@"kCVPixelFormatType_32RGBA, kCVPixelFormatType_32BGRA"];
return nullptr; return nullptr;
} }
} }
@ -181,46 +139,6 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
static_cast<uint8 *>(destBuffer.data)); static_cast<uint8 *>(destBuffer.data));
} }
+ (UInt8 *)pixelDataFromImageFrame:(ImageFrame &)imageFrame
shouldCopy:(BOOL)shouldCopy
error:(NSError **)error {
vImage_Buffer sourceBuffer = CreateVImageBufferFromImageFrame(imageFrame);
// Pre-multiply the raw pixels from a `mediapipe::Image` before creating a `CGImage` to ensure
// that pixels are displayed correctly irrespective of their alpha values.
vImage_Error premultiplyError;
vImage_Buffer destinationBuffer;
switch (imageFrame.Format()) {
case ImageFormat::SRGBA: {
destinationBuffer =
shouldCopy ? CreateEmptyVImageBufferFromImageFrame(imageFrame, true) : sourceBuffer;
premultiplyError =
vImagePremultiplyData_RGBA8888(&sourceBuffer, &destinationBuffer, kvImageNoFlags);
break;
}
default: {
[MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInternalError
description:@"An error occured while processing the output image "
@"pixels of the vision task."];
return nullptr;
}
}
if (premultiplyError != kvImageNoError) {
[MPPCommonUtils
createCustomError:error
withCode:MPPTasksErrorCodeInternalError
description:
@"An error occured while processing the output image pixels of the vision task."];
return nullptr;
}
return (UInt8 *)destinationBuffer.data;
}
@end @end
@implementation MPPCVPixelBufferUtils @implementation MPPCVPixelBufferUtils
@ -231,8 +149,7 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
std::unique_ptr<ImageFrame> imageFrame = nullptr; std::unique_ptr<ImageFrame> imageFrame = nullptr;
switch (pixelBufferFormat) { switch (pixelBufferFormat) {
// Core Video only supports pixel data of order BGRA for 32 bit RGBA images. case kCVPixelFormatType_32RGBA:
// Thus other formats like `kCVPixelFormatType_32BGRA` don't need to be accounted for.
case kCVPixelFormatType_32BGRA: { case kCVPixelFormatType_32BGRA: {
CVPixelBufferLockBaseAddress(pixelBuffer, 0); CVPixelBufferLockBaseAddress(pixelBuffer, 0);
imageFrame = [MPPPixelDataUtils imageFrame = [MPPPixelDataUtils
@ -248,58 +165,15 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
default: { default: {
[MPPCommonUtils createCustomError:error [MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError withCode:MPPTasksErrorCodeInvalidArgumentError
description:@"Unsupported pixel format for CVPixelBuffer. Expected " description:@"Unsupported pixel format for CVPixelBuffer. Supported "
@"kCVPixelFormatType_32BGRA"]; @"pixel format types are kCVPixelFormatType_32BGRA and "
@"kCVPixelFormatType_32RGBA"];
} }
} }
return imageFrame; return imageFrame;
} }
+ (CVPixelBufferRef)cvPixelBufferFromImageFrame:(ImageFrame &)imageFrame error:(NSError **)error {
if (imageFrame.Format() != ImageFormat::SRGBA) {
[MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInternalError
description:@"An error occured while creating a CVPixelBuffer from the "
@"output image of the vision task."];
return nullptr;
}
UInt8 *pixelData = [MPPPixelDataUtils pixelDataFromImageFrame:imageFrame
shouldCopy:YES
error:error];
if (!pixelData) {
return nullptr;
}
const uint8_t permute_map[4] = {2, 1, 0, 3};
vImage_Buffer sourceBuffer = CreateEmptyVImageBufferFromImageFrame(imageFrame, NO);
sourceBuffer.data = pixelData;
if (vImagePermuteChannels_ARGB8888(&sourceBuffer, &sourceBuffer, permute_map, kvImageNoFlags) ==
kvImageNoError) {
CVPixelBufferRef outputBuffer;
OSType pixelBufferFormatType = kCVPixelFormatType_32BGRA;
// Since data is copied, pass in a release callback that will be invoked when the pixel buffer
// is destroyed.
if (CVPixelBufferCreateWithBytes(kCFAllocatorDefault, imageFrame.Width(), imageFrame.Height(),
pixelBufferFormatType, pixelData, imageFrame.WidthStep(),
FreeRefConReleaseCallback, pixelData, nullptr,
&outputBuffer) == kCVReturnSuccess) {
return outputBuffer;
}
}
[MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInternalError
description:@"An error occured while creating a CVPixelBuffer from the "
@"output image of the vision task."];
return nullptr;
}
@end @end
@implementation MPPCGImageUtils @implementation MPPCGImageUtils
@ -358,14 +232,7 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
CGBitmapInfo bitmapInfo = kCGImageAlphaNoneSkipLast | kCGBitmapByteOrderDefault; CGBitmapInfo bitmapInfo = kCGImageAlphaNoneSkipLast | kCGBitmapByteOrderDefault;
ImageFrame *internalImageFrame = imageFrame.get(); ImageFrame *internalImageFrame = imageFrame.get();
size_t channelCount = 4;
UInt8 *pixelData = [MPPPixelDataUtils pixelDataFromImageFrame:*internalImageFrame
shouldCopy:shouldCopyPixelData
error:error];
if (!pixelData) {
return nullptr;
}
switch (internalImageFrame->Format()) { switch (internalImageFrame->Format()) {
case ImageFormat::SRGBA: { case ImageFormat::SRGBA: {
@ -375,41 +242,56 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
default: default:
[MPPCommonUtils createCustomError:error [MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInternalError withCode:MPPTasksErrorCodeInternalError
description:@"An error occured while creating a CGImage from the " description:@"An internal error occured."];
@"output image of the vision task."];
return nullptr; return nullptr;
} }
size_t bitsPerComponent = 8;
vImage_Buffer sourceBuffer = {
.data = (void *)internalImageFrame->MutablePixelData(),
.height = static_cast<vImagePixelCount>(internalImageFrame->Height()),
.width = static_cast<vImagePixelCount>(internalImageFrame->Width()),
.rowBytes = static_cast<size_t>(internalImageFrame->WidthStep())};
vImage_Buffer destBuffer;
CGDataProviderReleaseDataCallback callback = nullptr; CGDataProviderReleaseDataCallback callback = nullptr;
if (shouldCopyPixelData) {
destBuffer = allocatedVImageBuffer(static_cast<vImagePixelCount>(internalImageFrame->Width()),
static_cast<vImagePixelCount>(internalImageFrame->Height()),
static_cast<size_t>(internalImageFrame->WidthStep()));
callback = FreeDataProviderReleaseCallback;
} else {
destBuffer = sourceBuffer;
}
// Pre-multiply the raw pixels from a `mediapipe::Image` before creating a `CGImage` to ensure
// that pixels are displayed correctly irrespective of their alpha values.
vImage_Error premultiplyError =
vImagePremultiplyData_RGBA8888(&sourceBuffer, &destBuffer, kvImageNoFlags);
if (premultiplyError != kvImageNoError) {
[MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInternalError
description:@"An internal error occured."];
return nullptr;
}
CGDataProviderRef provider = CGDataProviderCreateWithData( CGDataProviderRef provider = CGDataProviderCreateWithData(
pixelData, pixelData, internalImageFrame->WidthStep() * internalImageFrame->Height(), destBuffer.data, destBuffer.data,
callback); internalImageFrame->WidthStep() * internalImageFrame->Height(), callback);
CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB(); CGColorSpaceRef colorSpace = CGColorSpaceCreateDeviceRGB();
CGImageRef cgImageRef =
CGImageRef cgImageRef = nullptr;
if (provider && colorSpace) {
size_t bitsPerComponent = 8;
size_t channelCount = 4;
cgImageRef =
CGImageCreate(internalImageFrame->Width(), internalImageFrame->Height(), bitsPerComponent, CGImageCreate(internalImageFrame->Width(), internalImageFrame->Height(), bitsPerComponent,
bitsPerComponent * channelCount, internalImageFrame->WidthStep(), colorSpace, bitsPerComponent * channelCount, internalImageFrame->WidthStep(), colorSpace,
bitmapInfo, provider, nullptr, YES, kCGRenderingIntentDefault); bitmapInfo, provider, nullptr, YES, kCGRenderingIntentDefault);
}
// Can safely pass `NULL` to these functions according to iOS docs.
CGDataProviderRelease(provider); CGDataProviderRelease(provider);
CGColorSpaceRelease(colorSpace); CGColorSpaceRelease(colorSpace);
if (!cgImageRef) {
[MPPCommonUtils createCustomError:error
withCode:MPPTasksErrorCodeInternalError
description:@"An error occured while converting the output image of the "
@"vision task to a CGImage."];
}
return cgImageRef; return cgImageRef;
} }
@ -465,30 +347,8 @@ static void FreeRefConReleaseCallback(void *refCon, const void *baseAddress) { f
return [self initWithUIImage:image orientation:sourceImage.orientation error:nil]; return [self initWithUIImage:image orientation:sourceImage.orientation error:nil];
} }
case MPPImageSourceTypePixelBuffer: {
if (!shouldCopyPixelData) {
// TODO: Investigate possibility of permuting channels of `mediapipe::Image` returned by
// vision tasks in place to ensure that we can support creating `CVPixelBuffer`s without
// copying the pixel data.
[MPPCommonUtils
createCustomError:error
withCode:MPPTasksErrorCodeInvalidArgumentError
description:
@"When the source type is pixel buffer, you cannot request uncopied data"];
return nil;
}
CVPixelBufferRef pixelBuffer =
[MPPCVPixelBufferUtils cvPixelBufferFromImageFrame:*(image.GetImageFrameSharedPtr())
error:error];
MPPImage *image = [self initWithPixelBuffer:pixelBuffer
orientation:sourceImage.orientation
error:nil];
CVPixelBufferRelease(pixelBuffer);
return image;
}
default: default:
// TODO Implement CMSampleBuffer. // TODO Implement Other Source Types.
return nil; return nil;
} }
} }

View File

@ -82,8 +82,10 @@ NS_SWIFT_NAME(ImageClassifier)
* `.image`. * `.image`.
* *
* This method supports classification of RGBA images. If your `MPImage` has a source type * This method supports classification of RGBA images. If your `MPImage` has a source type
* of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use * ofm`.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must have one of the following
* `kCVPixelFormatType_32BGRA` as its pixel format. * pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
* *
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha * If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
* channel. * channel.
@ -102,9 +104,11 @@ NS_SWIFT_NAME(ImageClassifier)
* of the provided `MPImage`. Only use this method when the `ImageClassifier` is created with * of the provided `MPImage`. Only use this method when the `ImageClassifier` is created with
* running mode, `.image`. * running mode, `.image`.
* *
* This method supports classification of RGBA images. If your `MPImage` has a source type * This method supports classification of RGBA images. If your `MPImage` has a source type of
* of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use * `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must have one of the following
* `kCVPixelFormatType_32BGRA` as its pixel format. * pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
* *
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha * If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
* channel. * channel.
@ -129,9 +133,11 @@ NS_SWIFT_NAME(ImageClassifier)
* It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must
* be monotonically increasing. * be monotonically increasing.
* *
* This method supports classification of RGBA images. If your `MPImage` has a source type * This method supports classification of RGBA images. If your `MPImage` has a source type of
* of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use * `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must have one of the following
* `kCVPixelFormatType_32BGRA` as its pixel format. * pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
* *
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha * If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
* channel. * channel.
@ -155,9 +161,11 @@ NS_SWIFT_NAME(ImageClassifier)
* It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must
* be monotonically increasing. * be monotonically increasing.
* *
* This method supports classification of RGBA images. If your `MPImage` has a source type * This method supports classification of RGBA images. If your `MPImage` has a source type of
* of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use * `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must have one of the following
* `kCVPixelFormatType_32BGRA` as its pixel format. * pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
* *
* If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha * If your `MPImage` has a source type of `.image` ensure that the color space is RGB with an Alpha
* channel. * channel.
@ -191,9 +199,11 @@ NS_SWIFT_NAME(ImageClassifier)
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing. * to the image classifier. The input timestamps must be monotonically increasing.
* *
* This method supports classification of RGBA images. If your `MPImage` has a source type * This method supports classification of RGBA images. If your `MPImage` has a source type of
* of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use * .pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must have one of the following
* `kCVPixelFormatType_32BGRA` as its pixel format. * pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
* *
* If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an * If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an
* Alpha channel. * Alpha channel.
@ -228,15 +238,17 @@ NS_SWIFT_NAME(ImageClassifier)
* It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent
* to the image classifier. The input timestamps must be monotonically increasing. * to the image classifier. The input timestamps must be monotonically increasing.
* *
* This method supports classification of RGBA images. If your `MPImage` has a source type * This method supports classification of RGBA images. If your `MPImage` has a source type of
* of `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must use * `.pixelBuffer` or `.sampleBuffer`, the underlying pixel buffer must have one of the following
* `kCVPixelFormatType_32BGRA` as its pixel format. * pixel format types:
* 1. kCVPixelFormatType_32BGRA
* 2. kCVPixelFormatType_32RGBA
* *
* If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an * If the input `MPImage` has a source type of `.image` ensure that the color space is RGB with an
* Alpha channel. * Alpha channel.
* *
* If this method is used for classifying live camera frames using `AVFoundation`, ensure that you * If this method is used for classifying live camera frames using `AVFoundation`, ensure that you
* request `AVCaptureVideoDataOutput` to output frames in `kCMPixelFormat_32BGRA` using its * request `AVCaptureVideoDataOutput` to output frames in `kCMPixelFormat_32RGBA` using its
* `videoSettings` property. * `videoSettings` property.
* *
* @param image A live stream image data of type `MPImage` on which image classification is to be * @param image A live stream image data of type `MPImage` on which image classification is to be

View File

@ -36,8 +36,7 @@ using PoseLandmarkerGraphOptionsProto =
optionsProto->MutableExtension(PoseLandmarkerGraphOptionsProto::ext); optionsProto->MutableExtension(PoseLandmarkerGraphOptionsProto::ext);
poseLandmarkerGraphOptions->Clear(); poseLandmarkerGraphOptions->Clear();
[self.baseOptions copyToProto:poseLandmarkerGraphOptions->mutable_base_options() [self.baseOptions copyToProto:poseLandmarkerGraphOptions->mutable_base_options()];
withUseStreamMode:self.runningMode != MPPRunningModeImage];
poseLandmarkerGraphOptions->set_min_tracking_confidence(self.minTrackingConfidence); poseLandmarkerGraphOptions->set_min_tracking_confidence(self.minTrackingConfidence);
PoseLandmarksDetectorGraphOptionsProto *poseLandmarksDetectorGraphOptions = PoseLandmarksDetectorGraphOptionsProto *poseLandmarksDetectorGraphOptions =

View File

@ -479,7 +479,7 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
* Sets minimum confidence score for the face landmark detection to be considered successful. * Sets minimum confidence score for the face landmark detection to be considered successful.
* Defaults to 0.5. * Defaults to 0.5.
*/ */
public abstract Builder setMinFacePresenceConfidence(Float value); public abstract Builder setMinFaceLandmarksConfidence(Float value);
/** /**
* The minimum confidence score for the pose detection to be considered successful. Defaults * The minimum confidence score for the pose detection to be considered successful. Defaults
@ -497,7 +497,7 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
* The minimum confidence score for the pose landmarks detection to be considered successful. * The minimum confidence score for the pose landmarks detection to be considered successful.
* Defaults to 0.5. * Defaults to 0.5.
*/ */
public abstract Builder setMinPosePresenceConfidence(Float value); public abstract Builder setMinPoseLandmarksConfidence(Float value);
/** /**
* The minimum confidence score for the hand landmark detection to be considered successful. * The minimum confidence score for the hand landmark detection to be considered successful.
@ -555,13 +555,13 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
abstract Optional<Float> minFaceSuppressionThreshold(); abstract Optional<Float> minFaceSuppressionThreshold();
abstract Optional<Float> minFacePresenceConfidence(); abstract Optional<Float> minFaceLandmarksConfidence();
abstract Optional<Float> minPoseDetectionConfidence(); abstract Optional<Float> minPoseDetectionConfidence();
abstract Optional<Float> minPoseSuppressionThreshold(); abstract Optional<Float> minPoseSuppressionThreshold();
abstract Optional<Float> minPosePresenceConfidence(); abstract Optional<Float> minPoseLandmarksConfidence();
abstract Optional<Float> minHandLandmarksConfidence(); abstract Optional<Float> minHandLandmarksConfidence();
@ -578,10 +578,10 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
.setRunningMode(RunningMode.IMAGE) .setRunningMode(RunningMode.IMAGE)
.setMinFaceDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD) .setMinFaceDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinFaceSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD) .setMinFaceSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD)
.setMinFacePresenceConfidence(DEFAULT_PRESENCE_THRESHOLD) .setMinFaceLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinPoseDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD) .setMinPoseDetectionConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinPoseSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD) .setMinPoseSuppressionThreshold(DEFAULT_SUPPRESION_THRESHOLD)
.setMinPosePresenceConfidence(DEFAULT_PRESENCE_THRESHOLD) .setMinPoseLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setMinHandLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD) .setMinHandLandmarksConfidence(DEFAULT_PRESENCE_THRESHOLD)
.setOutputFaceBlendshapes(DEFAULT_OUTPUT_FACE_BLENDSHAPES) .setOutputFaceBlendshapes(DEFAULT_OUTPUT_FACE_BLENDSHAPES)
.setOutputPoseSegmentationMasks(DEFAULT_OUTPUT_SEGMENTATION_MASKS); .setOutputPoseSegmentationMasks(DEFAULT_OUTPUT_SEGMENTATION_MASKS);
@ -616,12 +616,12 @@ public final class HolisticLandmarker extends BaseVisionTaskApi {
// Configure pose detector options. // Configure pose detector options.
minPoseDetectionConfidence().ifPresent(poseDetectorGraphOptions::setMinDetectionConfidence); minPoseDetectionConfidence().ifPresent(poseDetectorGraphOptions::setMinDetectionConfidence);
minPoseSuppressionThreshold().ifPresent(poseDetectorGraphOptions::setMinSuppressionThreshold); minPoseSuppressionThreshold().ifPresent(poseDetectorGraphOptions::setMinSuppressionThreshold);
minPosePresenceConfidence().ifPresent(poseLandmarkerGraphOptions::setMinDetectionConfidence); minPoseLandmarksConfidence().ifPresent(poseLandmarkerGraphOptions::setMinDetectionConfidence);
// Configure face detector options. // Configure face detector options.
minFaceDetectionConfidence().ifPresent(faceDetectorGraphOptions::setMinDetectionConfidence); minFaceDetectionConfidence().ifPresent(faceDetectorGraphOptions::setMinDetectionConfidence);
minFaceSuppressionThreshold().ifPresent(faceDetectorGraphOptions::setMinSuppressionThreshold); minFaceSuppressionThreshold().ifPresent(faceDetectorGraphOptions::setMinSuppressionThreshold);
minFacePresenceConfidence() minFaceLandmarksConfidence()
.ifPresent(faceLandmarksDetectorGraphOptions::setMinDetectionConfidence); .ifPresent(faceLandmarksDetectorGraphOptions::setMinDetectionConfidence);
holisticLandmarkerGraphOptions holisticLandmarkerGraphOptions

View File

@ -49,6 +49,5 @@ py_library(
"//mediapipe/calculators/core:flow_limiter_calculator_py_pb2", "//mediapipe/calculators/core:flow_limiter_calculator_py_pb2",
"//mediapipe/framework:calculator_options_py_pb2", "//mediapipe/framework:calculator_options_py_pb2",
"//mediapipe/framework:calculator_py_pb2", "//mediapipe/framework:calculator_py_pb2",
"@com_google_protobuf//:protobuf_python",
], ],
) )

View File

@ -14,8 +14,9 @@
"""MediaPipe Tasks' task info data class.""" """MediaPipe Tasks' task info data class."""
import dataclasses import dataclasses
from typing import Any, List from typing import Any, List
from google.protobuf import any_pb2
from mediapipe.calculators.core import flow_limiter_calculator_pb2 from mediapipe.calculators.core import flow_limiter_calculator_pb2
from mediapipe.framework import calculator_options_pb2 from mediapipe.framework import calculator_options_pb2
from mediapipe.framework import calculator_pb2 from mediapipe.framework import calculator_pb2
@ -79,34 +80,21 @@ class TaskInfo:
raise ValueError( raise ValueError(
'`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.' '`task_options` doesn`t provide `to_pb2()` method to convert itself to be a protobuf object.'
) )
task_subgraph_options = calculator_options_pb2.CalculatorOptions()
task_options_proto = self.task_options.to_pb2() task_options_proto = self.task_options.to_pb2()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
node_config = calculator_pb2.CalculatorGraphConfig.Node( task_options_proto)
if not enable_flow_limiting:
return calculator_pb2.CalculatorGraphConfig(
node=[
calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph, calculator=self.task_graph,
input_stream=self.input_streams, input_stream=self.input_streams,
output_stream=self.output_streams, output_stream=self.output_streams,
) options=task_subgraph_options)
],
if hasattr(task_options_proto, 'ext'):
# Use the extension mechanism for task_subgraph_options (proto2)
task_subgraph_options = calculator_options_pb2.CalculatorOptions()
task_subgraph_options.Extensions[task_options_proto.ext].CopyFrom(
task_options_proto
)
node_config.options.CopyFrom(task_subgraph_options)
else:
# Use the Any type for task_subgraph_options (proto3)
task_subgraph_options = any_pb2.Any()
task_subgraph_options.Pack(self.task_options.to_pb2())
node_config.node_options.append(task_subgraph_options)
if not enable_flow_limiting:
return calculator_pb2.CalculatorGraphConfig(
node=[node_config],
input_stream=self.input_streams, input_stream=self.input_streams,
output_stream=self.output_streams, output_stream=self.output_streams)
)
# When a FlowLimiterCalculator is inserted to lower the overall graph # When a FlowLimiterCalculator is inserted to lower the overall graph
# latency, the task doesn't guarantee that each input must have the # latency, the task doesn't guarantee that each input must have the
# corresponding output. # corresponding output.
@ -132,8 +120,13 @@ class TaskInfo:
], ],
options=flow_limiter_options) options=flow_limiter_options)
config = calculator_pb2.CalculatorGraphConfig( config = calculator_pb2.CalculatorGraphConfig(
node=[node_config, flow_limiter], node=[
input_stream=self.input_streams, calculator_pb2.CalculatorGraphConfig.Node(
calculator=self.task_graph,
input_stream=task_subgraph_inputs,
output_stream=self.output_streams, output_stream=self.output_streams,
) options=task_subgraph_options), flow_limiter
],
input_stream=self.input_streams,
output_stream=self.output_streams)
return config return config

View File

@ -194,27 +194,6 @@ py_test(
], ],
) )
py_test(
name = "holistic_landmarker_test",
srcs = ["holistic_landmarker_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
tags = ["not_run:arm"],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:holistic_landmarker",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@com_google_protobuf//:protobuf_python",
],
)
py_test( py_test(
name = "face_aligner_test", name = "face_aligner_test",
srcs = ["face_aligner_test.py"], srcs = ["face_aligner_test.py"],

View File

@ -1,544 +0,0 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for holistic landmarker."""
import enum
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import holistic_landmarker
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult
_HolisticResultProto = holistic_result_pb2.HolisticResult
_BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image
_HolisticLandmarker = holistic_landmarker.HolisticLandmarker
_HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'holistic_landmarker.task'
_POSE_IMAGE = 'male_full_height_hands.jpg'
_CAT_IMAGE = 'cat.jpg'
_EXPECTED_HOLISTIC_RESULT = 'male_full_height_hands_result_cpu.pbtxt'
_IMAGE_WIDTH = 638
_IMAGE_HEIGHT = 1000
_LANDMARKS_MARGIN = 0.03
_BLENDSHAPES_MARGIN = 0.13
_VIDEO_LANDMARKS_MARGIN = 0.03
_VIDEO_BLENDSHAPES_MARGIN = 0.31
_LIVE_STREAM_LANDMARKS_MARGIN = 0.03
_LIVE_STREAM_BLENDSHAPES_MARGIN = 0.31
def _get_expected_holistic_landmarker_result(
file_path: str,
) -> HolisticLandmarkerResult:
holistic_result_file_path = test_utils.get_test_data_path(file_path)
with open(holistic_result_file_path, 'rb') as f:
holistic_result_proto = _HolisticResultProto()
# Use this if a .pb file is available.
# holistic_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), holistic_result_proto)
holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2(
holistic_result_proto
)
return holistic_landmarker_result
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class HolisticLandmarkerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POSE_IMAGE)
)
self.model_path = test_utils.get_test_data_path(
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE
)
def _expect_landmarks_correct(
self, actual_landmarks, expected_landmarks, margin
):
# Expects to have the same number of landmarks detected.
self.assertLen(actual_landmarks, len(expected_landmarks))
for i, elem in enumerate(actual_landmarks):
self.assertAlmostEqual(elem.x, expected_landmarks[i].x, delta=margin)
self.assertAlmostEqual(elem.y, expected_landmarks[i].y, delta=margin)
def _expect_blendshapes_correct(
self, actual_blendshapes, expected_blendshapes, margin
):
# Expects to have the same number of blendshapes.
self.assertLen(actual_blendshapes, len(expected_blendshapes))
for i, elem in enumerate(actual_blendshapes):
self.assertEqual(elem.index, expected_blendshapes[i].index)
self.assertEqual(
elem.category_name, expected_blendshapes[i].category_name
)
self.assertAlmostEqual(
elem.score,
expected_blendshapes[i].score,
delta=margin,
)
def _expect_holistic_landmarker_results_correct(
self,
actual_result: HolisticLandmarkerResult,
expected_result: HolisticLandmarkerResult,
output_segmentation_mask: bool,
landmarks_margin: float,
blendshapes_margin: float,
):
self._expect_landmarks_correct(
actual_result.pose_landmarks,
expected_result.pose_landmarks,
landmarks_margin,
)
self._expect_landmarks_correct(
actual_result.face_landmarks,
expected_result.face_landmarks,
landmarks_margin,
)
self._expect_blendshapes_correct(
actual_result.face_blendshapes,
expected_result.face_blendshapes,
blendshapes_margin,
)
if output_segmentation_mask:
self.assertIsInstance(actual_result.segmentation_mask, _Image)
self.assertEqual(actual_result.segmentation_mask.width, _IMAGE_WIDTH)
self.assertEqual(actual_result.segmentation_mask.height, _IMAGE_HEIGHT)
else:
self.assertIsNone(actual_result.segmentation_mask)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _HolisticLandmarker.create_from_model_path(
self.model_path
) as landmarker:
self.assertIsInstance(landmarker, _HolisticLandmarker)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _HolisticLandmarkerOptions(base_options=base_options)
with _HolisticLandmarker.create_from_options(options) as landmarker:
self.assertIsInstance(landmarker, _HolisticLandmarker)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite'
)
options = _HolisticLandmarkerOptions(base_options=base_options)
_HolisticLandmarker.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _HolisticLandmarkerOptions(base_options=base_options)
landmarker = _HolisticLandmarker.create_from_options(options)
self.assertIsInstance(landmarker, _HolisticLandmarker)
@parameterized.parameters(
(
ModelFileType.FILE_NAME,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
(
ModelFileType.FILE_CONTENT,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
(
ModelFileType.FILE_NAME,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
(
ModelFileType.FILE_CONTENT,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
)
def test_detect(
self,
model_file_type,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result,
):
# Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _HolisticLandmarkerOptions(
base_options=base_options,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes
else False,
output_segmentation_mask=output_segmentation_mask,
)
landmarker = _HolisticLandmarker.create_from_options(options)
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
self._expect_holistic_landmarker_results_correct(
detection_result,
expected_holistic_landmarker_result,
output_segmentation_mask,
_LANDMARKS_MARGIN,
_BLENDSHAPES_MARGIN,
)
# Closes the holistic landmarker explicitly when the holistic landmarker is
# not used in a context.
landmarker.close()
@parameterized.parameters(
(
ModelFileType.FILE_NAME,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
(
ModelFileType.FILE_CONTENT,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
)
def test_detect_in_context(
self,
model_file_type,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result,
):
# Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name)
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _HolisticLandmarkerOptions(
base_options=base_options,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes
else False,
output_segmentation_mask=output_segmentation_mask,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
self._expect_holistic_landmarker_results_correct(
detection_result,
expected_holistic_landmarker_result,
output_segmentation_mask,
_LANDMARKS_MARGIN,
_BLENDSHAPES_MARGIN,
)
def test_empty_detection_outputs(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path)
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
# Load the cat image.
cat_test_image = _Image.create_from_file(
test_utils.get_test_data_path(_CAT_IMAGE)
)
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect(cat_test_image)
self.assertEmpty(detection_result.face_landmarks)
self.assertEmpty(detection_result.pose_landmarks)
self.assertEmpty(detection_result.pose_world_landmarks)
self.assertEmpty(detection_result.left_hand_landmarks)
self.assertEmpty(detection_result.left_hand_world_landmarks)
self.assertEmpty(detection_result.right_hand_landmarks)
self.assertEmpty(detection_result.right_hand_world_landmarks)
self.assertIsNone(detection_result.face_blendshapes)
self.assertIsNone(detection_result.segmentation_mask)
def test_missing_result_callback(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
)
with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _HolisticLandmarker.create_from_options(
options
) as unused_landmarker:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock(),
)
with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _HolisticLandmarker.create_from_options(
options
) as unused_landmarker:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
landmarker.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
landmarker.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
landmarker.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the live stream mode'
):
landmarker.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
unused_result = landmarker.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
(
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
)
def test_detect_for_video(
self,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result,
):
# Creates holistic landmarker.
model_path = test_utils.get_test_data_path(model_name)
base_options = _BaseOptions(model_asset_path=model_path)
options = _HolisticLandmarkerOptions(
base_options=base_options,
running_mode=_RUNNING_MODE.VIDEO,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes
else False,
output_segmentation_mask=output_segmentation_mask,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
# Performs holistic landmarks detection on the input.
detection_result = landmarker.detect_for_video(
self.test_image, timestamp
)
# Comparing results.
self._expect_holistic_landmarker_results_correct(
detection_result,
expected_holistic_landmarker_result,
output_segmentation_mask,
_VIDEO_LANDMARKS_MARGIN,
_VIDEO_BLENDSHAPES_MARGIN,
)
def test_calling_detect_in_live_stream_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the image mode'
):
landmarker.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(
ValueError, r'not initialized with the video mode'
):
landmarker.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock(),
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
landmarker.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'
):
landmarker.detect_async(self.test_image, 0)
@parameterized.parameters(
(
_POSE_IMAGE,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
False,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
(
_POSE_IMAGE,
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
True,
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
),
)
def test_detect_async_calls(
self,
image_path,
model_name,
output_segmentation_mask,
expected_holistic_landmarker_result,
):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path)
)
observed_timestamp_ms = -1
def check_result(
result: HolisticLandmarkerResult,
output_image: _Image,
timestamp_ms: int,
):
# Comparing results.
self._expect_holistic_landmarker_results_correct(
result,
expected_holistic_landmarker_result,
output_segmentation_mask,
_LIVE_STREAM_LANDMARKS_MARGIN,
_LIVE_STREAM_BLENDSHAPES_MARGIN,
)
self.assertTrue(
np.array_equal(output_image.numpy_view(), test_image.numpy_view())
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
model_path = test_utils.get_test_data_path(model_name)
options = _HolisticLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_face_blendshapes=True
if expected_holistic_landmarker_result.face_blendshapes
else False,
output_segmentation_mask=output_segmentation_mask,
result_callback=check_result,
)
with _HolisticLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
landmarker.detect_async(test_image, timestamp)
if __name__ == '__main__':
absltest.main()

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder: load py_library
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
@ -242,30 +243,6 @@ py_library(
], ],
) )
py_library(
name = "holistic_landmarker",
srcs = [
"holistic_landmarker.py",
],
deps = [
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_result_py_pb2",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_library( py_library(
name = "face_stylizer", name = "face_stylizer",
srcs = [ srcs = [

View File

@ -1,576 +0,0 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe holistic landmarker task."""
import dataclasses
from typing import Callable, List, Mapping, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_landmarker_graph_options_pb2
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_HolisticResultProto = holistic_result_pb2.HolisticResult
_HolisticLandmarkerGraphOptionsProto = (
holistic_landmarker_graph_options_pb2.HolisticLandmarkerGraphOptions
)
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_POSE_LANDMARKS_STREAM_NAME = 'pose_landmarks'
_POSE_LANDMARKS_TAG_NAME = 'POSE_LANDMARKS'
_POSE_WORLD_LANDMARKS_STREAM_NAME = 'pose_world_landmarks'
_POSE_WORLD_LANDMARKS_TAG = 'POSE_WORLD_LANDMARKS'
_POSE_SEGMENTATION_MASK_STREAM_NAME = 'pose_segmentation_mask'
_POSE_SEGMENTATION_MASK_TAG = 'POSE_SEGMENTATION_MASK'
_FACE_LANDMARKS_STREAM_NAME = 'face_landmarks'
_FACE_LANDMARKS_TAG = 'FACE_LANDMARKS'
_FACE_BLENDSHAPES_STREAM_NAME = 'extra_blendshapes'
_FACE_BLENDSHAPES_TAG = 'FACE_BLENDSHAPES'
_LEFT_HAND_LANDMARKS_STREAM_NAME = 'left_hand_landmarks'
_LEFT_HAND_LANDMARKS_TAG = 'LEFT_HAND_LANDMARKS'
_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME = 'left_hand_world_landmarks'
_LEFT_HAND_WORLD_LANDMARKS_TAG = 'LEFT_HAND_WORLD_LANDMARKS'
_RIGHT_HAND_LANDMARKS_STREAM_NAME = 'right_hand_landmarks'
_RIGHT_HAND_LANDMARKS_TAG = 'RIGHT_HAND_LANDMARKS'
_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME = 'right_hand_world_landmarks'
_RIGHT_HAND_WORLD_LANDMARKS_TAG = 'RIGHT_HAND_WORLD_LANDMARKS'
_TASK_GRAPH_NAME = (
'mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph'
)
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class HolisticLandmarkerResult:
"""The holistic landmarks result from HolisticLandmarker, where each vector element represents a single holistic detected in the image.
Attributes:
face_landmarks: Detected face landmarks in normalized image coordinates.
pose_landmarks: Detected pose landmarks in normalized image coordinates.
pose_world_landmarks: Detected pose world landmarks in image coordinates.
left_hand_landmarks: Detected left hand landmarks in normalized image
coordinates.
left_hand_world_landmarks: Detected left hand landmarks in image
coordinates.
right_hand_landmarks: Detected right hand landmarks in normalized image
coordinates.
right_hand_world_landmarks: Detected right hand landmarks in image
coordinates.
face_blendshapes: Optional face blendshapes.
segmentation_mask: Optional segmentation mask for pose.
"""
face_landmarks: List[landmark_module.NormalizedLandmark]
pose_landmarks: List[landmark_module.NormalizedLandmark]
pose_world_landmarks: List[landmark_module.Landmark]
left_hand_landmarks: List[landmark_module.NormalizedLandmark]
left_hand_world_landmarks: List[landmark_module.Landmark]
right_hand_landmarks: List[landmark_module.NormalizedLandmark]
right_hand_world_landmarks: List[landmark_module.Landmark]
face_blendshapes: Optional[List[category_module.Category]] = None
segmentation_mask: Optional[image_module.Image] = None
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _HolisticResultProto
) -> 'HolisticLandmarkerResult':
"""Creates a `HolisticLandmarkerResult` object from the given protobuf object."""
face_blendshapes = None
if hasattr(pb2_obj, 'face_blendshapes'):
face_blendshapes = [
category_module.Category(
score=classification.score,
index=classification.index,
category_name=classification.label,
display_name=classification.display_name,
)
for classification in pb2_obj.face_blendshapes.classification
]
return HolisticLandmarkerResult(
face_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.face_landmarks.landmark
],
pose_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.pose_landmarks.landmark
],
pose_world_landmarks=[
landmark_module.Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.pose_world_landmarks.landmark
],
left_hand_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.left_hand_landmarks.landmark
],
left_hand_world_landmarks=[],
right_hand_landmarks=[
landmark_module.NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.right_hand_landmarks.landmark
],
right_hand_world_landmarks=[],
face_blendshapes=face_blendshapes,
segmentation_mask=None,
)
def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet]
) -> HolisticLandmarkerResult:
"""Constructs a `HolisticLandmarksDetectionResult` from output packets."""
holistic_landmarker_result = HolisticLandmarkerResult(
[], [], [], [], [], [], []
)
face_landmarks_proto_list = packet_getter.get_proto(
output_packets[_FACE_LANDMARKS_STREAM_NAME]
)
pose_landmarks_proto_list = packet_getter.get_proto(
output_packets[_POSE_LANDMARKS_STREAM_NAME]
)
pose_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_POSE_WORLD_LANDMARKS_STREAM_NAME]
)
left_hand_landmarks_proto_list = packet_getter.get_proto(
output_packets[_LEFT_HAND_LANDMARKS_STREAM_NAME]
)
left_hand_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME]
)
right_hand_landmarks_proto_list = packet_getter.get_proto(
output_packets[_RIGHT_HAND_LANDMARKS_STREAM_NAME]
)
right_hand_world_landmarks_proto_list = packet_getter.get_proto(
output_packets[_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME]
)
face_landmarks = landmark_pb2.NormalizedLandmarkList()
face_landmarks.MergeFrom(face_landmarks_proto_list)
for face_landmark in face_landmarks.landmark:
holistic_landmarker_result.face_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(face_landmark)
)
pose_landmarks = landmark_pb2.NormalizedLandmarkList()
pose_landmarks.MergeFrom(pose_landmarks_proto_list)
for pose_landmark in pose_landmarks.landmark:
holistic_landmarker_result.pose_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(pose_landmark)
)
pose_world_landmarks = landmark_pb2.LandmarkList()
pose_world_landmarks.MergeFrom(pose_world_landmarks_proto_list)
for pose_world_landmark in pose_world_landmarks.landmark:
holistic_landmarker_result.pose_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(pose_world_landmark)
)
left_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
left_hand_landmarks.MergeFrom(left_hand_landmarks_proto_list)
for hand_landmark in left_hand_landmarks.landmark:
holistic_landmarker_result.left_hand_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
left_hand_world_landmarks = landmark_pb2.LandmarkList()
left_hand_world_landmarks.MergeFrom(left_hand_world_landmarks_proto_list)
for left_hand_world_landmark in left_hand_world_landmarks.landmark:
holistic_landmarker_result.left_hand_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(left_hand_world_landmark)
)
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
right_hand_landmarks.MergeFrom(right_hand_landmarks_proto_list)
for hand_landmark in right_hand_landmarks.landmark:
holistic_landmarker_result.right_hand_landmarks.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
)
right_hand_world_landmarks = landmark_pb2.LandmarkList()
right_hand_world_landmarks.MergeFrom(right_hand_world_landmarks_proto_list)
for right_hand_world_landmark in right_hand_world_landmarks.landmark:
holistic_landmarker_result.right_hand_world_landmarks.append(
landmark_module.Landmark.create_from_pb2(right_hand_world_landmark)
)
if _FACE_BLENDSHAPES_STREAM_NAME in output_packets:
face_blendshapes_proto_list = packet_getter.get_proto(
output_packets[_FACE_BLENDSHAPES_STREAM_NAME]
)
face_blendshapes_classifications = classification_pb2.ClassificationList()
face_blendshapes_classifications.MergeFrom(face_blendshapes_proto_list)
holistic_landmarker_result.face_blendshapes = []
for face_blendshapes in face_blendshapes_classifications.classification:
holistic_landmarker_result.face_blendshapes.append(
category_module.Category(
index=face_blendshapes.index,
score=face_blendshapes.score,
display_name=face_blendshapes.display_name,
category_name=face_blendshapes.label,
)
)
if _POSE_SEGMENTATION_MASK_STREAM_NAME in output_packets:
holistic_landmarker_result.segmentation_mask = packet_getter.get_image(
output_packets[_POSE_SEGMENTATION_MASK_STREAM_NAME]
)
return holistic_landmarker_result
@dataclasses.dataclass
class HolisticLandmarkerOptions:
"""Options for the holistic landmarker task.
Attributes:
base_options: Base options for the holistic landmarker task.
running_mode: The running mode of the task. Default to the image mode.
HolisticLandmarker has three running modes: 1) The image mode for
detecting holistic landmarks on single image inputs. 2) The video mode for
detecting holistic landmarks on the decoded frames of a video. 3) The live
stream mode for detecting holistic landmarks on the live stream of input
data, such as from camera. In this mode, the "result_callback" below must
be specified to receive the detection results asynchronously.
min_face_detection_confidence: The minimum confidence score for the face
detection to be considered successful.
min_face_suppression_threshold: The minimum non-maximum-suppression
threshold for face detection to be considered overlapped.
min_face_landmarks_confidence: The minimum confidence score for the face
landmark detection to be considered successful.
min_pose_detection_confidence: The minimum confidence score for the pose
detection to be considered successful.
min_pose_suppression_threshold: The minimum non-maximum-suppression
threshold for pose detection to be considered overlapped.
min_pose_landmarks_confidence: The minimum confidence score for the pose
landmark detection to be considered successful.
min_hand_landmarks_confidence: The minimum confidence score for the hand
landmark detection to be considered successful.
output_face_blendshapes: Whether HolisticLandmarker outputs face blendshapes
classification. Face blendshapes are used for rendering the 3D face model.
output_segmentation_mask: whether to output segmentation masks.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
min_face_detection_confidence: float = 0.5
min_face_suppression_threshold: float = 0.5
min_face_landmarks_confidence: float = 0.5
min_pose_detection_confidence: float = 0.5
min_pose_suppression_threshold: float = 0.5
min_pose_landmarks_confidence: float = 0.5
min_hand_landmarks_confidence: float = 0.5
output_face_blendshapes: bool = False
output_segmentation_mask: bool = False
result_callback: Optional[
Callable[[HolisticLandmarkerResult, image_module.Image, int], None]
] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _HolisticLandmarkerGraphOptionsProto:
"""Generates an HolisticLandmarkerGraphOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True
)
# Initialize the holistic landmarker options from base options.
holistic_landmarker_options_proto = _HolisticLandmarkerGraphOptionsProto(
base_options=base_options_proto
)
# Configure face detector and face landmarks detector options.
holistic_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = (
self.min_face_detection_confidence
)
holistic_landmarker_options_proto.face_detector_graph_options.min_suppression_threshold = (
self.min_face_suppression_threshold
)
holistic_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = (
self.min_face_landmarks_confidence
)
# Configure pose detector and pose landmarks detector options.
holistic_landmarker_options_proto.pose_detector_graph_options.min_detection_confidence = (
self.min_pose_detection_confidence
)
holistic_landmarker_options_proto.pose_detector_graph_options.min_suppression_threshold = (
self.min_pose_suppression_threshold
)
holistic_landmarker_options_proto.pose_landmarks_detector_graph_options.min_detection_confidence = (
self.min_pose_landmarks_confidence
)
# Configure hand landmarks detector options.
holistic_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = (
self.min_hand_landmarks_confidence
)
return holistic_landmarker_options_proto
class HolisticLandmarker(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs holistic landmarks detection on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'HolisticLandmarker':
"""Creates an `HolisticLandmarker` object from a TensorFlow Lite model and the default `HolisticLandmarkerOptions`.
Note that the created `HolisticLandmarker` instance is in image mode, for
detecting holistic landmarks on single image inputs.
Args:
model_path: Path to the model.
Returns:
`HolisticLandmarker` object that's created from the model file and the
default `HolisticLandmarkerOptions`.
Raises:
ValueError: If failed to create `HolisticLandmarker` object from the
provided file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = HolisticLandmarkerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE
)
return cls.create_from_options(options)
@classmethod
def create_from_options(
cls, options: HolisticLandmarkerOptions
) -> 'HolisticLandmarker':
"""Creates the `HolisticLandmarker` object from holistic landmarker options.
Args:
options: Options for the holistic landmarker task.
Returns:
`HolisticLandmarker` object that's created from `options`.
Raises:
ValueError: If failed to create `HolisticLandmarker` object from
`HolisticLandmarkerOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty():
empty_packet = output_packets[_FACE_LANDMARKS_STREAM_NAME]
options.result_callback(
HolisticLandmarkerResult([], [], [], [], [], [], []),
image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
return
holistic_landmarks_detection_result = _build_landmarker_result(
output_packets
)
timestamp = output_packets[_FACE_LANDMARKS_STREAM_NAME].timestamp
options.result_callback(
holistic_landmarks_detection_result,
image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
)
output_streams = [
':'.join([_FACE_LANDMARKS_TAG, _FACE_LANDMARKS_STREAM_NAME]),
':'.join([_POSE_LANDMARKS_TAG_NAME, _POSE_LANDMARKS_STREAM_NAME]),
':'.join(
[_POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME]
),
':'.join([_LEFT_HAND_LANDMARKS_TAG, _LEFT_HAND_LANDMARKS_STREAM_NAME]),
':'.join([
_LEFT_HAND_WORLD_LANDMARKS_TAG,
_LEFT_HAND_WORLD_LANDMARKS_STREAM_NAME,
]),
':'.join(
[_RIGHT_HAND_LANDMARKS_TAG, _RIGHT_HAND_LANDMARKS_STREAM_NAME]
),
':'.join([
_RIGHT_HAND_WORLD_LANDMARKS_TAG,
_RIGHT_HAND_WORLD_LANDMARKS_STREAM_NAME,
]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_segmentation_mask:
output_streams.append(
':'.join(
[_POSE_SEGMENTATION_MASK_TAG, _POSE_SEGMENTATION_MASK_STREAM_NAME]
)
)
if options.output_face_blendshapes:
output_streams.append(
':'.join([_FACE_BLENDSHAPES_TAG, _FACE_BLENDSHAPES_STREAM_NAME])
)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
],
output_streams=output_streams,
task_options=options,
)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode
== _RunningMode.LIVE_STREAM
),
options.running_mode,
packets_callback if options.result_callback else None,
)
def detect(
self,
image: image_module.Image,
) -> HolisticLandmarkerResult:
"""Performs holistic landmarks detection on the given image.
Only use this method when the HolisticLandmarker is created with the image
running mode.
The image can be of any size with format RGB or RGBA.
Args:
image: MediaPipe Image.
Returns:
The holistic landmarks detection results.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If holistic landmarker detection failed to run.
"""
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
})
if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty():
return HolisticLandmarkerResult([], [], [], [], [], [], [])
return _build_landmarker_result(output_packets)
def detect_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
) -> HolisticLandmarkerResult:
"""Performs holistic landmarks detection on the provided video frame.
Only use this method when the HolisticLandmarker is created with the video
running mode.
Only use this method when the HolisticLandmarker is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
Returns:
The holistic landmarks detection results.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If holistic landmarker detection failed to run.
"""
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
})
if output_packets[_FACE_LANDMARKS_STREAM_NAME].is_empty():
return HolisticLandmarkerResult([], [], [], [], [], [], [])
return _build_landmarker_result(output_packets)
def detect_async(
self,
image: image_module.Image,
timestamp_ms: int,
) -> None:
"""Sends live image data to perform holistic landmarks detection.
The results will be available via the "result_callback" provided in the
HolisticLandmarkerOptions. Only use this method when the HolisticLandmarker
is
created with the live stream running mode.
Only use this method when the HolisticLandmarker is created with the live
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input image is accepted. The results will be available via the
`result_callback` provided in the `HolisticLandmarkerOptions`. The
`detect_async` method is designed to process live stream data such as
camera input. To lower the overall latency, holistic landmarker may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` provides:
- The holistic landmarks detection results.
- The input image that the holistic landmarker runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
Raises:
ValueError: If the current input timestamp is smaller than what the
holistic landmarker has already processed.
"""
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND
),
})

View File

@ -48,7 +48,6 @@ mediapipe_files(srcs = [
"face_landmark.tflite", "face_landmark.tflite",
"face_landmarker.task", "face_landmarker.task",
"face_landmarker_v2.task", "face_landmarker_v2.task",
"face_landmarker_v2_with_blendshapes.task",
"face_stylizer_color_ink.task", "face_stylizer_color_ink.task",
"fist.jpg", "fist.jpg",
"fist.png", "fist.png",
@ -58,11 +57,9 @@ mediapipe_files(srcs = [
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"hand_landmarker.task", "hand_landmarker.task",
"handrecrop_2020_07_21_v0.f16.tflite", "handrecrop_2020_07_21_v0.f16.tflite",
"holistic_landmarker.task",
"left_hands.jpg", "left_hands.jpg",
"left_hands_rotated.jpg", "left_hands_rotated.jpg",
"leopard_bg_removal_result_512x512.png", "leopard_bg_removal_result_512x512.png",
"male_full_height_hands.jpg",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",
@ -144,7 +141,6 @@ filegroup(
"left_hands.jpg", "left_hands.jpg",
"left_hands_rotated.jpg", "left_hands_rotated.jpg",
"leopard_bg_removal_result_512x512.png", "leopard_bg_removal_result_512x512.png",
"male_full_height_hands.jpg",
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg", "multi_objects_rotated.jpg",
@ -189,7 +185,6 @@ filegroup(
"face_detection_short_range.tflite", "face_detection_short_range.tflite",
"face_landmarker.task", "face_landmarker.task",
"face_landmarker_v2.task", "face_landmarker_v2.task",
"face_landmarker_v2_with_blendshapes.task",
"face_stylizer_color_ink.task", "face_stylizer_color_ink.task",
"gesture_recognizer.task", "gesture_recognizer.task",
"hair_segmentation.tflite", "hair_segmentation.tflite",
@ -197,7 +192,6 @@ filegroup(
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"hand_landmarker.task", "hand_landmarker.task",
"handrecrop_2020_07_21_v0.f16.tflite", "handrecrop_2020_07_21_v0.f16.tflite",
"holistic_landmarker.task",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",

View File

@ -71,15 +71,6 @@ export interface MediapipeTasksFake {
/** An map of field paths to values */ /** An map of field paths to values */
export type FieldPathToValue = [string[] | string, unknown]; export type FieldPathToValue = [string[] | string, unknown];
type JsonObject = Record<string, unknown>;
/**
* The function to convert a binary proto to a JsonObject.
* For example, the deserializer of HolisticLandmarkerOptions's binary proto is
* HolisticLandmarkerOptions.deserializeBinary(binaryProto).toObject().
*/
export type Deserializer = (binaryProto: Uint8Array) => JsonObject;
/** /**
* Verifies that the graph has been initialized and that it contains the * Verifies that the graph has been initialized and that it contains the
* provided options. * provided options.
@ -88,7 +79,6 @@ export function verifyGraph(
tasksFake: MediapipeTasksFake, tasksFake: MediapipeTasksFake,
expectedCalculatorOptions?: FieldPathToValue, expectedCalculatorOptions?: FieldPathToValue,
expectedBaseOptions?: FieldPathToValue, expectedBaseOptions?: FieldPathToValue,
deserializer?: Deserializer,
): void { ): void {
expect(tasksFake.graph).toBeDefined(); expect(tasksFake.graph).toBeDefined();
// Our graphs should have at least one node in them for processing, and // Our graphs should have at least one node in them for processing, and
@ -99,31 +89,22 @@ export function verifyGraph(
expect(node).toEqual( expect(node).toEqual(
jasmine.objectContaining({calculator: tasksFake.calculatorName})); jasmine.objectContaining({calculator: tasksFake.calculatorName}));
let proto;
if (deserializer) {
const binaryProto =
tasksFake.graph!.getNodeList()[0].getNodeOptionsList()[0].getValue() as
Uint8Array;
proto = deserializer(binaryProto);
} else {
proto = (node.options as {ext: unknown}).ext;
}
if (expectedBaseOptions) { if (expectedBaseOptions) {
const [fieldPath, value] = expectedBaseOptions; const [fieldPath, value] = expectedBaseOptions;
let baseOptions = (proto as {baseOptions: unknown}).baseOptions; let proto = (node.options as {ext: {baseOptions: unknown}}).ext.baseOptions;
for (const fieldName of ( for (const fieldName of (
Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { Array.isArray(fieldPath) ? fieldPath : [fieldPath])) {
baseOptions = ((baseOptions ?? {}) as JsonObject)[fieldName]; proto = ((proto ?? {}) as Record<string, unknown>)[fieldName];
} }
expect(baseOptions).toEqual(value); expect(proto).toEqual(value);
} }
if (expectedCalculatorOptions) { if (expectedCalculatorOptions) {
const [fieldPath, value] = expectedCalculatorOptions; const [fieldPath, value] = expectedCalculatorOptions;
let proto = (node.options as {ext: unknown}).ext;
for (const fieldName of ( for (const fieldName of (
Array.isArray(fieldPath) ? fieldPath : [fieldPath])) { Array.isArray(fieldPath) ? fieldPath : [fieldPath])) {
proto = ((proto ?? {}) as JsonObject)[fieldName]; proto = ((proto ?? {}) as Record<string, unknown>)[fieldName];
} }
expect(proto).toEqual(value); expect(proto).toEqual(value);
} }

View File

@ -27,7 +27,6 @@ VISION_LIBS = [
"//mediapipe/tasks/web/vision/face_stylizer", "//mediapipe/tasks/web/vision/face_stylizer",
"//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/gesture_recognizer",
"//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/hand_landmarker",
"//mediapipe/tasks/web/vision/holistic_landmarker",
"//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/image_classifier",
"//mediapipe/tasks/web/vision/image_embedder", "//mediapipe/tasks/web/vision/image_embedder",
"//mediapipe/tasks/web/vision/image_segmenter", "//mediapipe/tasks/web/vision/image_segmenter",

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {assertExists, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner';
/** /**
@ -159,13 +159,13 @@ export class CategoryMaskShaderContext extends MPImageShaderContext {
protected override setupShaders(): void { protected override setupShaders(): void {
super.setupShaders(); super.setupShaders();
const gl = this.gl!; const gl = this.gl!;
this.backgroundTextureUniform = assertExists( this.backgroundTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'backgroundTexture'), gl.getUniformLocation(this.program!, 'backgroundTexture'),
'Uniform location'); 'Uniform location');
this.colorMappingTextureUniform = assertExists( this.colorMappingTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'colorMappingTexture'), gl.getUniformLocation(this.program!, 'colorMappingTexture'),
'Uniform location'); 'Uniform location');
this.maskTextureUniform = assertExists( this.maskTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'maskTexture'), gl.getUniformLocation(this.program!, 'maskTexture'),
'Uniform location'); 'Uniform location');
} }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {assertExists, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner';
/** /**
@ -60,13 +60,13 @@ export class ConfidenceMaskShaderContext extends MPImageShaderContext {
protected override setupShaders(): void { protected override setupShaders(): void {
super.setupShaders(); super.setupShaders();
const gl = this.gl!; const gl = this.gl!;
this.defaultTextureUniform = assertExists( this.defaultTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'defaultTexture'), gl.getUniformLocation(this.program!, 'defaultTexture'),
'Uniform location'); 'Uniform location');
this.overlayTextureUniform = assertExists( this.overlayTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'overlayTexture'), gl.getUniformLocation(this.program!, 'overlayTexture'),
'Uniform location'); 'Uniform location');
this.maskTextureUniform = assertExists( this.maskTextureUniform = assertNotNull(
gl.getUniformLocation(this.program!, 'maskTexture'), gl.getUniformLocation(this.program!, 'maskTexture'),
'Uniform location'); 'Uniform location');
} }

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {assertExists, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
/** Number of instances a user can keep alive before we raise a warning. */ /** Number of instances a user can keep alive before we raise a warning. */
const INSTANCE_COUNT_WARNING_THRESHOLD = 250; const INSTANCE_COUNT_WARNING_THRESHOLD = 250;
@ -249,8 +249,8 @@ export class MPImage {
'is passed when iniitializing the image.'); 'is passed when iniitializing the image.');
} }
if (!this.gl) { if (!this.gl) {
this.gl = assertExists( this.gl = assertNotNull(
this.canvas.getContext('webgl2') as WebGL2RenderingContext, this.canvas.getContext('webgl2'),
'You cannot use a canvas that is already bound to a different ' + 'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.'); 'type of rendering context.');
} }

View File

@ -32,9 +32,9 @@ const FRAGMENT_SHADER = `
} }
`; `;
/** Helper to assert that `value` is not null or undefined. */ /** Helper to assert that `value` is not null. */
export function assertExists<T>(value: T, msg: string): NonNullable<T> { export function assertNotNull<T>(value: T|null, msg: string): T {
if (!value) { if (value === null) {
throw new Error(`Unable to obtain required WebGL resource: ${msg}`); throw new Error(`Unable to obtain required WebGL resource: ${msg}`);
} }
return value; return value;
@ -105,7 +105,7 @@ export class MPImageShaderContext {
private compileShader(source: string, type: number): WebGLShader { private compileShader(source: string, type: number): WebGLShader {
const gl = this.gl!; const gl = this.gl!;
const shader = const shader =
assertExists(gl.createShader(type), 'Failed to create WebGL shader'); assertNotNull(gl.createShader(type), 'Failed to create WebGL shader');
gl.shaderSource(shader, source); gl.shaderSource(shader, source);
gl.compileShader(shader); gl.compileShader(shader);
if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) { if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
@ -119,7 +119,7 @@ export class MPImageShaderContext {
protected setupShaders(): void { protected setupShaders(): void {
const gl = this.gl!; const gl = this.gl!;
this.program = this.program =
assertExists(gl.createProgram()!, 'Failed to create WebGL program'); assertNotNull(gl.createProgram()!, 'Failed to create WebGL program');
this.vertexShader = this.vertexShader =
this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER); this.compileShader(this.getVertexShader(), gl.VERTEX_SHADER);
@ -144,11 +144,11 @@ export class MPImageShaderContext {
private createBuffers(flipVertically: boolean): MPImageShaderBuffers { private createBuffers(flipVertically: boolean): MPImageShaderBuffers {
const gl = this.gl!; const gl = this.gl!;
const vertexArrayObject = const vertexArrayObject =
assertExists(gl.createVertexArray(), 'Failed to create vertex array'); assertNotNull(gl.createVertexArray(), 'Failed to create vertex array');
gl.bindVertexArray(vertexArrayObject); gl.bindVertexArray(vertexArrayObject);
const vertexBuffer = const vertexBuffer =
assertExists(gl.createBuffer(), 'Failed to create buffer'); assertNotNull(gl.createBuffer(), 'Failed to create buffer');
gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
gl.enableVertexAttribArray(this.aVertex!); gl.enableVertexAttribArray(this.aVertex!);
gl.vertexAttribPointer(this.aVertex!, 2, gl.FLOAT, false, 0, 0); gl.vertexAttribPointer(this.aVertex!, 2, gl.FLOAT, false, 0, 0);
@ -157,7 +157,7 @@ export class MPImageShaderContext {
gl.STATIC_DRAW); gl.STATIC_DRAW);
const textureBuffer = const textureBuffer =
assertExists(gl.createBuffer(), 'Failed to create buffer'); assertNotNull(gl.createBuffer(), 'Failed to create buffer');
gl.bindBuffer(gl.ARRAY_BUFFER, textureBuffer); gl.bindBuffer(gl.ARRAY_BUFFER, textureBuffer);
gl.enableVertexAttribArray(this.aTex!); gl.enableVertexAttribArray(this.aTex!);
gl.vertexAttribPointer(this.aTex!, 2, gl.FLOAT, false, 0, 0); gl.vertexAttribPointer(this.aTex!, 2, gl.FLOAT, false, 0, 0);
@ -232,7 +232,7 @@ export class MPImageShaderContext {
WebGLTexture { WebGLTexture {
this.maybeInitGL(gl); this.maybeInitGL(gl);
const texture = const texture =
assertExists(gl.createTexture(), 'Failed to create texture'); assertNotNull(gl.createTexture(), 'Failed to create texture');
gl.bindTexture(gl.TEXTURE_2D, texture); gl.bindTexture(gl.TEXTURE_2D, texture);
gl.texParameteri( gl.texParameteri(
gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE); gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, wrapping ?? gl.CLAMP_TO_EDGE);
@ -252,7 +252,7 @@ export class MPImageShaderContext {
this.maybeInitGL(gl); this.maybeInitGL(gl);
if (!this.framebuffer) { if (!this.framebuffer) {
this.framebuffer = this.framebuffer =
assertExists(gl.createFramebuffer(), 'Failed to create framebuffe.'); assertNotNull(gl.createFramebuffer(), 'Failed to create framebuffe.');
} }
gl.bindFramebuffer(gl.FRAMEBUFFER, this.framebuffer); gl.bindFramebuffer(gl.FRAMEBUFFER, this.framebuffer);
gl.framebufferTexture2D( gl.framebufferTexture2D(

View File

@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {assertExists, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {assertNotNull, MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {isIOS} from '../../../../web/graph_runner/platform_utils'; import {isIOS} from '../../../../web/graph_runner/platform_utils';
/** Number of instances a user can keep alive before we raise a warning. */ /** Number of instances a user can keep alive before we raise a warning. */
@ -270,8 +270,8 @@ export class MPMask {
'is passed when initializing the image.'); 'is passed when initializing the image.');
} }
if (!this.gl) { if (!this.gl) {
this.gl = assertExists( this.gl = assertNotNull(
this.canvas.getContext('webgl2') as WebGL2RenderingContext, this.canvas.getContext('webgl2'),
'You cannot use a canvas that is already bound to a different ' + 'You cannot use a canvas that is already bound to a different ' +
'type of rendering context.'); 'type of rendering context.');
} }

View File

@ -21,7 +21,7 @@ import {MPImage} from '../../../../tasks/web/vision/core/image';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context'; import {MPImageShaderContext} from '../../../../tasks/web/vision/core/image_shader_context';
import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {getImageSourceSize, GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner'; import {GraphRunner, ImageSource, WasmMediaPipeConstructor} from '../../../../web/graph_runner/graph_runner';
import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {isWebKit} from '../../../../web/graph_runner/platform_utils'; import {isWebKit} from '../../../../web/graph_runner/platform_utils';
import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service';
@ -134,6 +134,22 @@ export abstract class VisionTaskRunner extends TaskRunner {
this.process(imageFrame, imageProcessingOptions, timestamp); this.process(imageFrame, imageProcessingOptions, timestamp);
} }
private getImageSourceSize(imageSource: ImageSource): [number, number] {
if ((imageSource as HTMLVideoElement).videoWidth !== undefined) {
return [
(imageSource as HTMLVideoElement).videoWidth,
(imageSource as HTMLVideoElement).videoHeight
];
} else if ((imageSource as HTMLImageElement).naturalWidth !== undefined) {
return [
(imageSource as HTMLImageElement).naturalWidth,
(imageSource as HTMLImageElement).naturalHeight
];
} else {
return [imageSource.width, imageSource.height];
}
}
private convertToNormalizedRect( private convertToNormalizedRect(
imageSource: ImageSource, imageSource: ImageSource,
imageProcessingOptions?: ImageProcessingOptions): NormalizedRect { imageProcessingOptions?: ImageProcessingOptions): NormalizedRect {
@ -183,7 +199,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
// uses this for cropping, // uses this for cropping,
// - then finally rotates this back. // - then finally rotates this back.
if (imageProcessingOptions?.rotationDegrees % 180 !== 0) { if (imageProcessingOptions?.rotationDegrees % 180 !== 0) {
const [imageWidth, imageHeight] = getImageSourceSize(imageSource); const [imageWidth, imageHeight] = this.getImageSourceSize(imageSource);
// tslint:disable:no-unnecessary-type-assertion // tslint:disable:no-unnecessary-type-assertion
const width = normalizedRect.getHeight()! * imageHeight / imageWidth; const width = normalizedRect.getHeight()! * imageHeight / imageWidth;
const height = normalizedRect.getWidth()! * imageWidth / imageHeight; const height = normalizedRect.getWidth()! * imageWidth / imageHeight;

View File

@ -1,94 +0,0 @@
# This contains the MediaPipe Hand Landmarker Task.
#
# This task takes video frames and outputs synchronized frames along with
# the detection results for one or more holistic categories, using Hand Landmarker.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "holistic_landmarker",
srcs = ["holistic_landmarker.ts"],
visibility = ["//visibility:public"],
deps = [
":holistic_landmarker_connections",
":holistic_landmarker_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_roi_refinement_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/components/processors:classifier_result",
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)
mediapipe_ts_declaration(
name = "holistic_landmarker_types",
srcs = [
"holistic_landmarker_options.d.ts",
"holistic_landmarker_result.d.ts",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/containers:landmark",
"//mediapipe/tasks/web/components/containers:matrix",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "holistic_landmarker_test_lib",
testonly = True,
srcs = [
"holistic_landmarker_test.ts",
],
deps = [
":holistic_landmarker",
":holistic_landmarker_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/vision/holistic_landmarker/proto:holistic_landmarker_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/processors:landmark_result",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
],
)
jasmine_node_test(
name = "holistic_landmarker_test",
tags = ["nomsan"],
deps = [":holistic_landmarker_test_lib"],
)
mediapipe_ts_library(
name = "holistic_landmarker_connections",
deps = [
"//mediapipe/tasks/web/vision/face_landmarker:face_landmarks_connections",
"//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections",
"//mediapipe/tasks/web/vision/pose_landmarker:pose_landmarks_connections",
],
)

View File

@ -1,732 +0,0 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Any} from 'google-protobuf/google/protobuf/any_pb';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {ClassificationList as ClassificationListProto} from '../../../../framework/formats/classification_pb';
import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {FaceDetectorGraphOptions} from '../../../../tasks/cc/vision/face_detector/proto/face_detector_graph_options_pb';
import {FaceLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options_pb';
import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb';
import {HandRoiRefinementGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_roi_refinement_graph_options_pb';
import {HolisticLandmarkerGraphOptions} from '../../../../tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options_pb';
import {PoseDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_detector/proto/pose_detector_graph_options_pb';
import {PoseLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options_pb';
import {Classifications} from '../../../../tasks/web/components/containers/classification_result';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {convertFromClassifications} from '../../../../tasks/web/components/processors/classifier_result';
import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {FACE_LANDMARKS_CONTOURS, FACE_LANDMARKS_FACE_OVAL, FACE_LANDMARKS_LEFT_EYE, FACE_LANDMARKS_LEFT_EYEBROW, FACE_LANDMARKS_LEFT_IRIS, FACE_LANDMARKS_LIPS, FACE_LANDMARKS_RIGHT_EYE, FACE_LANDMARKS_RIGHT_EYEBROW, FACE_LANDMARKS_RIGHT_IRIS, FACE_LANDMARKS_TESSELATION} from '../../../../tasks/web/vision/face_landmarker/face_landmarks_connections';
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
import {POSE_CONNECTIONS} from '../../../../tasks/web/vision/pose_landmarker/pose_landmarks_connections';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
import {HolisticLandmarkerOptions} from './holistic_landmarker_options';
import {HolisticLandmarkerResult} from './holistic_landmarker_result';
export * from './holistic_landmarker_options';
export * from './holistic_landmarker_result';
export {ImageSource};
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
const IMAGE_STREAM = 'input_frames_image';
const POSE_LANDMARKS_STREAM = 'pose_landmarks';
const POSE_WORLD_LANDMARKS_STREAM = 'pose_world_landmarks';
const POSE_SEGMENTATION_MASK_STREAM = 'pose_segmentation_mask';
const FACE_LANDMARKS_STREAM = 'face_landmarks';
const FACE_BLENDSHAPES_STREAM = 'extra_blendshapes';
const LEFT_HAND_LANDMARKS_STREAM = 'left_hand_landmarks';
const LEFT_HAND_WORLD_LANDMARKS_STREAM = 'left_hand_world_landmarks';
const RIGHT_HAND_LANDMARKS_STREAM = 'right_hand_landmarks';
const RIGHT_HAND_WORLD_LANDMARKS_STREAM = 'right_hand_world_landmarks';
const HOLISTIC_LANDMARKER_GRAPH =
'mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph';
const DEFAULT_SUPRESSION_THRESHOLD = 0.3;
const DEFAULT_SCORE_THRESHOLD = 0.5;
/**
* A callback that receives the result from the holistic landmarker detection.
* The returned result are only valid for the duration of the callback. If
* asynchronous processing is needed, the masks need to be copied before the
* callback returns.
*/
export type HolisticLandmarkerCallback = (result: HolisticLandmarkerResult) =>
void;
/** Performs holistic landmarks detection on images. */
export class HolisticLandmarker extends VisionTaskRunner {
private result: HolisticLandmarkerResult = {
faceLandmarks: [],
faceBlendshapes: [],
poseLandmarks: [],
poseWorldLandmarks: [],
poseSegmentationMasks: [],
leftHandLandmarks: [],
leftHandWorldLandmarks: [],
rightHandLandmarks: [],
rightHandWorldLandmarks: []
};
private outputFaceBlendshapes = false;
private outputPoseSegmentationMasks = false;
private userCallback?: HolisticLandmarkerCallback;
private readonly options: HolisticLandmarkerGraphOptions;
private readonly handLandmarksDetectorGraphOptions:
HandLandmarksDetectorGraphOptions;
private readonly handRoiRefinementGraphOptions: HandRoiRefinementGraphOptions;
private readonly faceDetectorGraphOptions: FaceDetectorGraphOptions;
private readonly faceLandmarksDetectorGraphOptions:
FaceLandmarksDetectorGraphOptions;
private readonly poseDetectorGraphOptions: PoseDetectorGraphOptions;
private readonly poseLandmarksDetectorGraphOptions:
PoseLandmarksDetectorGraphOptions;
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
* @export
* @nocollapse
*/
static HAND_CONNECTIONS = HAND_CONNECTIONS;
/**
* An array containing the pairs of pose landmark indices to be rendered with
* connections.
* @export
* @nocollapse
*/
static POSE_CONNECTIONS = POSE_CONNECTIONS;
/**
* Landmark connections to draw the connection between a face's lips.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_LIPS = FACE_LANDMARKS_LIPS;
/**
* Landmark connections to draw the connection between a face's left eye.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_LEFT_EYE = FACE_LANDMARKS_LEFT_EYE;
/**
* Landmark connections to draw the connection between a face's left eyebrow.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_LEFT_EYEBROW = FACE_LANDMARKS_LEFT_EYEBROW;
/**
* Landmark connections to draw the connection between a face's left iris.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_LEFT_IRIS = FACE_LANDMARKS_LEFT_IRIS;
/**
* Landmark connections to draw the connection between a face's right eye.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_RIGHT_EYE = FACE_LANDMARKS_RIGHT_EYE;
/**
* Landmark connections to draw the connection between a face's right
* eyebrow.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_RIGHT_EYEBROW = FACE_LANDMARKS_RIGHT_EYEBROW;
/**
* Landmark connections to draw the connection between a face's right iris.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_RIGHT_IRIS = FACE_LANDMARKS_RIGHT_IRIS;
/**
* Landmark connections to draw the face's oval.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_FACE_OVAL = FACE_LANDMARKS_FACE_OVAL;
/**
* Landmark connections to draw the face's contour.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_CONTOURS = FACE_LANDMARKS_CONTOURS;
/**
* Landmark connections to draw the face's tesselation.
* @export
* @nocollapse
*/
static FACE_LANDMARKS_TESSELATION = FACE_LANDMARKS_TESSELATION;
/**
* Initializes the Wasm runtime and creates a new `HolisticLandmarker` from
* the provided options.
* @export
* @param wasmFileset A configuration object that provides the location of the
* Wasm binary and its loader.
* @param holisticLandmarkerOptions The options for the HolisticLandmarker.
* Note that either a path to the model asset or a model buffer needs to
* be provided (via `baseOptions`).
*/
static createFromOptions(
wasmFileset: WasmFileset,
holisticLandmarkerOptions: HolisticLandmarkerOptions):
Promise<HolisticLandmarker> {
return VisionTaskRunner.createVisionInstance(
HolisticLandmarker, wasmFileset, holisticLandmarkerOptions);
}
/**
* Initializes the Wasm runtime and creates a new `HolisticLandmarker` based
* on the provided model asset buffer.
* @export
* @param wasmFileset A configuration object that provides the location of the
* Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<HolisticLandmarker> {
return VisionTaskRunner.createVisionInstance(
HolisticLandmarker, wasmFileset, {baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new `HolisticLandmarker` based
* on the path to the model asset.
* @export
* @param wasmFileset A configuration object that provides the location of the
* Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static createFromModelPath(
wasmFileset: WasmFileset,
modelAssetPath: string): Promise<HolisticLandmarker> {
return VisionTaskRunner.createVisionInstance(
HolisticLandmarker, wasmFileset, {baseOptions: {modelAssetPath}});
}
/** @hideconstructor */
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
/* normRectStream= */ null, /* roiAllowed= */ false);
this.options = new HolisticLandmarkerGraphOptions();
this.options.setBaseOptions(new BaseOptionsProto());
this.handLandmarksDetectorGraphOptions =
new HandLandmarksDetectorGraphOptions();
this.options.setHandLandmarksDetectorGraphOptions(
this.handLandmarksDetectorGraphOptions);
this.handRoiRefinementGraphOptions = new HandRoiRefinementGraphOptions();
this.options.setHandRoiRefinementGraphOptions(
this.handRoiRefinementGraphOptions);
this.faceDetectorGraphOptions = new FaceDetectorGraphOptions();
this.options.setFaceDetectorGraphOptions(this.faceDetectorGraphOptions);
this.faceLandmarksDetectorGraphOptions =
new FaceLandmarksDetectorGraphOptions();
this.options.setFaceLandmarksDetectorGraphOptions(
this.faceLandmarksDetectorGraphOptions);
this.poseDetectorGraphOptions = new PoseDetectorGraphOptions();
this.options.setPoseDetectorGraphOptions(this.poseDetectorGraphOptions);
this.poseLandmarksDetectorGraphOptions =
new PoseLandmarksDetectorGraphOptions();
this.options.setPoseLandmarksDetectorGraphOptions(
this.poseLandmarksDetectorGraphOptions);
this.initDefaults();
}
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto);
}
/**
* Sets new options for this `HolisticLandmarker`.
*
* Calling `setOptions()` with a subset of options only affects those options.
* You can reset an option back to its default value by explicitly setting it
* to `undefined`.
*
* @export
* @param options The options for the holistic landmarker.
*/
override setOptions(options: HolisticLandmarkerOptions): Promise<void> {
// Configure face detector options.
if ('minFaceDetectionConfidence' in options) {
this.faceDetectorGraphOptions.setMinDetectionConfidence(
options.minFaceDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if ('minFaceSuppressionThreshold' in options) {
this.faceDetectorGraphOptions.setMinSuppressionThreshold(
options.minFaceSuppressionThreshold ?? DEFAULT_SUPRESSION_THRESHOLD);
}
// Configure face landmark detector options.
if ('minFacePresenceConfidence' in options) {
this.faceLandmarksDetectorGraphOptions.setMinDetectionConfidence(
options.minFacePresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if ('outputFaceBlendshapes' in options) {
this.outputFaceBlendshapes = !!options.outputFaceBlendshapes;
}
// Configure pose detector options.
if ('minPoseDetectionConfidence' in options) {
this.poseDetectorGraphOptions.setMinDetectionConfidence(
options.minPoseDetectionConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if ('minPoseSuppressionThreshold' in options) {
this.poseDetectorGraphOptions.setMinSuppressionThreshold(
options.minPoseSuppressionThreshold ?? DEFAULT_SUPRESSION_THRESHOLD);
}
// Configure pose landmark detector options.
if ('minPosePresenceConfidence' in options) {
this.poseLandmarksDetectorGraphOptions.setMinDetectionConfidence(
options.minPosePresenceConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
if ('outputPoseSegmentationMasks' in options) {
this.outputPoseSegmentationMasks = !!options.outputPoseSegmentationMasks;
}
// Configure hand detector options.
if ('minHandLandmarksConfidence' in options) {
this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
options.minHandLandmarksConfidence ?? DEFAULT_SCORE_THRESHOLD);
}
return this.applyOptions(options);
}
/**
* Performs holistic landmarks detection on the provided single image and
* invokes the callback with the response. The method returns synchronously
* once the callback returns. Only use this method when the HolisticLandmarker
* is created with running mode `image`.
*
* @export
* @param image An image to process.
* @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of
* the callback.
*/
detect(image: ImageSource, callback: HolisticLandmarkerCallback): void;
/**
* Performs holistic landmarks detection on the provided single image and
* invokes the callback with the response. The method returns synchronously
* once the callback returns. Only use this method when the HolisticLandmarker
* is created with running mode `image`.
*
* @export
* @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of
* the callback.
*/
detect(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: HolisticLandmarkerCallback): void;
/**
* Performs holistic landmarks detection on the provided single image and
* waits synchronously for the response. This method creates a copy of the
* resulting masks and should not be used in high-throughput applications.
* Only use this method when the HolisticLandmarker is created with running
* mode `image`.
*
* @export
* @param image An image to process.
* @return The landmarker result. Any masks are copied to avoid lifetime
* limits.
* @return The detected pose landmarks.
*/
detect(image: ImageSource): HolisticLandmarkerResult;
/**
* Performs holistic landmarks detection on the provided single image and
* waits synchronously for the response. This method creates a copy of the
* resulting masks and should not be used in high-throughput applications.
* Only use this method when the HolisticLandmarker is created with running
* mode `image`.
*
* @export
* @param image An image to process.
* @return The landmarker result. Any masks are copied to avoid lifetime
* limits.
* @return The detected pose landmarks.
*/
detect(image: ImageSource, imageProcessingOptions: ImageProcessingOptions):
HolisticLandmarkerResult;
/** @export */
detect(
image: ImageSource,
imageProcessingOptionsOrCallback?: ImageProcessingOptions|
HolisticLandmarkerCallback,
callback?: HolisticLandmarkerCallback): HolisticLandmarkerResult|void {
const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback!;
this.resetResults();
this.processImageData(image, imageProcessingOptions);
return this.processResults();
}
/**
* Performs holistic landmarks detection on the provided video frame and
* invokes the callback with the response. The method returns synchronously
* once the callback returns. Only use this method when the HolisticLandmarker
* is created with running mode `video`.
*
* @export
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of
* the callback.
*/
detectForVideo(
videoFrame: ImageSource, timestamp: number,
callback: HolisticLandmarkerCallback): void;
/**
* Performs holistic landmarks detection on the provided video frame and
* invokes the callback with the response. The method returns synchronously
* once the callback returns. Only use this method when the holisticLandmarker
* is created with running mode `video`.
*
* @export
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of
* the callback.
*/
detectForVideo(
videoFrame: ImageSource, timestamp: number,
imageProcessingOptions: ImageProcessingOptions,
callback: HolisticLandmarkerCallback): void;
/**
* Performs holistic landmarks detection on the provided video frame and
* returns the result. This method creates a copy of the resulting masks and
* should not be used in high-throughput applications. Only use this method
* when the HolisticLandmarker is created with running mode `video`.
*
* @export
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @return The landmarker result. Any masks are copied to extend the
* lifetime of the returned data.
*/
detectForVideo(videoFrame: ImageSource, timestamp: number):
HolisticLandmarkerResult;
/**
* Performs holistic landmarks detection on the provided video frame and waits
* synchronously for the response. Only use this method when the
* HolisticLandmarker is created with running mode `video`.
*
* @export
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @return The detected holistic landmarks.
*/
detectForVideo(
videoFrame: ImageSource, timestamp: number,
imageProcessingOptions: ImageProcessingOptions): HolisticLandmarkerResult;
/** @export */
detectForVideo(
videoFrame: ImageSource, timestamp: number,
imageProcessingOptionsOrCallback?: ImageProcessingOptions|
HolisticLandmarkerCallback,
callback?: HolisticLandmarkerCallback): HolisticLandmarkerResult|void {
const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback;
this.resetResults();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
return this.processResults();
}
private resetResults(): void {
this.result = {
faceLandmarks: [],
faceBlendshapes: [],
poseLandmarks: [],
poseWorldLandmarks: [],
poseSegmentationMasks: [],
leftHandLandmarks: [],
leftHandWorldLandmarks: [],
rightHandLandmarks: [],
rightHandWorldLandmarks: []
};
}
private processResults(): HolisticLandmarkerResult|void {
try {
if (this.userCallback) {
this.userCallback(this.result);
} else {
return this.result;
}
} finally {
// Free the image memory, now that we've finished our callback.
this.freeKeepaliveStreams();
}
}
/** Sets the default values for the graph. */
private initDefaults(): void {
this.faceDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
this.faceDetectorGraphOptions.setMinSuppressionThreshold(
DEFAULT_SUPRESSION_THRESHOLD);
this.faceLandmarksDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
this.poseDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
this.poseDetectorGraphOptions.setMinSuppressionThreshold(
DEFAULT_SUPRESSION_THRESHOLD);
this.poseLandmarksDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
this.handLandmarksDetectorGraphOptions.setMinDetectionConfidence(
DEFAULT_SCORE_THRESHOLD);
}
/** Converts raw data into a landmark, and adds it to our landmarks list. */
private addJsLandmarks(data: Uint8Array, outputList: NormalizedLandmark[][]):
void {
const landmarksProto = NormalizedLandmarkList.deserializeBinary(data);
outputList.push(convertToLandmarks(landmarksProto));
}
/**
* Converts raw data into a world landmark, and adds it to our worldLandmarks
* list.
*/
private addJsWorldLandmarks(data: Uint8Array, outputList: Landmark[][]):
void {
const worldLandmarksProto = LandmarkList.deserializeBinary(data);
outputList.push(convertToWorldLandmarks(worldLandmarksProto));
}
/** Adds new blendshapes from the given proto. */
private addBlenshape(data: Uint8Array, outputList: Classifications[]): void {
if (!this.outputFaceBlendshapes) {
return;
}
const classificationList = ClassificationListProto.deserializeBinary(data);
outputList.push(convertFromClassifications(
classificationList.getClassificationList() ?? []));
}
/** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addOutputStream(POSE_LANDMARKS_STREAM);
graphConfig.addOutputStream(POSE_WORLD_LANDMARKS_STREAM);
graphConfig.addOutputStream(FACE_LANDMARKS_STREAM);
graphConfig.addOutputStream(LEFT_HAND_LANDMARKS_STREAM);
graphConfig.addOutputStream(LEFT_HAND_WORLD_LANDMARKS_STREAM);
graphConfig.addOutputStream(RIGHT_HAND_LANDMARKS_STREAM);
graphConfig.addOutputStream(RIGHT_HAND_WORLD_LANDMARKS_STREAM);
const calculatorOptions = new CalculatorOptions();
const optionsProto = new Any();
optionsProto.setTypeUrl(
'type.googleapis.com/mediapipe.tasks.vision.holistic_landmarker.proto.HolisticLandmarkerGraphOptions');
optionsProto.setValue(this.options.serializeBinary());
const landmarkerNode = new CalculatorGraphConfig.Node();
landmarkerNode.setCalculator(HOLISTIC_LANDMARKER_GRAPH);
landmarkerNode.addNodeOptions(optionsProto);
landmarkerNode.addInputStream('IMAGE:' + IMAGE_STREAM);
landmarkerNode.addOutputStream('POSE_LANDMARKS:' + POSE_LANDMARKS_STREAM);
landmarkerNode.addOutputStream(
'POSE_WORLD_LANDMARKS:' + POSE_WORLD_LANDMARKS_STREAM);
landmarkerNode.addOutputStream('FACE_LANDMARKS:' + FACE_LANDMARKS_STREAM);
landmarkerNode.addOutputStream(
'LEFT_HAND_LANDMARKS:' + LEFT_HAND_LANDMARKS_STREAM);
landmarkerNode.addOutputStream(
'LEFT_HAND_WORLD_LANDMARKS:' + LEFT_HAND_WORLD_LANDMARKS_STREAM);
landmarkerNode.addOutputStream(
'RIGHT_HAND_LANDMARKS:' + RIGHT_HAND_LANDMARKS_STREAM);
landmarkerNode.addOutputStream(
'RIGHT_HAND_WORLD_LANDMARKS:' + RIGHT_HAND_WORLD_LANDMARKS_STREAM);
landmarkerNode.setOptions(calculatorOptions);
graphConfig.addNode(landmarkerNode);
// We only need to keep alive the image stream, since the protos are being
// deep-copied anyways via serialization+deserialization.
this.addKeepaliveNode(graphConfig);
this.graphRunner.attachProtoListener(
POSE_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto, this.result.poseLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
POSE_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachProtoListener(
POSE_WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsWorldLandmarks(binaryProto, this.result.poseWorldLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
POSE_WORLD_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
if (this.outputPoseSegmentationMasks) {
landmarkerNode.addOutputStream(
'POSE_SEGMENTATION_MASK:' + POSE_SEGMENTATION_MASK_STREAM);
this.keepStreamAlive(POSE_SEGMENTATION_MASK_STREAM);
this.graphRunner.attachImageListener(
POSE_SEGMENTATION_MASK_STREAM, (mask, timestamp) => {
this.result.poseSegmentationMasks = [this.convertToMPMask(
mask, /* interpolateValues= */ true,
/* shouldCopyData= */ !this.userCallback)];
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
POSE_SEGMENTATION_MASK_STREAM, timestamp => {
this.result.poseSegmentationMasks = [];
this.setLatestOutputTimestamp(timestamp);
});
}
this.graphRunner.attachProtoListener(
FACE_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto, this.result.faceLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
FACE_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
if (this.outputFaceBlendshapes) {
graphConfig.addOutputStream(FACE_BLENDSHAPES_STREAM);
landmarkerNode.addOutputStream(
'FACE_BLENDSHAPES:' + FACE_BLENDSHAPES_STREAM);
this.graphRunner.attachProtoListener(
FACE_BLENDSHAPES_STREAM, (binaryProto, timestamp) => {
this.addBlenshape(binaryProto, this.result.faceBlendshapes);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
FACE_BLENDSHAPES_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
this.graphRunner.attachProtoListener(
LEFT_HAND_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto, this.result.leftHandLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
LEFT_HAND_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachProtoListener(
LEFT_HAND_WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsWorldLandmarks(
binaryProto, this.result.leftHandWorldLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
LEFT_HAND_WORLD_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachProtoListener(
RIGHT_HAND_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsLandmarks(binaryProto, this.result.rightHandLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
RIGHT_HAND_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachProtoListener(
RIGHT_HAND_WORLD_LANDMARKS_STREAM, (binaryProto, timestamp) => {
this.addJsWorldLandmarks(
binaryProto, this.result.rightHandWorldLandmarks);
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
RIGHT_HAND_WORLD_LANDMARKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -1,71 +0,0 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options';
/** Options to configure the MediaPipe HolisticLandmarker Task */
export declare interface HolisticLandmarkerOptions extends VisionTaskOptions {
/**
* The minimum confidence score for the face detection to be considered
* successful. Defaults to 0.5.
*/
minFaceDetectionConfidence?: number|undefined;
/**
* The minimum non-maximum-suppression threshold for face detection to be
* considered overlapped. Defaults to 0.3.
*/
minFaceSuppressionThreshold?: number|undefined;
/**
* The minimum confidence score of face presence score in the face landmarks
* detection. Defaults to 0.5.
*/
minFacePresenceConfidence?: number|undefined;
/**
* Whether FaceLandmarker outputs face blendshapes classification. Face
* blendshapes are used for rendering the 3D face model.
*/
outputFaceBlendshapes?: boolean|undefined;
/**
* The minimum confidence score for the pose detection to be considered
* successful. Defaults to 0.5.
*/
minPoseDetectionConfidence?: number|undefined;
/**
* The minimum non-maximum-suppression threshold for pose detection to be
* considered overlapped. Defaults to 0.3.
*/
minPoseSuppressionThreshold?: number|undefined;
/**
* The minimum confidence score of pose presence score in the pose landmarks
* detection. Defaults to 0.5.
*/
minPosePresenceConfidence?: number|undefined;
/** Whether to output segmentation masks. Defaults to false. */
outputPoseSegmentationMasks?: boolean|undefined;
/**
* The minimum confidence score of hand presence score in the hand landmarks
* detection. Defaults to 0.5.
*/
minHandLandmarksConfidence?: number|undefined;
}

View File

@ -1,55 +0,0 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Category} from '../../../../tasks/web/components/containers/category';
import {Classifications} from '../../../../tasks/web/components/containers/classification_result';
import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark';
import {MPMask} from '../../../../tasks/web/vision/core/mask';
export {Category, Landmark, NormalizedLandmark};
/**
* Represents the holistic landmarks detection results generated by
* `HolisticLandmarker`.
*/
export declare interface HolisticLandmarkerResult {
/** Detected face landmarks in normalized image coordinates. */
faceLandmarks: NormalizedLandmark[][];
/** Optional face blendshapes results. */
faceBlendshapes: Classifications[];
/** Detected pose landmarks in normalized image coordinates. */
poseLandmarks: NormalizedLandmark[][];
/** Pose landmarks in world coordinates of detected poses. */
poseWorldLandmarks: Landmark[][];
/** Optional segmentation mask for the detected pose. */
poseSegmentationMasks: MPMask[];
/** Left hand landmarks of detected left hands. */
leftHandLandmarks: NormalizedLandmark[][];
/** Left hand landmarks in world coordinates of detected left hands. */
leftHandWorldLandmarks: Landmark[][];
/** Right hand landmarks of detected right hands. */
rightHandLandmarks: NormalizedLandmark[][];
/** Right hand landmarks in world coordinates of detected right hands. */
rightHandWorldLandmarks: Landmark[][];
}

View File

@ -1,403 +0,0 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb';
import {HolisticLandmarkerGraphOptions} from '../../../../tasks/cc/vision/holistic_landmarker/proto/holistic_landmarker_graph_options_pb';
import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, Deserializer, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {HolisticLandmarker} from './holistic_landmarker';
import {HolisticLandmarkerOptions} from './holistic_landmarker_options';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
type ProtoListener = ((binaryProtos: Uint8Array, timestamp: number) => void);
const holisticLandmarkerDeserializer =
(binaryProto =>
HolisticLandmarkerGraphOptions.deserializeBinary(binaryProto)
.toObject()) as Deserializer;
function createBlendshapes(): ClassificationList {
const blendshapesProto = new ClassificationList();
const classification = new Classification();
classification.setScore(0.1);
classification.setIndex(1);
classification.setLabel('face_label');
classification.setDisplayName('face_display_name');
blendshapesProto.addClassification(classification);
return blendshapesProto;
}
class HolisticLandmarkerFake extends HolisticLandmarker implements
MediapipeTasksFake {
calculatorName =
'mediapipe.tasks.vision.holistic_landmarker.HolisticLandmarkerGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
listeners = new Map<string, ProtoListener>();
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachProtoListener')
.and.callFake((stream, listener) => {
expect(stream).toMatch(
/(pose_landmarks|pose_world_landmarks|pose_segmentation_mask|face_landmarks|extra_blendshapes|left_hand_landmarks|left_hand_world_landmarks|right_hand_landmarks|right_hand_world_landmarks)/);
this.listeners.set(stream, listener);
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
spyOn(this.graphRunner, 'addProtoToStream');
}
getGraphRunner(): VisionGraphRunner {
return this.graphRunner;
}
}
describe('HolisticLandmarker', () => {
let holisticLandmarker: HolisticLandmarkerFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
holisticLandmarker = new HolisticLandmarkerFake();
await holisticLandmarker.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
afterEach(() => {
holisticLandmarker.close();
});
it('initializes graph', async () => {
verifyGraph(holisticLandmarker);
verifyGraph(
holisticLandmarker, undefined, undefined,
holisticLandmarkerDeserializer);
});
it('reloads graph when settings are changed', async () => {
verifyListenersRegistered(holisticLandmarker);
await holisticLandmarker.setOptions({minFaceDetectionConfidence: 0.6});
verifyGraph(
holisticLandmarker,
[['faceDetectorGraphOptions', 'minDetectionConfidence'], 0.6],
undefined, holisticLandmarkerDeserializer);
verifyListenersRegistered(holisticLandmarker);
await holisticLandmarker.setOptions({minFaceDetectionConfidence: 0.7});
verifyGraph(
holisticLandmarker,
[['faceDetectorGraphOptions', 'minDetectionConfidence'], 0.7],
undefined, holisticLandmarkerDeserializer);
verifyListenersRegistered(holisticLandmarker);
});
it('merges options', async () => {
await holisticLandmarker.setOptions({minFaceDetectionConfidence: 0.5});
await holisticLandmarker.setOptions({minFaceSuppressionThreshold: 0.5});
await holisticLandmarker.setOptions({minFacePresenceConfidence: 0.5});
await holisticLandmarker.setOptions({minPoseDetectionConfidence: 0.5});
await holisticLandmarker.setOptions({minPoseSuppressionThreshold: 0.5});
await holisticLandmarker.setOptions({minPosePresenceConfidence: 0.5});
await holisticLandmarker.setOptions({minHandLandmarksConfidence: 0.5});
verifyGraph(
holisticLandmarker,
[
'faceDetectorGraphOptions', {
baseOptions: undefined,
minDetectionConfidence: 0.5,
minSuppressionThreshold: 0.5,
numFaces: undefined
}
],
undefined, holisticLandmarkerDeserializer);
verifyGraph(
holisticLandmarker,
[
'faceLandmarksDetectorGraphOptions', {
baseOptions: undefined,
minDetectionConfidence: 0.5,
smoothLandmarks: undefined,
faceBlendshapesGraphOptions: undefined
}
],
undefined, holisticLandmarkerDeserializer);
verifyGraph(
holisticLandmarker,
[
'poseDetectorGraphOptions', {
baseOptions: undefined,
minDetectionConfidence: 0.5,
minSuppressionThreshold: 0.5,
numPoses: undefined
}
],
undefined, holisticLandmarkerDeserializer);
verifyGraph(
holisticLandmarker,
[
'poseLandmarksDetectorGraphOptions', {
baseOptions: undefined,
minDetectionConfidence: 0.5,
smoothLandmarks: undefined
}
],
undefined, holisticLandmarkerDeserializer);
verifyGraph(
holisticLandmarker,
[
'handLandmarksDetectorGraphOptions',
{baseOptions: undefined, minDetectionConfidence: 0.5}
],
undefined, holisticLandmarkerDeserializer);
});
describe('setOptions()', () => {
interface TestCase {
optionPath: [keyof HolisticLandmarkerOptions, ...string[]];
fieldPath: string[];
customValue: unknown;
defaultValue: unknown;
}
const testCases: TestCase[] = [
{
optionPath: ['minFaceDetectionConfidence'],
fieldPath: ['faceDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.1,
defaultValue: 0.5
},
{
optionPath: ['minFaceSuppressionThreshold'],
fieldPath: ['faceDetectorGraphOptions', 'minSuppressionThreshold'],
customValue: 0.2,
defaultValue: 0.3
},
{
optionPath: ['minFacePresenceConfidence'],
fieldPath:
['faceLandmarksDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.2,
defaultValue: 0.5
},
{
optionPath: ['minPoseDetectionConfidence'],
fieldPath: ['poseDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.1,
defaultValue: 0.5
},
{
optionPath: ['minPoseSuppressionThreshold'],
fieldPath: ['poseDetectorGraphOptions', 'minSuppressionThreshold'],
customValue: 0.2,
defaultValue: 0.3
},
{
optionPath: ['minPosePresenceConfidence'],
fieldPath:
['poseLandmarksDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.2,
defaultValue: 0.5
},
{
optionPath: ['minHandLandmarksConfidence'],
fieldPath:
['handLandmarksDetectorGraphOptions', 'minDetectionConfidence'],
customValue: 0.1,
defaultValue: 0.5
},
];
/** Creates an options object that can be passed to setOptions() */
function createOptions(
path: string[], value: unknown): HolisticLandmarkerOptions {
const options: Record<string, unknown> = {};
let currentLevel = options;
for (const element of path.slice(0, -1)) {
currentLevel[element] = {};
currentLevel = currentLevel[element] as Record<string, unknown>;
}
currentLevel[path[path.length - 1]] = value;
return options;
}
for (const testCase of testCases) {
it(`uses default value for ${testCase.optionPath[0]}`, async () => {
verifyGraph(
holisticLandmarker, [testCase.fieldPath, testCase.defaultValue],
undefined, holisticLandmarkerDeserializer);
});
it(`can set ${testCase.optionPath[0]}`, async () => {
await holisticLandmarker.setOptions(
createOptions(testCase.optionPath, testCase.customValue));
verifyGraph(
holisticLandmarker, [testCase.fieldPath, testCase.customValue],
undefined, holisticLandmarkerDeserializer);
});
it(`can clear ${testCase.optionPath[0]}`, async () => {
await holisticLandmarker.setOptions(
createOptions(testCase.optionPath, testCase.customValue));
verifyGraph(
holisticLandmarker, [testCase.fieldPath, testCase.customValue],
undefined, holisticLandmarkerDeserializer);
await holisticLandmarker.setOptions(
createOptions(testCase.optionPath, undefined));
verifyGraph(
holisticLandmarker, [testCase.fieldPath, testCase.defaultValue],
undefined, holisticLandmarkerDeserializer);
});
}
});
it('supports outputFaceBlendshapes', async () => {
const stream = 'extra_blendshapes';
await holisticLandmarker.setOptions({});
expect(holisticLandmarker.graph!.getOutputStreamList())
.not.toContain(stream);
await holisticLandmarker.setOptions({outputFaceBlendshapes: false});
expect(holisticLandmarker.graph!.getOutputStreamList())
.not.toContain(stream);
await holisticLandmarker.setOptions({outputFaceBlendshapes: true});
expect(holisticLandmarker.graph!.getOutputStreamList()).toContain(stream);
});
it('transforms results', async () => {
const faceLandmarksProto = createLandmarks().serializeBinary();
const blendshapesProto = createBlendshapes().serializeBinary();
const poseLandmarksProto = createLandmarks().serializeBinary();
const poseWorldLandmarksProto = createWorldLandmarks().serializeBinary();
const leftHandLandmarksProto = createLandmarks().serializeBinary();
const leftHandWorldLandmarksProto =
createWorldLandmarks().serializeBinary();
const rightHandLandmarksProto = createLandmarks().serializeBinary();
const rightHandWorldLandmarksProto =
createWorldLandmarks().serializeBinary();
await holisticLandmarker.setOptions(
{outputFaceBlendshapes: true, outputPoseSegmentationMasks: false});
// Pass the test data to our listener
holisticLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(holisticLandmarker);
holisticLandmarker.listeners.get('face_landmarks')!
(faceLandmarksProto, 1337);
holisticLandmarker.listeners.get('extra_blendshapes')!
(blendshapesProto, 1337);
holisticLandmarker.listeners.get('pose_landmarks')!
(poseLandmarksProto, 1337);
holisticLandmarker.listeners.get('pose_world_landmarks')!
(poseWorldLandmarksProto, 1337);
holisticLandmarker.listeners.get('left_hand_landmarks')!
(leftHandLandmarksProto, 1337);
holisticLandmarker.listeners.get('left_hand_world_landmarks')!
(leftHandWorldLandmarksProto, 1337);
holisticLandmarker.listeners.get('right_hand_landmarks')!
(rightHandLandmarksProto, 1337);
holisticLandmarker.listeners.get('right_hand_world_landmarks')!
(rightHandWorldLandmarksProto, 1337);
});
// Invoke the holistic landmarker
const landmarks = holisticLandmarker.detect({} as HTMLImageElement);
expect(holisticLandmarker.getGraphRunner().addGpuBufferAsImageToStream)
.toHaveBeenCalledTimes(1);
expect(holisticLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(landmarks).toEqual({
faceLandmarks: [[{x: 0, y: 0, z: 0}]],
faceBlendshapes: [{
categories: [{
index: 1,
score: 0.1,
categoryName: 'face_label',
displayName: 'face_display_name'
}],
headIndex: -1,
headName: ''
}],
poseLandmarks: [[{x: 0, y: 0, z: 0}]],
poseWorldLandmarks: [[{x: 0, y: 0, z: 0}]],
poseSegmentationMasks: [],
leftHandLandmarks: [[{x: 0, y: 0, z: 0}]],
leftHandWorldLandmarks: [[{x: 0, y: 0, z: 0}]],
rightHandLandmarks: [[{x: 0, y: 0, z: 0}]],
rightHandWorldLandmarks: [[{x: 0, y: 0, z: 0}]]
});
});
it('clears results between invoations', async () => {
const faceLandmarksProto = createLandmarks().serializeBinary();
const poseLandmarksProto = createLandmarks().serializeBinary();
const poseWorldLandmarksProto = createWorldLandmarks().serializeBinary();
const leftHandLandmarksProto = createLandmarks().serializeBinary();
const leftHandWorldLandmarksProto =
createWorldLandmarks().serializeBinary();
const rightHandLandmarksProto = createLandmarks().serializeBinary();
const rightHandWorldLandmarksProto =
createWorldLandmarks().serializeBinary();
// Pass the test data to our listener
holisticLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => {
holisticLandmarker.listeners.get('face_landmarks')!
(faceLandmarksProto, 1337);
holisticLandmarker.listeners.get('pose_landmarks')!
(poseLandmarksProto, 1337);
holisticLandmarker.listeners.get('pose_world_landmarks')!
(poseWorldLandmarksProto, 1337);
holisticLandmarker.listeners.get('left_hand_landmarks')!
(leftHandLandmarksProto, 1337);
holisticLandmarker.listeners.get('left_hand_world_landmarks')!
(leftHandWorldLandmarksProto, 1337);
holisticLandmarker.listeners.get('right_hand_landmarks')!
(rightHandLandmarksProto, 1337);
holisticLandmarker.listeners.get('right_hand_world_landmarks')!
(rightHandWorldLandmarksProto, 1337);
});
// Invoke the holistic landmarker twice
const landmarks1 = holisticLandmarker.detect({} as HTMLImageElement);
const landmarks2 = holisticLandmarker.detect({} as HTMLImageElement);
// Verify that landmarks2 is not a concatenation of all previously returned
// hands.
expect(landmarks1).toEqual(landmarks2);
});
});

View File

@ -23,7 +23,6 @@ import {FaceLandmarker as FaceLandmarkerImpl} from '../../../tasks/web/vision/fa
import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer';
import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
import {HolisticLandmarker as HolisticLandmarkerImpl} from '../../../tasks/web/vision/holistic_landmarker/holistic_landmarker';
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
@ -42,7 +41,6 @@ const FaceLandmarker = FaceLandmarkerImpl;
const FaceStylizer = FaceStylizerImpl; const FaceStylizer = FaceStylizerImpl;
const GestureRecognizer = GestureRecognizerImpl; const GestureRecognizer = GestureRecognizerImpl;
const HandLandmarker = HandLandmarkerImpl; const HandLandmarker = HandLandmarkerImpl;
const HolisticLandmarker = HolisticLandmarkerImpl;
const ImageClassifier = ImageClassifierImpl; const ImageClassifier = ImageClassifierImpl;
const ImageEmbedder = ImageEmbedderImpl; const ImageEmbedder = ImageEmbedderImpl;
const ImageSegmenter = ImageSegementerImpl; const ImageSegmenter = ImageSegementerImpl;
@ -60,7 +58,6 @@ export {
FaceStylizer, FaceStylizer,
GestureRecognizer, GestureRecognizer,
HandLandmarker, HandLandmarker,
HolisticLandmarker,
ImageClassifier, ImageClassifier,
ImageEmbedder, ImageEmbedder,
ImageSegmenter, ImageSegmenter,

View File

@ -19,7 +19,6 @@ mediapipe_ts_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":pose_landmarker_types", ":pose_landmarker_types",
":pose_landmarks_connections",
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto",
@ -33,6 +32,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:mask", "//mediapipe/tasks/web/vision/core:mask",
"//mediapipe/tasks/web/vision/core:types",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],
@ -67,9 +67,3 @@ jasmine_node_test(
tags = ["nomsan"], tags = ["nomsan"],
deps = [":pose_landmarker_test_lib"], deps = [":pose_landmarker_test_lib"],
) )
mediapipe_ts_library(
name = "pose_landmarks_connections",
srcs = ["pose_landmarks_connections.ts"],
deps = ["//mediapipe/tasks/web/vision/core:types"],
)

View File

@ -26,13 +26,13 @@ import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {MPMask} from '../../../../tasks/web/vision/core/mask'; import {MPMask} from '../../../../tasks/web/vision/core/mask';
import {convertToConnections} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {PoseLandmarkerOptions} from './pose_landmarker_options'; import {PoseLandmarkerOptions} from './pose_landmarker_options';
import {PoseLandmarkerResult} from './pose_landmarker_result'; import {PoseLandmarkerResult} from './pose_landmarker_result';
import {POSE_CONNECTIONS} from './pose_landmarks_connections';
export * from './pose_landmarker_options'; export * from './pose_landmarker_options';
export * from './pose_landmarker_result'; export * from './pose_landmarker_result';
@ -79,7 +79,12 @@ export class PoseLandmarker extends VisionTaskRunner {
* @export * @export
* @nocollapse * @nocollapse
*/ */
static POSE_CONNECTIONS = POSE_CONNECTIONS; static POSE_CONNECTIONS = convertToConnections(
[0, 1], [1, 2], [2, 3], [3, 7], [0, 4], [4, 5], [5, 6], [6, 8], [9, 10],
[11, 12], [11, 13], [13, 15], [15, 17], [15, 19], [15, 21], [17, 19],
[12, 14], [14, 16], [16, 18], [16, 20], [16, 22], [18, 20], [11, 23],
[12, 24], [23, 24], [23, 25], [24, 26], [25, 27], [26, 28], [27, 29],
[28, 30], [29, 31], [30, 32], [27, 31], [28, 32]);
/** /**
* Initializes the Wasm runtime and creates a new `PoseLandmarker` from the * Initializes the Wasm runtime and creates a new `PoseLandmarker` from the
@ -201,7 +206,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* callback returns. Only use this method when the PoseLandmarker is created * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `image`. * with running mode `image`.
* *
* @export
* @param image An image to process. * @param image An image to process.
* @param callback The callback that is invoked with the result. The * @param callback The callback that is invoked with the result. The
* lifetime of the returned masks is only guaranteed for the duration of * lifetime of the returned masks is only guaranteed for the duration of
@ -214,7 +218,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* callback returns. Only use this method when the PoseLandmarker is created * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `image`. * with running mode `image`.
* *
* @export
* @param image An image to process. * @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference. * to process the input image before running inference.
@ -232,7 +235,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* use this method when the PoseLandmarker is created with running mode * use this method when the PoseLandmarker is created with running mode
* `image`. * `image`.
* *
* @export
* @param image An image to process. * @param image An image to process.
* @return The landmarker result. Any masks are copied to avoid lifetime * @return The landmarker result. Any masks are copied to avoid lifetime
* limits. * limits.
@ -246,7 +248,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* use this method when the PoseLandmarker is created with running mode * use this method when the PoseLandmarker is created with running mode
* `image`. * `image`.
* *
* @export
* @param image An image to process. * @param image An image to process.
* @return The landmarker result. Any masks are copied to avoid lifetime * @return The landmarker result. Any masks are copied to avoid lifetime
* limits. * limits.
@ -279,7 +280,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* callback returns. Only use this method when the PoseLandmarker is created * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `video`. * with running mode `video`.
* *
* @export
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the result. The * @param callback The callback that is invoked with the result. The
@ -295,7 +295,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* callback returns. Only use this method when the PoseLandmarker is created * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `video`. * with running mode `video`.
* *
* @export
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
@ -314,7 +313,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* in high-throughput applications. Only use this method when the * in high-throughput applications. Only use this method when the
* PoseLandmarker is created with running mode `video`. * PoseLandmarker is created with running mode `video`.
* *
* @export
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @return The landmarker result. Any masks are copied to extend the * @return The landmarker result. Any masks are copied to extend the
@ -329,7 +327,6 @@ export class PoseLandmarker extends VisionTaskRunner {
* callback returns. Only use this method when the PoseLandmarker is created * callback returns. Only use this method when the PoseLandmarker is created
* with running mode `video`. * with running mode `video`.
* *
* @export
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how

View File

@ -1,28 +0,0 @@
/**
* Copyright 2023 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {convertToConnections} from '../../../../tasks/web/vision/core/types';
/**
* An array containing the pairs of pose landmark indices to be rendered with
* connections.
*/
export const POSE_CONNECTIONS = convertToConnections(
[0, 1], [1, 2], [2, 3], [3, 7], [0, 4], [4, 5], [5, 6], [6, 8], [9, 10],
[11, 12], [11, 13], [13, 15], [15, 17], [15, 19], [15, 21], [17, 19],
[12, 14], [14, 16], [16, 18], [16, 20], [16, 22], [18, 20], [11, 23],
[12, 24], [23, 24], [23, 25], [24, 26], [25, 27], [26, 28], [27, 29],
[28, 30], [29, 31], [30, 32], [27, 31], [28, 32]);

View File

@ -22,7 +22,6 @@ export * from '../../../tasks/web/vision/face_detector/face_detector';
export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker';
export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer';
export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
export * from '../../../tasks/web/vision/holistic_landmarker/holistic_landmarker';
export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
export * from '../../../tasks/web/vision/image_classifier/image_classifier'; export * from '../../../tasks/web/vision/image_classifier/image_classifier';
export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/image_embedder/image_embedder';

View File

@ -90,7 +90,7 @@ class BasePacketProcessor {
// CorrectPtsForRollover(2^33 - 1) -> 2^33 - 1 // CorrectPtsForRollover(2^33 - 1) -> 2^33 - 1
// CorrectPtsForRollover(0) -> 2^33 // PTS in media rolls over, corrected. // CorrectPtsForRollover(0) -> 2^33 // PTS in media rolls over, corrected.
// CorrectPtsForRollover(1) -> 2^33 + 1 // CorrectPtsForRollover(1) -> 2^33 + 1
int64_t CorrectPtsForRollover(int64_t media_pts); int64 CorrectPtsForRollover(int64 media_pts);
AVCodecContext* avcodec_ctx_ = nullptr; AVCodecContext* avcodec_ctx_ = nullptr;
const AVCodec* avcodec_ = nullptr; const AVCodec* avcodec_ = nullptr;
@ -113,7 +113,7 @@ class BasePacketProcessor {
AVRational source_frame_rate_; AVRational source_frame_rate_;
// The number of frames that were successfully processed. // The number of frames that were successfully processed.
int64_t num_frames_processed_ = 0; int64 num_frames_processed_ = 0;
int bytes_per_sample_ = 0; int bytes_per_sample_ = 0;
@ -121,7 +121,7 @@ class BasePacketProcessor {
bool last_frame_time_regression_detected_ = false; bool last_frame_time_regression_detected_ = false;
// The last rollover corrected PTS returned by CorrectPtsForRollover. // The last rollover corrected PTS returned by CorrectPtsForRollover.
int64_t rollover_corrected_last_pts_ = AV_NOPTS_VALUE; int64 rollover_corrected_last_pts_ = AV_NOPTS_VALUE;
// The buffer of current frames. // The buffer of current frames.
std::deque<Packet> buffer_; std::deque<Packet> buffer_;
@ -141,16 +141,16 @@ class AudioPacketProcessor : public BasePacketProcessor {
private: private:
// Appends audio in buffer(s) to the output buffer (buffer_). // Appends audio in buffer(s) to the output buffer (buffer_).
absl::Status AddAudioDataToBuffer(const Timestamp output_timestamp, absl::Status AddAudioDataToBuffer(const Timestamp output_timestamp,
uint8_t* const* raw_audio, uint8* const* raw_audio,
int buf_size_bytes); int buf_size_bytes);
// Converts a number of samples into an approximate stream timestamp value. // Converts a number of samples into an approximate stream timestamp value.
int64_t SampleNumberToTimestamp(const int64_t sample_number); int64 SampleNumberToTimestamp(const int64 sample_number);
int64_t TimestampToSampleNumber(const int64_t timestamp); int64 TimestampToSampleNumber(const int64 timestamp);
// Converts a timestamp/sample number to microseconds. // Converts a timestamp/sample number to microseconds.
int64_t TimestampToMicroseconds(const int64_t timestamp); int64 TimestampToMicroseconds(const int64 timestamp);
int64_t SampleNumberToMicroseconds(const int64_t sample_number); int64 SampleNumberToMicroseconds(const int64 sample_number);
// Returns an error if the sample format in avformat_ctx_.sample_format // Returns an error if the sample format in avformat_ctx_.sample_format
// is not supported. // is not supported.
@ -161,7 +161,7 @@ class AudioPacketProcessor : public BasePacketProcessor {
absl::Status ProcessDecodedFrame(const AVPacket& packet) override; absl::Status ProcessDecodedFrame(const AVPacket& packet) override;
// Corrects PTS for rollover if correction is enabled. // Corrects PTS for rollover if correction is enabled.
int64_t MaybeCorrectPtsForRollover(int64_t media_pts); int64 MaybeCorrectPtsForRollover(int64 media_pts);
// Number of channels to output. This value might be different from // Number of channels to output. This value might be different from
// the actual number of channels for the current AVPacket, found in // the actual number of channels for the current AVPacket, found in
@ -171,7 +171,7 @@ class AudioPacketProcessor : public BasePacketProcessor {
// Sample rate of the data to output. This value might be different // Sample rate of the data to output. This value might be different
// from the actual sample rate for the current AVPacket, found in // from the actual sample rate for the current AVPacket, found in
// avcodec_ctx_->sample_rate. // avcodec_ctx_->sample_rate.
int64_t sample_rate_ = -1; int64 sample_rate_ = -1;
// The time base of audio samples (i.e. the reciprocal of the sample rate). // The time base of audio samples (i.e. the reciprocal of the sample rate).
AVRational sample_time_base_; AVRational sample_time_base_;
@ -180,7 +180,7 @@ class AudioPacketProcessor : public BasePacketProcessor {
Timestamp last_timestamp_; Timestamp last_timestamp_;
// The expected sample number based on counting samples. // The expected sample number based on counting samples.
int64_t expected_sample_number_ = 0; int64 expected_sample_number_ = 0;
// Options for the processor. // Options for the processor.
AudioStreamOptions options_; AudioStreamOptions options_;

View File

@ -74,14 +74,14 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image,
// values used are those from ITU-R BT.601 (which are the same as ITU-R // values used are those from ITU-R BT.601 (which are the same as ITU-R
// BT.709). The conversion values are taken from wikipedia and cross // BT.709). The conversion values are taken from wikipedia and cross
// checked with other sources. // checked with other sources.
void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, // void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, //
uint8_t* y, uint8_t* cb, uint8_t* cr); uint8* y, uint8* cb, uint8* cr);
// Convert MPEG YCbCr values into sRGB values. See the SrgbToMpegYCbCr() // Convert MPEG YCbCr values into sRGB values. See the SrgbToMpegYCbCr()
// for more notes. Many MPEG YCbCr values do not correspond directly // for more notes. Many MPEG YCbCr values do not correspond directly
// to an sRGB value. If the value is invalid it will be clipped to the // to an sRGB value. If the value is invalid it will be clipped to the
// closest valid value on a per channel basis. // closest valid value on a per channel basis.
void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, // void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
uint8_t* r, uint8_t* g, uint8_t* b); uint8* r, uint8* g, uint8* b);
// Conversion functions to and from srgb and linear RGB in 16 bits-per-pixel // Conversion functions to and from srgb and linear RGB in 16 bits-per-pixel
// channel. // channel.

View File

@ -27,7 +27,7 @@ namespace mediapipe {
// both expected to contain one label per line. // both expected to contain one label per line.
// Returns an error e.g. if there's a mismatch between the number of labels and // Returns an error e.g. if there's a mismatch between the number of labels and
// display names. // display names.
absl::StatusOr<proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>> absl::StatusOr<proto_ns::Map<int64, ::mediapipe::LabelMapItem>>
BuildLabelMapFromFiles(absl::string_view labels_file_contents, BuildLabelMapFromFiles(absl::string_view labels_file_contents,
absl::string_view display_names_file); absl::string_view display_names_file);

View File

@ -292,7 +292,7 @@ class TimeSeriesCalculatorTest : public ::testing::Test {
// Overload to allow explicit conversion from int64 to Timestamp // Overload to allow explicit conversion from int64 to Timestamp
template <typename T> template <typename T>
void AppendInputPacket(const T* payload, const int64_t timestamp, void AppendInputPacket(const T* payload, const int64 timestamp,
const size_t input_index = 0) { const size_t input_index = 0) {
AppendInputPacket(payload, Timestamp(timestamp), input_index); AppendInputPacket(payload, Timestamp(timestamp), input_index);
} }
@ -305,7 +305,7 @@ class TimeSeriesCalculatorTest : public ::testing::Test {
} }
template <typename T> template <typename T>
void AppendInputPacket(const T* payload, const int64_t timestamp, void AppendInputPacket(const T* payload, const int64 timestamp,
const std::string& input_tag) { const std::string& input_tag) {
AppendInputPacket(payload, Timestamp(timestamp), input_tag); AppendInputPacket(payload, Timestamp(timestamp), input_tag);
} }
@ -450,7 +450,7 @@ class MultiStreamTimeSeriesCalculatorTest
// Overload to allow explicit conversion from int64 to Timestamp // Overload to allow explicit conversion from int64 to Timestamp
void AppendInputPacket(const std::vector<Matrix>* input_vector, void AppendInputPacket(const std::vector<Matrix>* input_vector,
const int64_t timestamp) { const int64 timestamp) {
AppendInputPacket(input_vector, Timestamp(timestamp)); AppendInputPacket(input_vector, Timestamp(timestamp));
} }

View File

@ -39,7 +39,7 @@ namespace time_series_util {
// function. // function.
bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp,
const Timestamp& initial_timestamp, const Timestamp& initial_timestamp,
int64_t cumulative_samples, int64 cumulative_samples,
double sample_rate); double sample_rate);
// Returns absl::Status::OK if the header is valid. Otherwise, returns a // Returns absl::Status::OK if the header is valid. Otherwise, returns a
@ -109,11 +109,11 @@ void SetExtensionInHeader(const TimeSeriesHeaderExtensionClass& extension,
} }
// Converts from a time_in_seconds to an integer number of samples. // Converts from a time_in_seconds to an integer number of samples.
int64_t SecondsToSamples(double time_in_seconds, double sample_rate); int64 SecondsToSamples(double time_in_seconds, double sample_rate);
// Converts from an integer number of samples to a time duration in seconds // Converts from an integer number of samples to a time duration in seconds
// spanned by the samples. // spanned by the samples.
double SamplesToSeconds(int64_t num_samples, double sample_rate); double SamplesToSeconds(int64 num_samples, double sample_rate);
} // namespace time_series_util } // namespace time_series_util
} // namespace mediapipe } // namespace mediapipe

View File

@ -52,26 +52,6 @@ declare global {
*/ */
declare function importScripts(...urls: Array<string|URL>): void; declare function importScripts(...urls: Array<string|URL>): void;
/**
* Detects image source size.
*/
export function getImageSourceSize(imageSource: ImageSource): [number, number] {
if ((imageSource as HTMLVideoElement).videoWidth !== undefined) {
const videoElement = imageSource as HTMLVideoElement;
return [videoElement.videoWidth, videoElement.videoHeight];
} else if ((imageSource as HTMLImageElement).naturalWidth !== undefined) {
// TODO: Ensure this works with SVG images
const imageElement = imageSource as HTMLImageElement;
return [imageElement.naturalWidth, imageElement.naturalHeight];
} else if ((imageSource as VideoFrame).displayWidth !== undefined) {
const videoFrame = imageSource as VideoFrame;
return [videoFrame.displayWidth, videoFrame.displayHeight];
} else {
const notVideoFrame = imageSource as Exclude<ImageSource, VideoFrame>;
return [notVideoFrame.width, notVideoFrame.height];
}
}
/** /**
* Simple class to run an arbitrary image-in/image-out MediaPipe graph (i.e. * Simple class to run an arbitrary image-in/image-out MediaPipe graph (i.e.
* as created by wasm_mediapipe_demo BUILD macro), and either render results * as created by wasm_mediapipe_demo BUILD macro), and either render results
@ -84,7 +64,7 @@ export class GraphRunner implements GraphRunnerApi {
// should be somewhat fixed when we create our .d.ts files. // should be somewhat fixed when we create our .d.ts files.
readonly wasmModule: WasmModule; readonly wasmModule: WasmModule;
readonly hasMultiStreamSupport: boolean; readonly hasMultiStreamSupport: boolean;
autoResizeCanvas = true; autoResizeCanvas: boolean = true;
audioPtr: number|null; audioPtr: number|null;
audioSize: number; audioSize: number;
@ -216,7 +196,18 @@ export class GraphRunner implements GraphRunnerApi {
gl.pixelStorei(gl.UNPACK_FLIP_Y_WEBGL, false); gl.pixelStorei(gl.UNPACK_FLIP_Y_WEBGL, false);
} }
const [width, height] = getImageSourceSize(imageSource); let width, height;
if ((imageSource as HTMLVideoElement).videoWidth) {
width = (imageSource as HTMLVideoElement).videoWidth;
height = (imageSource as HTMLVideoElement).videoHeight;
} else if ((imageSource as HTMLImageElement).naturalWidth) {
// TODO: Ensure this works with SVG images
width = (imageSource as HTMLImageElement).naturalWidth;
height = (imageSource as HTMLImageElement).naturalHeight;
} else {
width = imageSource.width;
height = imageSource.height;
}
if (this.autoResizeCanvas && if (this.autoResizeCanvas &&
(width !== this.wasmModule.canvas.width || (width !== this.wasmModule.canvas.width ||
@ -304,7 +295,7 @@ export class GraphRunner implements GraphRunnerApi {
* format). * format).
* *
* Consumers must deserialize the binary representation themselves as this * Consumers must deserialize the binary representation themselves as this
* avoids adding a direct dependency on the Protobuf JSPB target in the graph * avoids addding a direct dependency on the Protobuf JSPB target in the graph
* library. * library.
*/ */
getCalculatorGraphConfig( getCalculatorGraphConfig(

View File

@ -26,8 +26,8 @@ export {
/** /**
* Valid types of image sources which we can run our GraphRunner over. * Valid types of image sources which we can run our GraphRunner over.
*/ */
export type ImageSource = HTMLCanvasElement|HTMLVideoElement|HTMLImageElement| export type ImageSource =
ImageData|ImageBitmap|VideoFrame; HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap;
/** /**
* Simple interface for a class to run an arbitrary MediaPipe graph on web, and * Simple interface for a class to run an arbitrary MediaPipe graph on web, and

View File

@ -5,7 +5,7 @@
"devDependencies": { "devDependencies": {
"@bazel/jasmine": "^5.7.2", "@bazel/jasmine": "^5.7.2",
"@bazel/rollup": "^5.7.1", "@bazel/rollup": "^5.7.1",
"@bazel/typescript": "^5.8.1", "@bazel/typescript": "^5.7.1",
"@rollup/plugin-commonjs": "^23.0.2", "@rollup/plugin-commonjs": "^23.0.2",
"@rollup/plugin-node-resolve": "^15.0.1", "@rollup/plugin-node-resolve": "^15.0.1",
"@rollup/plugin-terser": "^0.1.0", "@rollup/plugin-terser": "^0.1.0",
@ -20,6 +20,6 @@
"protobufjs-cli": "^1.0.2", "protobufjs-cli": "^1.0.2",
"rollup": "^2.3.0", "rollup": "^2.3.0",
"ts-protoc-gen": "^0.15.0", "ts-protoc-gen": "^0.15.0",
"typescript": "^5.3.3" "typescript": "^4.8.4"
} }
} }

View File

@ -2,4 +2,4 @@
# The next version of MediaPipe (e.g. the version that is currently in development). # The next version of MediaPipe (e.g. the version that is currently in development).
# This version should be bumped after every release. # This version should be bumped after every release.
MEDIAPIPE_FULL_VERSION = "0.10.10" MEDIAPIPE_FULL_VERSION = "0.10.9"

1766
yarn.lock

File diff suppressed because it is too large Load Diff