diff --git a/mediapipe/calculators/image_style/BUILD b/mediapipe/calculators/image_style/BUILD index c4f4325f2..fd8e84d39 100644 --- a/mediapipe/calculators/image_style/BUILD +++ b/mediapipe/calculators/image_style/BUILD @@ -18,14 +18,24 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +mediapipe_proto_library( + name = "fast_utils_calculator_proto", + srcs = ["fast_utils_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + cc_library( name = "fast_utils_calculator", srcs = ["fast_utils_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":fast_utils_calculator_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -38,8 +48,32 @@ cc_library( "//mediapipe/framework/port:opencv_highgui", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", - "//mediapipe/util:annotation_renderer", - "//mediapipe/util:render_data_cc_proto", + ], + alwayslink = 1, +) + +cc_library( + name = "apply_mask_calculator", + srcs = ["apply_mask_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "@com_google_absl//absl/strings", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:file_helpers", + "//mediapipe/util:resource_util", ], alwayslink = 1, ) @@ -49,3 +83,5 @@ cc_library( + + diff --git a/mediapipe/calculators/image_style/apply_mask_calculator.cc b/mediapipe/calculators/image_style/apply_mask_calculator.cc new file mode 100644 index 000000000..35f96f3b7 --- /dev/null +++ b/mediapipe/calculators/image_style/apply_mask_calculator.cc @@ -0,0 +1,305 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/vector.h" + +using namespace std; +namespace mediapipe +{ + namespace + { + static const std::vector FFHQ_NORM_LM = { + {638.68525475 / 1024, 486.24604922 / 1024}, + {389.31496114 / 1024, 485.8921848 / 1024}, + {513.67979275 / 1024, 620.8915371 / 1024}, + {405.50932642 / 1024, 756.52797927 / 1024}, + {622.55630397 / 1024, 756.15509499 / 1024}}; + + constexpr char kImageFrameTag[] = "IMAGE"; + constexpr char kFakeBgTag[] = "FAKE_BG"; + constexpr char kLmMaskTag[] = "LM_MASK"; + + inline bool HasImageTag(mediapipe::CalculatorContext *cc) { return false; } + + cv::Mat blend_mask(cv::Mat mask_face, cv::Mat mask_bbox, int kernel_size = 33, int reduce_size = 128) + { + int k_sz = kernel_size; + auto [width, height] = mask_face.size(); + + cv::Mat mask_face_0 = mask_face.clone(); + + double K = (double)reduce_size / std::min(height, width); + + cv::resize(mask_face, mask_face, {(int)(width * K), (int)(height * K)}); + mask_face.convertTo(mask_face, CV_32F); + + cv::GaussianBlur(mask_face, mask_face, {k_sz, k_sz}, 0); + mask_face *= 2; + cv::threshold(mask_face, mask_face, 1, 255, CV_THRESH_TRUNC); + + cv::resize(mask_bbox, mask_bbox, {(int)(width * K), (int)(height * K)}); + + mask_bbox.convertTo(mask_bbox, CV_32F); + cv::GaussianBlur(mask_bbox, mask_bbox, {k_sz, k_sz}, 0); + + cv::Mat mask_bbox_3ch; + cv::merge(std::vector{mask_bbox, mask_bbox, mask_bbox}, mask_bbox_3ch); + + cv::Mat mask = mask_bbox_3ch.mul(mask_face); + + cv::Mat img_out; + cv::resize(mask, img_out, {width, height}); + + for (int i = 1; i < mask_face_0.rows; i++) + { + for (int j = 1; j < mask_face_0.cols; j++) + { + if (mask_face_0.at(i, j) > 0) + img_out.at(i, j) = 1; + } + } + + return img_out; + } + } // namespace + + class ApplyMaskCalculator : public CalculatorBase + { + public: + ApplyMaskCalculator() = default; + ~ApplyMaskCalculator() override = default; + + static absl::Status GetContract(CalculatorContract *cc); + + // From Calculator. + absl::Status Open(CalculatorContext *cc) override; + absl::Status Process(CalculatorContext *cc) override; + absl::Status Close(CalculatorContext *cc) override; + + private: + absl::Status CreateRenderTargetCpu(CalculatorContext *cc, + std::unique_ptr &image_mat, + std::string_view tag, + ImageFormat::Format *target_format); + + absl::Status RenderToCpu( + CalculatorContext *cc, const ImageFormat::Format &target_format, + uchar *data_image, std::unique_ptr &image_mat); + + // Indicates if image frame is available as input. + bool image_frame_available_ = false; + int image_width_; + int image_height_; + }; + REGISTER_CALCULATOR(ApplyMaskCalculator); + + absl::Status ApplyMaskCalculator::GetContract(CalculatorContract *cc) + { + CHECK_GE(cc->Inputs().NumEntries(), 1); + + if (cc->Inputs().HasTag(kImageFrameTag)) + { + cc->Inputs().Tag(kImageFrameTag).Set(); + CHECK(cc->Outputs().HasTag(kImageFrameTag)); + } + + if (cc->Inputs().HasTag(kFakeBgTag)) + { + cc->Inputs().Tag(kFakeBgTag).Set(); + } + if (cc->Inputs().HasTag(kLmMaskTag)) + { + cc->Inputs().Tag(kLmMaskTag).Set(); + } + if (cc->Outputs().HasTag(kImageFrameTag)) + { + cc->Outputs().Tag(kImageFrameTag).Set(); + } + + return absl::OkStatus(); + } + + absl::Status ApplyMaskCalculator::Open(CalculatorContext *cc) + { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Inputs().HasTag(kImageFrameTag) || HasImageTag(cc)) + { + image_frame_available_ = true; + } + + // Set the output header based on the input header (if present). + const char *tag = kImageFrameTag; + if (image_frame_available_ && !cc->Inputs().Tag(tag).Header().IsEmpty()) + { + const auto &input_header = + cc->Inputs().Tag(tag).Header().Get(); + auto *output_video_header = new VideoHeader(input_header); + cc->Outputs().Tag(tag).SetHeader(Adopt(output_video_header)); + } + + return absl::OkStatus(); + } + + absl::Status ApplyMaskCalculator::Process(CalculatorContext *cc) + { + if (cc->Inputs().HasTag(kImageFrameTag) && + cc->Inputs().Tag(kImageFrameTag).IsEmpty()) + { + return absl::OkStatus(); + } + // Initialize render target, drawn with OpenCV. + ImageFormat::Format target_format; + std::unique_ptr image_mat; + MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, kImageFrameTag, &target_format)); + + if (((cc->Inputs().HasTag(kFakeBgTag) && + !cc->Inputs().Tag(kFakeBgTag).IsEmpty())) && + ((cc->Inputs().HasTag(kLmMaskTag) && + !cc->Inputs().Tag(kLmMaskTag).IsEmpty()))) + { + // Initialize render target, drawn with OpenCV. + std::unique_ptr fake_bg; + std::unique_ptr lm_mask_ptr; + + MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, fake_bg, kFakeBgTag, &target_format)); + MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, lm_mask_ptr, kLmMaskTag, &target_format)); + + cv::Mat mat_fake_bg_ = *fake_bg.get(); + cv::Mat mat_image_ = *image_mat.get(); + cv::Mat lm_mask = *lm_mask_ptr.get(); + + image_width_ = image_mat->cols; + image_height_ = image_mat->rows; + + cv::Mat roi_mask = mat_image_.clone(); + + cv::transform(roi_mask, roi_mask, cv::Matx13f(1, 1, 1)); + cv::threshold(roi_mask, roi_mask, 1, 255, CV_THRESH_TRUNC); + + cv::Mat mask = blend_mask(lm_mask, roi_mask, 33); + + mat_image_.convertTo(mat_image_, CV_32F); + mat_fake_bg_.convertTo(mat_fake_bg_, CV_32F); + cv::resize(mat_fake_bg_, mat_fake_bg_, {image_width_, image_height_}); + + cv::Mat im_out = mat_fake_bg_.mul(cv::Scalar::all(1) - mask) + mat_image_.mul(mask); + + im_out.convertTo(*image_mat, CV_8U); + } + uchar *image_mat_ptr = image_mat->data; + MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr, image_mat)); + + return absl::OkStatus(); + } + + absl::Status ApplyMaskCalculator::Close(CalculatorContext *cc) + { + return absl::OkStatus(); + } + + absl::Status ApplyMaskCalculator::RenderToCpu( + CalculatorContext *cc, const ImageFormat::Format &target_format, + uchar *data_image, std::unique_ptr &image_mat) + { + auto output_frame = absl::make_unique( + target_format, image_mat->cols, image_mat->rows); + + output_frame->CopyPixelData(target_format, image_mat->cols, image_mat->rows, data_image, + ImageFrame::kDefaultAlignmentBoundary); + + if (cc->Outputs().HasTag(kImageFrameTag)) + { + cc->Outputs() + .Tag(kImageFrameTag) + .Add(output_frame.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); + } + + absl::Status ApplyMaskCalculator::CreateRenderTargetCpu( + CalculatorContext *cc, std::unique_ptr &image_mat, std::string_view tag, + ImageFormat::Format *target_format) + { + if (image_frame_available_) + { + const auto &input_frame = + cc->Inputs().Tag(tag).Get(); + + int target_mat_type; + switch (input_frame.Format()) + { + case ImageFormat::SRGBA: + *target_format = ImageFormat::SRGBA; + target_mat_type = CV_8UC4; + break; + case ImageFormat::SRGB: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + case ImageFormat::GRAY8: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + default: + return absl::UnknownError("Unexpected image frame format."); + break; + } + + image_mat = absl::make_unique( + input_frame.Height(), input_frame.Width(), target_mat_type); + + auto input_mat = formats::MatView(&input_frame); + + if (input_frame.Format() == ImageFormat::GRAY8) + { + cv::Mat rgb_mat; + cv::cvtColor(input_mat, rgb_mat, CV_GRAY2RGB); + rgb_mat.copyTo(*image_mat); + } + else + { + input_mat.copyTo(*image_mat); + } + } + else + { + image_mat = absl::make_unique( + 1920, 1080, CV_8UC4, + cv::Scalar::all(255)); + *target_format = ImageFormat::SRGBA; + } + + return absl::OkStatus(); + } +} // namespace mediapipe diff --git a/mediapipe/calculators/image_style/fast_utils_calculator.cc b/mediapipe/calculators/image_style/fast_utils_calculator.cc index 5a995d817..8039c43f6 100644 --- a/mediapipe/calculators/image_style/fast_utils_calculator.cc +++ b/mediapipe/calculators/image_style/fast_utils_calculator.cc @@ -22,6 +22,7 @@ #include #include "absl/strings/str_cat.h" +#include "mediapipe/calculators/image_style/fast_utils_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_options.pb.h" #include "mediapipe/framework/formats/image_format.pb.h" @@ -41,7 +42,7 @@ namespace mediapipe { namespace { - static const std::vector FFHQ_NORM_LM = { + const std::vector FFHQ_NORM_LM = { {638.68525475 / 1024, 486.24604922 / 1024}, {389.31496114 / 1024, 485.8921848 / 1024}, {513.67979275 / 1024, 620.8915371 / 1024}, @@ -52,16 +53,8 @@ namespace mediapipe constexpr char kVectorTag[] = "VECTOR"; constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; - - std::tuple _normalized_to_pixel_coordinates(float normalized_x, - float normalized_y, int image_width, int image_height) - { - // Converts normalized value pair to pixel coordinates - int x_px = std::min(floor(normalized_x * image_width), image_width - 1); - int y_px = std::min(floor(normalized_y * image_height), image_height - 1); - - return {x_px, y_px}; - }; + constexpr char kLmMaskTag[] = "LM_MASK"; + constexpr char kSizeTag[] = "SIZE"; static const std::vector FACEMESH_FACE_OVAL{ {10, 338}, {338, 297}, {297, 332}, {332, 284}, {284, 251}, {251, 389}, {389, 356}, {356, 454}, {454, 323}, {323, 361}, {361, 288}, {288, 397}, {397, 365}, {365, 379}, {379, 378}, {378, 400}, {400, 377}, {377, 152}, {152, 148}, {148, 176}, {176, 149}, {149, 150}, {150, 136}, {136, 172}, {172, 58}, {58, 132}, {132, 93}, {93, 234}, {234, 127}, {127, 162}, {162, 21}, {21, 54}, {54, 103}, {103, 67}, {67, 109}, {109, 10}}; @@ -113,16 +106,15 @@ namespace mediapipe cv::Mat &source, cv::Mat &target, float eps = 1e-7) { - cv::Mat source_mean_mat, target_mean_mat, source1ch, target1ch; - + cv::Mat source_mean_mat, target_mean_mat; cv::reduce(source, source_mean_mat, 0, CV_REDUCE_AVG, CV_32F); cv::reduce(target, target_mean_mat, 0, CV_REDUCE_AVG, CV_32F); source -= {source_mean_mat.at(0, 0), source_mean_mat.at(0, 1)}; target -= {target_mean_mat.at(0, 0), target_mean_mat.at(0, 1)}; - source1ch = source.reshape(1, 5); - target1ch = target.reshape(1, 5); + cv::Mat source1ch = source.reshape(1, 5); + cv::Mat target1ch = target.reshape(1, 5); cv::Mat source_std_mat, target_std_mat; cv::meanStdDev(source1ch, cv::noArray(), source_std_mat); @@ -136,21 +128,21 @@ namespace mediapipe source /= source_std + eps; target /= target_std + eps; - cv::Mat u, vt, rotation, w; - source1ch = source.reshape(1, 5); target1ch = target.reshape(1, 5); + cv::Mat u, vt, w; cv::SVD::compute(source1ch.t() * target1ch, w, u, vt); - rotation = (u * vt).t(); + cv::Mat rotation = (u * vt).t(); + + float scale = target_std / (source_std + eps); - float scale = target_std / source_std + eps; cv::Mat translation; + cv::subtract(target_mean_mat.reshape(1, 2), + scale * rotation * source_mean_mat.reshape(1, 2), translation); - cv::subtract(target_mean_mat.reshape(1, 2), scale * rotation * source_mean_mat.reshape(1, 2), translation); - - return std::make_tuple(scale, rotation, translation); + return {scale, rotation, translation}; } std::tuple Crop( @@ -210,6 +202,9 @@ namespace mediapipe absl::Status Process(CalculatorContext *cc) override; absl::Status Close(CalculatorContext *cc) override; + protected: + mediapipe::FastUtilsCalculatorOptions options_; + private: absl::Status CreateRenderTargetCpu(CalculatorContext *cc, std::unique_ptr &image_mat, @@ -217,7 +212,7 @@ namespace mediapipe absl::Status RenderToCpu( CalculatorContext *cc, const ImageFormat::Format &target_format, - uchar *data_image, std::unique_ptr &image_mat); + uchar *data_image, std::unique_ptr &image_mat, std::string_view tag); absl::Status Call(CalculatorContext *cc, std::unique_ptr &image_mat, @@ -231,19 +226,20 @@ namespace mediapipe // Indicates if image frame is available as input. bool image_frame_available_ = false; - std::vector>> index_dict = { + + const std::vector>> index_dict = { {"leftEye", {384, 385, 386, 387, 388, 390, 263, 362, 398, 466, 373, 374, 249, 380, 381, 382}}, {"rightEye", {160, 33, 161, 163, 133, 7, 173, 144, 145, 246, 153, 154, 155, 157, 158, 159}}, {"nose", {4}}, - //{"lips", {0, 13, 14, 17, 84}}, {"leftLips", {61, 146}}, {"rightLips", {291, 375}}, - }; - - std::unique_ptr image_mat; + }; + cv::Mat mat_image_; + cv::Mat lm_mask; int image_width_; int image_height_; + bool back_to_im; }; REGISTER_CALCULATOR(FastUtilsCalculator); @@ -273,18 +269,29 @@ namespace mediapipe { cc->Inputs().Tag(kNormLandmarksTag).Set>(); } + if (cc->Inputs().HasTag(kSizeTag)) + { + cc->Inputs().Tag(kSizeTag).Set>(); + } if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs().Tag(kImageFrameTag).Set(); } + if (cc->Outputs().HasTag(kLmMaskTag)) + { + cc->Outputs().Tag(kLmMaskTag).Set(); + } + return absl::OkStatus(); } absl::Status FastUtilsCalculator::Open(CalculatorContext *cc) { cc->SetOffset(TimestampDiff(0)); + options_ = cc->Options(); + back_to_im = options_.back_to_image(); if (cc->Inputs().HasTag(kImageFrameTag) || HasImageTag(cc)) { @@ -313,7 +320,9 @@ namespace mediapipe } // Initialize render target, drawn with OpenCV. + std::unique_ptr image_mat; ImageFormat::Format target_format; + ImageFormat::Format target_format2; std::vector> lms_out; MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); @@ -326,14 +335,42 @@ namespace mediapipe { MP_RETURN_IF_ERROR(Call(cc, image_mat, target_format, lms_out)); - cv::Mat source_lm = cv::Mat(lms_out[0]); + if (cc->Outputs().HasTag(kLmMaskTag)) + { + lm_mask.convertTo(lm_mask, CV_8U); - MP_RETURN_IF_ERROR(Align(image_mat, source_lm)); + std::unique_ptr lm_mask_ptr = absl::make_unique( + mat_image_.size(), lm_mask.type()); + + lm_mask.copyTo(*lm_mask_ptr); + + target_format2 = ImageFormat::GRAY8; + uchar *lm_mask_pt = lm_mask_ptr->data; + + MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format2, lm_mask_pt, lm_mask_ptr, kLmMaskTag)); + } + + if (!back_to_im) + { + MP_RETURN_IF_ERROR(Align(image_mat, cv::Mat(lms_out[0]))); + } + else + { + const auto &size = + cc->Inputs().Tag(kSizeTag).Get>(); + cv::Mat tar = cv::Mat(FFHQ_NORM_LM) * 256; + + MP_RETURN_IF_ERROR(Align(image_mat, tar, + cv::Mat(lms_out[0]), {size.first, size.second})); + } + uchar *image_mat_ptr = image_mat->data; + MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr, image_mat, kImageFrameTag)); + } + else + { + uchar *image_mat_ptr = image_mat->data; + MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr, image_mat, kImageFrameTag)); } - - uchar *image_mat_ptr = image_mat->data; - MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr, image_mat)); - return absl::OkStatus(); } @@ -344,7 +381,7 @@ namespace mediapipe absl::Status FastUtilsCalculator::RenderToCpu( CalculatorContext *cc, const ImageFormat::Format &target_format, - uchar *data_image, std::unique_ptr &image_mat) + uchar *data_image, std::unique_ptr &image_mat, std::string_view tag) { auto output_frame = absl::make_unique( target_format, image_mat->cols, image_mat->rows); @@ -352,10 +389,10 @@ namespace mediapipe output_frame->CopyPixelData(target_format, image_mat->cols, image_mat->rows, data_image, ImageFrame::kDefaultAlignmentBoundary); - if (cc->Outputs().HasTag(kImageFrameTag)) + if (cc->Outputs().HasTag(tag)) { cc->Outputs() - .Tag(kImageFrameTag) + .Tag(tag) .Add(output_frame.release(), cc->InputTimestamp()); } @@ -410,9 +447,8 @@ namespace mediapipe else { image_mat = absl::make_unique( - 150, 150, CV_8UC4, - cv::Scalar(255, 255, - 255)); + 1920, 1080, CV_8UC4, + cv::Scalar::all(255)); *target_format = ImageFormat::SRGBA; } @@ -424,8 +460,6 @@ namespace mediapipe ImageFormat::Format &target_format, std::vector> &lms_out) { - std::vector kps, landmarks; - if (cc->Inputs().HasTag(kNormLandmarksTag)) { const std::vector &landmarkslist = @@ -434,11 +468,12 @@ namespace mediapipe std::vector point_array; for (const auto &face : landmarkslist) { + std::vector landmarks = {}; for (const auto &[key, value] : index_dict) { + std::vector kps = {}; for (auto order : value) { - const NormalizedLandmark &landmark = face.landmark(order); if (!IsLandmarkVisibleAndPresent( @@ -449,11 +484,13 @@ namespace mediapipe continue; } + const auto &size = + cc->Inputs().Tag(kSizeTag).Get>(); const auto &point = landmark; int x = -1; int y = -1; - CHECK(NormalizedtoPixelCoordinates(point.x(), point.y(), image_width_, - image_height_, &x, &y)); + CHECK(NormalizedtoPixelCoordinates(point.x(), point.y(), size.first, + size.second, &x, &y)); kps.push_back(cv::Point2f(x, y)); } @@ -461,12 +498,29 @@ namespace mediapipe cv::reduce(kps, mean, 1, CV_REDUCE_AVG, CV_32F); landmarks.push_back({mean.at(0, 0), mean.at(0, 1)}); - - kps.clear(); } lms_out.push_back(landmarks); + } + if (cc->Outputs().HasTag(kLmMaskTag)) + { + std::vector kpsint = {}; + for (auto &ix : FACEMESH_FACE_OVAL) + { + auto i = ix.x; - landmarks.clear(); + const NormalizedLandmark &landmark = landmarkslist[0].landmark(i); + + const auto &point = landmark; + int x = -1; + int y = -1; + CHECK(NormalizedtoPixelCoordinates(point.x(), point.y(), image_width_, + image_height_, &x, &y)); + kpsint.push_back(cv::Point(x, y)); + } + std::vector> pts; + pts.push_back(kpsint); + lm_mask = cv::Mat::zeros(image_mat->size(), CV_32FC1); + cv::fillPoly(lm_mask, pts, cv::Scalar::all(1), cv::LINE_AA); } } @@ -478,6 +532,8 @@ namespace mediapipe cv::Mat target_lm, cv::Size size, float extend, std::tuple roi) { + cv::Mat mat_image_ = *image_mat.get(); + cv::Mat source, target; source_lm.convertTo(source, CV_32F); target_lm.convertTo(target, CV_32F); diff --git a/mediapipe/calculators/image_style/fast_utils_calculator.proto b/mediapipe/calculators/image_style/fast_utils_calculator.proto new file mode 100644 index 000000000..65b003f7e --- /dev/null +++ b/mediapipe/calculators/image_style/fast_utils_calculator.proto @@ -0,0 +1,27 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message FastUtilsCalculatorOptions { + extend CalculatorOptions { + optional FastUtilsCalculatorOptions ext = 251431399; + } + // Change color and size of rendered landmarks based on its z value. + optional bool back_to_image = 1 [default = false]; +} diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 586fb0dd3..a77968e8c 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -893,3 +893,34 @@ cc_library( }), alwayslink = 1, ) + +cc_library( + name = "tensors_to_image_calculator", + srcs = ["tensors_to_image_calculator.cc"], + copts = select({ + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_opencv", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:port", + "//mediapipe/util:resource_util", + "@org_tensorflow//tensorflow/lite:framework", + "//mediapipe/framework/port:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/calculators/tensor/tensors_to_image_calculator.cc b/mediapipe/calculators/tensor/tensors_to_image_calculator.cc new file mode 100644 index 000000000..3221bae9b --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_to_image_calculator.cc @@ -0,0 +1,201 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/util/resource_util.h" +#include "tensorflow/lite/interpreter.h" + +namespace +{ + constexpr char kTensorsTag[] = "TENSORS"; + constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; + constexpr char kImageTag[] = "IMAGE"; + + absl::StatusOr> GetHwcFromDims( + const std::vector &dims) + { + if (dims.size() == 3) + { + return std::make_tuple(dims[0], dims[1], dims[2]); + } + else if (dims.size() == 4) + { + // BHWC format check B == 1 + RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } + else + { + RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); + } + } +} // namespace + +namespace mediapipe +{ + + // Converts Tensors from a tflite segmentation model to an image. + // + // Performs optional upscale to OUTPUT_SIZE dimensions if provided, + // otherwise the image is the same size as input tensor. + // + // + // + // Inputs: + // One of the following TENSORS tags: + // TENSORS: Vector of Tensor, + // The tensor dimensions are specified in this calculator's options. + // OUTPUT_SIZE(optional): std::pair, + // If provided, the size to upscale mask to. + // + // Output: + // IMAGE: An Image output, RGBA. + // + // + // Usage example: + // node { + // calculator: "TensorsToImageCalculator" + // input_stream: "TENSORS:tensors" + // input_stream: "OUTPUT_SIZE:size" + // output_stream: "IMAGE:image" + // } + // + // TODO Refactor and add support for other backends/platforms. + // + class TensorsToImageCalculator : public CalculatorBase + { + public: + static absl::Status GetContract(CalculatorContract *cc); + + absl::Status Open(CalculatorContext *cc) override; + absl::Status Process(CalculatorContext *cc) override; + absl::Status Close(CalculatorContext *cc) override; + + private: + absl::Status ProcessCpu(CalculatorContext *cc); + + }; + REGISTER_CALCULATOR(TensorsToImageCalculator); + + // static + absl::Status TensorsToImageCalculator::GetContract( + CalculatorContract *cc) + { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + + // Inputs. + cc->Inputs().Tag(kTensorsTag).Set>(); + if (cc->Inputs().HasTag(kOutputSizeTag)) + { + cc->Inputs().Tag(kOutputSizeTag).Set>(); + } + + // Outputs. + cc->Outputs().Tag(kImageTag).Set(); + + return absl::OkStatus(); + } + + absl::Status TensorsToImageCalculator::Open(CalculatorContext *cc) + { + cc->SetOffset(TimestampDiff(0)); + + return absl::OkStatus(); + } + + absl::Status TensorsToImageCalculator::Process(CalculatorContext *cc) + { + if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) + { + return absl::OkStatus(); + } + + const auto &input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + + MP_RETURN_IF_ERROR(ProcessCpu(cc)); + + return absl::OkStatus(); + } + + absl::Status TensorsToImageCalculator::Close(CalculatorContext *cc) + { + + return absl::OkStatus(); + } + + absl::Status TensorsToImageCalculator::ProcessCpu( + CalculatorContext *cc) + { + // Get input streams, and dimensions. + const auto &input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) + { + const auto &size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + cv::Mat image_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float *raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + std::vector channels(4); + cv::split(tensor_mat, channels); + for (auto ch : channels) + ch = (ch + 1) * 127.5; + + cv::merge(channels, tensor_mat); + + cv::convertScaleAbs(tensor_mat, tensor_mat); + + // Send out image as CPU packet. + std::shared_ptr image_frame = std::make_shared( + ImageFormat::SRGB, output_width, output_height); + std::unique_ptr output_image = absl::make_unique(image_frame); + auto output_mat = formats::MatView(output_image.get()); + // Upsample image into output. + cv::resize(tensor_mat, *output_mat, + cv::Size(output_width, output_height)); + cc->Outputs().Tag(kImageTag).Add(output_image.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 783359faa..21f983894 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -13,7 +13,6 @@ // limitations under the License. #include -#include #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -21,10 +20,8 @@ #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/image_opencv.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port.h" -#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/gpu/gpu_origin.pb.h" @@ -36,7 +33,12 @@ #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU + +#if !MEDIAPIPE_DISABLE_OPENCV +#include "mediapipe/framework/formats/image_opencv.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#endif // !MEDIAPIPE_DISABLE_OPENCV #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #include "tensorflow/lite/delegates/gpu/gl/converters/util.h" @@ -44,7 +46,7 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_METAL_ENABLED #import @@ -53,378 +55,347 @@ #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" -#endif // MEDIAPIPE_METAL_ENABLED +#endif // MEDIAPIPE_METAL_ENABLED -namespace -{ - constexpr int kWorkgroupSize = 8; // Block size for GPU shader. - enum - { - ATTRIB_VERTEX, - ATTRIB_TEXTURE_POSITION, - NUM_ATTRIBUTES - }; +namespace { +constexpr int kWorkgroupSize = 8; // Block size for GPU shader. +enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; - // Commonly used to compute the number of blocks to launch in a kernel. - int NumGroups(const int size, const int group_size) - { // NOLINT - return (size + group_size - 1) / group_size; - } +// Commonly used to compute the number of blocks to launch in a kernel. +int NumGroups(const int size, const int group_size) { // NOLINT + return (size + group_size - 1) / group_size; +} - bool CanUseGpu() - { +bool CanUseGpu() { #if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED - // TODO: Configure GPU usage policy in individual calculators. - constexpr bool kAllowGpuProcessing = true; - return kAllowGpuProcessing; + // TODO: Configure GPU usage policy in individual calculators. + constexpr bool kAllowGpuProcessing = true; + return kAllowGpuProcessing; #else - return false; -#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED + return false; +#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED +} + +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kMaskTag[] = "MASK"; + +absl::StatusOr> GetHwcFromDims( + const std::vector& dims) { + if (dims.size() == 3) { + return std::make_tuple(dims[0], dims[1], dims[2]); + } else if (dims.size() == 4) { + // BHWC format check B == 1 + RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } else { + RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); } +} +} // namespace - constexpr char kTensorsTag[] = "TENSORS"; - constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; - constexpr char kMaskTag[] = "MASK"; - - absl::StatusOr> GetHwcFromDims( - const std::vector &dims) - { - if (dims.size() == 3) - { - return std::make_tuple(dims[0], dims[1], dims[2]); - } - else if (dims.size() == 4) - { - // BHWC format check B == 1 - RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; - return std::make_tuple(dims[1], dims[2], dims[3]); - } - else - { - RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); - } - } -} // namespace - -namespace mediapipe -{ +namespace mediapipe { #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - using ::tflite::gpu::gl::GlProgram; - using ::tflite::gpu::gl::GlShader; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +using ::tflite::gpu::gl::GlProgram; +using ::tflite::gpu::gl::GlShader; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - // Converts Tensors from a tflite segmentation model to an image mask. - // - // Performs optional upscale to OUTPUT_SIZE dimensions if provided, - // otherwise the mask is the same size as input tensor. - // - // If at least one input tensor is already on GPU, processing happens on GPU and - // the output mask is also stored on GPU. Otherwise, processing and the output - // mask are both on CPU. - // - // On GPU, the mask is an RGBA image, in both the R & A channels, scaled 0-1. - // On CPU, the mask is a ImageFormat::VEC32F1 image, with values scaled 0-1. - // - // - // Inputs: - // One of the following TENSORS tags: - // TENSORS: Vector of Tensor, - // The tensor dimensions are specified in this calculator's options. - // OUTPUT_SIZE(optional): std::pair, - // If provided, the size to upscale mask to. - // - // Output: - // MASK: An Image output mask, RGBA(GPU) / VEC32F1(CPU). - // - // Options: - // See tensors_to_segmentation_calculator.proto - // - // Usage example: - // node { - // calculator: "TensorsToSegmentationCalculator" - // input_stream: "TENSORS:tensors" - // input_stream: "OUTPUT_SIZE:size" - // output_stream: "MASK:hair_mask" - // node_options: { - // [mediapipe.TensorsToSegmentationCalculatorOptions] { - // output_layer_index: 1 - // # gpu_origin: CONVENTIONAL # or TOP_LEFT - // } - // } - // } - // - // TODO Refactor and add support for other backends/platforms. - // - class TensorsToSegmentationCalculator : public CalculatorBase - { - public: - static absl::Status GetContract(CalculatorContract *cc); +// Converts Tensors from a tflite segmentation model to an image mask. +// +// Performs optional upscale to OUTPUT_SIZE dimensions if provided, +// otherwise the mask is the same size as input tensor. +// +// If at least one input tensor is already on GPU, processing happens on GPU and +// the output mask is also stored on GPU. Otherwise, processing and the output +// mask are both on CPU. +// +// On GPU, the mask is an RGBA image, in both the R & A channels, scaled 0-1. +// On CPU, the mask is a ImageFormat::VEC32F1 image, with values scaled 0-1. +// +// +// Inputs: +// One of the following TENSORS tags: +// TENSORS: Vector of Tensor, +// The tensor dimensions are specified in this calculator's options. +// OUTPUT_SIZE(optional): std::pair, +// If provided, the size to upscale mask to. +// +// Output: +// MASK: An Image output mask, RGBA(GPU) / VEC32F1(CPU). +// +// Options: +// See tensors_to_segmentation_calculator.proto +// +// Usage example: +// node { +// calculator: "TensorsToSegmentationCalculator" +// input_stream: "TENSORS:tensors" +// input_stream: "OUTPUT_SIZE:size" +// output_stream: "MASK:hair_mask" +// node_options: { +// [mediapipe.TensorsToSegmentationCalculatorOptions] { +// output_layer_index: 1 +// # gpu_origin: CONVENTIONAL # or TOP_LEFT +// } +// } +// } +// +// TODO Refactor and add support for other backends/platforms. +// +class TensorsToSegmentationCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); - absl::Status Open(CalculatorContext *cc) override; - absl::Status Process(CalculatorContext *cc) override; - absl::Status Close(CalculatorContext *cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; - private: - absl::Status LoadOptions(CalculatorContext *cc); - absl::Status InitGpu(CalculatorContext *cc); - absl::Status ProcessGpu(CalculatorContext *cc); - absl::Status ProcessCpu(CalculatorContext *cc); - void GlRender(); + private: + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status ProcessGpu(CalculatorContext* cc); + absl::Status ProcessCpu(CalculatorContext* cc); + void GlRender(); - bool DoesGpuTextureStartAtBottom() - { - return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; - } - - template - absl::Status ApplyActivation(cv::Mat &tensor_mat, cv::Mat *small_mask_mat); - - ::mediapipe::TensorsToSegmentationCalculatorOptions options_; - -#if !MEDIAPIPE_DISABLE_GPU - mediapipe::GlCalculatorHelper gpu_helper_; - GLuint upsample_program_; -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - std::unique_ptr mask_program_31_; -#else - GLuint mask_program_20_; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 -#if MEDIAPIPE_METAL_ENABLED - MPPMetalHelper *metal_helper_ = nullptr; - id mask_program_; -#endif // MEDIAPIPE_METAL_ENABLED -#endif // !MEDIAPIPE_DISABLE_GPU - }; - REGISTER_CALCULATOR(TensorsToSegmentationCalculator); - - // static - absl::Status TensorsToSegmentationCalculator::GetContract( - CalculatorContract *cc) - { - RET_CHECK(!cc->Inputs().GetTags().empty()); - RET_CHECK(!cc->Outputs().GetTags().empty()); - - // Inputs. - cc->Inputs().Tag(kTensorsTag).Set>(); - if (cc->Inputs().HasTag(kOutputSizeTag)) - { - cc->Inputs().Tag(kOutputSizeTag).Set>(); - } - - // Outputs. - cc->Outputs().Tag(kMaskTag).Set(); - - if (CanUseGpu()) - { -#if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#if MEDIAPIPE_METAL_ENABLED - MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); -#endif // MEDIAPIPE_METAL_ENABLED -#endif // !MEDIAPIPE_DISABLE_GPU - } - - return absl::OkStatus(); - } - - absl::Status TensorsToSegmentationCalculator::Open(CalculatorContext *cc) - { - cc->SetOffset(TimestampDiff(0)); - bool use_gpu = false; - - if (CanUseGpu()) - { -#if !MEDIAPIPE_DISABLE_GPU - use_gpu = true; - MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#if MEDIAPIPE_METAL_ENABLED - metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; - RET_CHECK(metal_helper_); -#endif // MEDIAPIPE_METAL_ENABLED -#endif // !MEDIAPIPE_DISABLE_GPU - } - - MP_RETURN_IF_ERROR(LoadOptions(cc)); - - if (use_gpu) - { -#if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(InitGpu(cc)); -#else - RET_CHECK_FAIL() << "GPU processing disabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } - - return absl::OkStatus(); - } - - absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext *cc) - { - if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) - { - return absl::OkStatus(); - } - - const auto &input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - - bool use_gpu = false; - if (CanUseGpu()) - { - // Use GPU processing only if at least one input tensor is already on GPU. - for (const auto &tensor : input_tensors) - { - if (tensor.ready_on_gpu()) - { - use_gpu = true; - break; - } - } - } - - // Validate tensor channels and activation type. - /*{ - RET_CHECK(!input_tensors.empty()); - ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - int tensor_channels = std::get<2>(hwc); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; - switch (options_.activation()) { - case Options::NONE: - RET_CHECK_EQ(tensor_channels, 1); - break; - case Options::SIGMOID: - RET_CHECK_EQ(tensor_channels, 1); - break; - case Options::SOFTMAX: - RET_CHECK_EQ(tensor_channels, 2); - break; - } - } - */ - /* if (use_gpu) - { -#if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status - { - MP_RETURN_IF_ERROR(ProcessGpu(cc)); - return absl::OkStatus(); })); -#else - RET_CHECK_FAIL() << "GPU processing disabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } - else - { */ - MP_RETURN_IF_ERROR(ProcessCpu(cc)); - //} - - return absl::OkStatus(); - } - - absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext *cc) - { - -#if !MEDIAPIPE_DISABLE_GPU - gpu_helper_.RunInGlContext([this] - { - if (upsample_program_) - glDeleteProgram(upsample_program_); - upsample_program_ = 0; -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - mask_program_31_.reset(); -#else - if (mask_program_20_) - glDeleteProgram(mask_program_20_); - mask_program_20_ = 0; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 -#if MEDIAPIPE_METAL_ENABLED - mask_program_ = nil; -#endif // MEDIAPIPE_METAL_ENABLED - }); -#endif // !MEDIAPIPE_DISABLE_GPU - - return absl::OkStatus(); - } - - absl::Status TensorsToSegmentationCalculator::ProcessCpu( - CalculatorContext *cc) - { - // Get input streams, and dimensions. - const auto &input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) - { - const auto &size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } - - // Create initial working mask. - cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); - - // Wrap input tensor. - auto raw_input_tensor = &input_tensors[0]; - auto raw_input_view = raw_input_tensor->GetCpuReadView(); - const float *raw_input_data = raw_input_view.buffer(); - cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), - CV_MAKETYPE(CV_32F, tensor_channels), - const_cast(raw_input_data)); - - // std::cout << tensor_mat.channels() << std::endl; - std::vector channels(4); - cv::split(tensor_mat, channels); - for (auto ch : channels) - ch = (ch + 1) * 127.5; - - cv::merge(channels, tensor_mat); - - cv::convertScaleAbs(tensor_mat, tensor_mat); - // std::cout << "R (numpy) = " << std::endl << cv::format(tensor_mat, cv::Formatter::FMT_NUMPY ) << std::endl << std::endl; - - // Send out image as CPU packet. - std::shared_ptr mask_frame = std::make_shared( - ImageFormat::SRGB, output_width, output_height); - std::unique_ptr output_mask = absl::make_unique(mask_frame); - auto output_mat = formats::MatView(output_mask.get()); - // Upsample small mask into output. - cv::resize(tensor_mat, *output_mat, - cv::Size(output_width, output_height)); - cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); - - return absl::OkStatus(); + bool DoesGpuTextureStartAtBottom() { + return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; } +#if !MEDIAPIPE_DISABLE_OPENCV template - absl::Status TensorsToSegmentationCalculator::ApplyActivation( - cv::Mat &tensor_mat, cv::Mat *small_mask_mat) + absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); +#endif // !MEDIAPIPE_DISABLE_OPENCV + ::mediapipe::TensorsToSegmentationCalculatorOptions options_; + +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint upsample_program_; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + std::unique_ptr mask_program_31_; +#else + GLuint mask_program_20_; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#if MEDIAPIPE_METAL_ENABLED + MPPMetalHelper* metal_helper_ = nullptr; + id mask_program_; +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(TensorsToSegmentationCalculator); + +// static +absl::Status TensorsToSegmentationCalculator::GetContract( + CalculatorContract* cc) { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); + + // Inputs. + cc->Inputs().Tag(kTensorsTag).Set>(); + if (cc->Inputs().HasTag(kOutputSizeTag)) { + cc->Inputs().Tag(kOutputSizeTag).Set>(); + } + + // Outputs. + cc->Outputs().Tag(kMaskTag).Set(); + + if (CanUseGpu()) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#if MEDIAPIPE_METAL_ENABLED + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + bool use_gpu = false; + + if (CanUseGpu()) { +#if !MEDIAPIPE_DISABLE_GPU + use_gpu = true; + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#if MEDIAPIPE_METAL_ENABLED + metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(metal_helper_); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } + + MP_RETURN_IF_ERROR(LoadOptions(cc)); + + if (use_gpu) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(InitGpu(cc)); +#else + RET_CHECK_FAIL() << "GPU processing disabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { + if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + + bool use_gpu = false; + if (CanUseGpu()) { + // Use GPU processing only if at least one input tensor is already on GPU. + for (const auto& tensor : input_tensors) { + if (tensor.ready_on_gpu()) { + use_gpu = true; + break; + } + } + } + + // Validate tensor channels and activation type. { - // Configure activation function. - const int output_layer_index = options_.output_layer_index(); + RET_CHECK(!input_tensors.empty()); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + int tensor_channels = std::get<2>(hwc); typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; - const auto activation_fn = [&](const cv::Vec2f &mask_value) - { - float new_mask_value = 0; - // TODO consider moving switch out of the loop, - // and also avoid float/Vec2f casting. - switch (options_.activation()) - { + switch (options_.activation()) { case Options::NONE: - { + RET_CHECK_EQ(tensor_channels, 1); + break; + case Options::SIGMOID: + RET_CHECK_EQ(tensor_channels, 1); + break; + case Options::SOFTMAX: + RET_CHECK_EQ(tensor_channels, 2); + break; + } + } + + if (use_gpu) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(ProcessGpu(cc)); + return absl::OkStatus(); + })); +#else + RET_CHECK_FAIL() << "GPU processing disabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + } else { +#if !MEDIAPIPE_DISABLE_OPENCV + MP_RETURN_IF_ERROR(ProcessCpu(cc)); +#else + RET_CHECK_FAIL() << "OpenCV processing disabled."; +#endif // !MEDIAPIPE_DISABLE_OPENCV + } + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + gpu_helper_.RunInGlContext([this] { + if (upsample_program_) glDeleteProgram(upsample_program_); + upsample_program_ = 0; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + mask_program_31_.reset(); +#else + if (mask_program_20_) glDeleteProgram(mask_program_20_); + mask_program_20_ = 0; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#if MEDIAPIPE_METAL_ENABLED + mask_program_ = nil; +#endif // MEDIAPIPE_METAL_ENABLED + }); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status TensorsToSegmentationCalculator::ProcessCpu( + CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_OPENCV + // Get input streams, and dimensions. + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) { + const auto& size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + // Create initial working mask. + cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float* raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + // Process mask tensor and apply activation function. + if (tensor_channels == 2) { + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else if (tensor_channels == 1) { + RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != + options_.activation()); // Requires 2 channels. + if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == + options_.activation()) // Pass-through optimization. + tensor_mat.copyTo(small_mask_mat); + else + MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); + } else { + RET_CHECK_FAIL() << "Unsupported number of tensor channels " + << tensor_channels; + } + + // Send out image as CPU packet. + std::shared_ptr mask_frame = std::make_shared( + ImageFormat::VEC32F1, output_width, output_height); + std::unique_ptr output_mask = absl::make_unique(mask_frame); + auto output_mat = formats::MatView(output_mask.get()); + // Upsample small mask into output. + cv::resize(small_mask_mat, *output_mat, + cv::Size(output_width, output_height)); + cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); +#endif // !MEDIAPIPE_DISABLE_OPENCV + + return absl::OkStatus(); +} + +#if !MEDIAPIPE_DISABLE_OPENCV +template +absl::Status TensorsToSegmentationCalculator::ApplyActivation( + cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { + // Configure activation function. + const int output_layer_index = options_.output_layer_index(); + typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + const auto activation_fn = [&](const cv::Vec2f& mask_value) { + float new_mask_value = 0; + // TODO consider moving switch out of the loop, + // and also avoid float/Vec2f casting. + switch (options_.activation()) { + case Options::NONE: { new_mask_value = mask_value[0]; break; } - case Options::SIGMOID: - { + case Options::SIGMOID: { const float pixel0 = mask_value[0]; new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); break; } - case Options::SOFTMAX: - { + case Options::SOFTMAX: { const float pixel0 = mask_value[0]; const float pixel1 = mask_value[1]; const float max_pixel = std::max(pixel0, pixel1); @@ -436,147 +407,225 @@ namespace mediapipe softmax_denom; break; } - } - return new_mask_value; - }; + } + return new_mask_value; + }; - // Process mask tensor. - for (int i = 0; i < tensor_mat.rows; ++i) - { - for (int j = 0; j < tensor_mat.cols; ++j) - { - const T &input_pix = tensor_mat.at(i, j); - const float mask_value = activation_fn(input_pix); - small_mask_mat->at(i, j) = mask_value; - } + // Process mask tensor. + for (int i = 0; i < tensor_mat.rows; ++i) { + for (int j = 0; j < tensor_mat.cols; ++j) { + const T& input_pix = tensor_mat.at(i, j); + const float mask_value = activation_fn(input_pix); + small_mask_mat->at(i, j) = mask_value; + } + } + + return absl::OkStatus(); +} +#endif // !MEDIAPIPE_DISABLE_OPENCV + +// Steps: +// 1. receive tensor +// 2. process segmentation tensor into small mask +// 3. upsample small mask into output mask to be same size as input image +absl::Status TensorsToSegmentationCalculator::ProcessGpu( + CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU + // Get input streams, and dimensions. + const auto& input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) { + const auto& size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + // Create initial working mask texture. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + tflite::gpu::gl::GlTexture small_mask_texture; +#else + mediapipe::GlTexture small_mask_texture; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + // Run shader, process mask tensor. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + { + MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture( + tflite::gpu::DataType::UINT8, // GL_RGBA8 + {tensor_width, tensor_height}, &small_mask_texture)); + + const int output_index = 0; + glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, + GL_WRITE_ONLY, GL_RGBA8); + + auto read_view = input_tensors[0].GetOpenGlBufferReadView(); + glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); + + const tflite::gpu::uint3 workgroups = { + NumGroups(tensor_width, kWorkgroupSize), + NumGroups(tensor_height, kWorkgroupSize), 1}; + + glUseProgram(mask_program_31_->id()); + glUniform2i(glGetUniformLocation(mask_program_31_->id(), "out_size"), + tensor_width, tensor_height); + + MP_RETURN_IF_ERROR(mask_program_31_->Dispatch(workgroups)); + } +#elif MEDIAPIPE_METAL_ENABLED + { + id command_buffer = [metal_helper_ commandBuffer]; + command_buffer.label = @"SegmentationKernel"; + id command_encoder = + [command_buffer computeCommandEncoder]; + [command_encoder setComputePipelineState:mask_program_]; + + auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer); + [command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0]; + + mediapipe::GpuBuffer small_mask_buffer = [metal_helper_ + mediapipeGpuBufferWithWidth:tensor_width + height:tensor_height + format:mediapipe::GpuBufferFormat::kBGRA32]; + id small_mask_texture_metal = + [metal_helper_ metalTextureWithGpuBuffer:small_mask_buffer]; + [command_encoder setTexture:small_mask_texture_metal atIndex:1]; + + unsigned int out_size[] = {static_cast(tensor_width), + static_cast(tensor_height)}; + [command_encoder setBytes:&out_size length:sizeof(out_size) atIndex:2]; + + MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); + MTLSize threadgroups = + MTLSizeMake(NumGroups(tensor_width, kWorkgroupSize), + NumGroups(tensor_height, kWorkgroupSize), 1); + [command_encoder dispatchThreadgroups:threadgroups + threadsPerThreadgroup:threads_per_group]; + [command_encoder endEncoding]; + [command_buffer commit]; + + small_mask_texture = gpu_helper_.CreateSourceTexture(small_mask_buffer); + } +#else + { + small_mask_texture = gpu_helper_.CreateDestinationTexture( + tensor_width, tensor_height, + mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 + + // Go through CPU if not already texture 2D (no direct conversion yet). + // Tensor::GetOpenGlTexture2dReadView() doesn't automatically convert types. + if (!input_tensors[0].ready_as_opengl_texture_2d()) { + (void)input_tensors[0].GetCpuReadView(); } - return absl::OkStatus(); + auto read_view = input_tensors[0].GetOpenGlTexture2dReadView(); + + gpu_helper_.BindFramebuffer(small_mask_texture); + glActiveTexture(GL_TEXTURE1); + glBindTexture(GL_TEXTURE_2D, read_view.name()); + glUseProgram(mask_program_20_); + GlRender(); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + } +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + + // Upsample small mask into output. + mediapipe::GlTexture output_texture = gpu_helper_.CreateDestinationTexture( + output_width, output_height, + mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 + + // Run shader, upsample result. + { + gpu_helper_.BindFramebuffer(output_texture); + glActiveTexture(GL_TEXTURE1); +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + glBindTexture(GL_TEXTURE_2D, small_mask_texture.id()); +#else + glBindTexture(GL_TEXTURE_2D, small_mask_texture.name()); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + glUseProgram(upsample_program_); + GlRender(); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); } - // Steps: - // 1. receive tensor - // 2. process segmentation tensor into small mask - // 3. upsample small mask into output mask to be same size as input image - absl::Status TensorsToSegmentationCalculator::ProcessGpu( - CalculatorContext *cc) - { + // Send out image as GPU packet. + auto output_image = output_texture.GetFrame(); + cc->Outputs().Tag(kMaskTag).Add(output_image.release(), cc->InputTimestamp()); + + // Cleanup + output_texture.Release(); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +void TensorsToSegmentationCalculator::GlRender() { #if !MEDIAPIPE_DISABLE_GPU + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; - // Get input streams, and dimensions. - const auto &input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) - { - const auto &size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); - // Wrap input tensor. - auto raw_input_tensor = &input_tensors[0]; - auto raw_input_view = raw_input_tensor->GetCpuReadView(); - const float *raw_input_data = raw_input_view.buffer(); - cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), - CV_MAKETYPE(CV_32F, tensor_channels), - const_cast(raw_input_data)); + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); - // std::cout << tensor_mat.channels() << std::endl; - std::vector channels(4); - cv::split(tensor_mat, channels); - for (auto ch : channels) - ch = (ch + 1) * 127.5; + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); - cv::merge(channels, tensor_mat); + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); - cv::convertScaleAbs(tensor_mat, tensor_mat); - // std::cout << "R (numpy) = " << std::endl << cv::format(tensor_mat, cv::Formatter::FMT_NUMPY ) << std::endl << std::endl; + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); +#endif // !MEDIAPIPE_DISABLE_GPU +} - // Send out image as CPU packet. - std::shared_ptr mask_frame = std::make_shared( - ImageFormat::SRGB, output_width, output_height); - std::unique_ptr output_mask = absl::make_unique(mask_frame); - auto output_mat = formats::MatView(output_mask.get()); - // Upsample small mask into output. - cv::resize(tensor_mat, *output_mat, - cv::Size(output_width, output_height)); - cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); +absl::Status TensorsToSegmentationCalculator::LoadOptions( + CalculatorContext* cc) { + // Get calculator options specified in the graph. + options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); -#endif // !MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); +} - return absl::OkStatus(); - } - - void TensorsToSegmentationCalculator::GlRender() - { +absl::Status TensorsToSegmentationCalculator::InitGpu(CalculatorContext* cc) { #if !MEDIAPIPE_DISABLE_GPU - static const GLfloat square_vertices[] = { - -1.0f, -1.0f, // bottom left - 1.0f, -1.0f, // bottom right - -1.0f, 1.0f, // top left - 1.0f, 1.0f, // top right - }; - static const GLfloat texture_vertices[] = { - 0.0f, 0.0f, // bottom left - 1.0f, 0.0f, // bottom right - 0.0f, 1.0f, // top left - 1.0f, 1.0f, // top right - }; - - // vertex storage - GLuint vbo[2]; - glGenBuffers(2, vbo); - GLuint vao; - glGenVertexArrays(1, &vao); - glBindVertexArray(vao); - - // vbo 0 - glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); - glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, - GL_STATIC_DRAW); - glEnableVertexAttribArray(ATTRIB_VERTEX); - glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); - - // vbo 1 - glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); - glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, - GL_STATIC_DRAW); - glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); - glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); - - // draw - glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); - - // cleanup - glDisableVertexAttribArray(ATTRIB_VERTEX); - glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); - glBindBuffer(GL_ARRAY_BUFFER, 0); - glBindVertexArray(0); - glDeleteVertexArrays(1, &vao); - glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU - } - - absl::Status TensorsToSegmentationCalculator::LoadOptions( - CalculatorContext *cc) - { - // Get calculator options specified in the graph. - options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); - - return absl::OkStatus(); - } - - absl::Status TensorsToSegmentationCalculator::InitGpu(CalculatorContext *cc) - { -#if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status - { - // A shader to process a segmentation tensor into an output mask. - // Currently uses 4 channels for output, and sets R+A channels as mask value. + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { + // A shader to process a segmentation tensor into an output mask. + // Currently uses 4 channels for output, and sets R+A channels as mask value. #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // GLES 3.1 const tflite::gpu::uint3 workgroup_size = {kWorkgroupSize, kWorkgroupSize, @@ -757,7 +806,7 @@ void main() { vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value); fragColor = out_value; })"; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Shader defines. typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; @@ -793,7 +842,7 @@ void main() { "texture_coordinate", }; - // Main shader program & parameters + // Main shader program & parameters #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 GlShader shader_without_previous; MP_RETURN_IF_ERROR(GlShader::CompileShader( @@ -825,7 +874,7 @@ void main() { RET_CHECK(mask_program_20_) << "Problem initializing the program."; glUseProgram(mask_program_20_); glUniform1i(glGetUniformLocation(mask_program_20_, "input_texture"), 1); -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Simple pass-through program, used for hardware upsampling. mediapipe::GlhCreateProgram( @@ -835,10 +884,11 @@ void main() { glUseProgram(upsample_program_); glUniform1i(glGetUniformLocation(upsample_program_, "video_frame"), 1); - return absl::OkStatus(); })); -#endif // !MEDIAPIPE_DISABLE_GPU - return absl::OkStatus(); - } + })); +#endif // !MEDIAPIPE_DISABLE_GPU -} // namespace mediapipe + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/graphs/beauty/beauty_desktop_cpu.pbtxt b/mediapipe/graphs/beauty/beauty_desktop_cpu.pbtxt index c68bf8df9..ca671f07f 100644 --- a/mediapipe/graphs/beauty/beauty_desktop_cpu.pbtxt +++ b/mediapipe/graphs/beauty/beauty_desktop_cpu.pbtxt @@ -14,7 +14,7 @@ profiler_config { trace_enabled: true enable_profiler: true trace_log_interval_count: 200 - trace_log_path: "/Users/alena/Workdir/mediapipe/logs/beauty/" + trace_log_path: "/home/mslight/Work/clone/mediapipe/mediapipe/logs/beauty/" } # Throttles the images flowing downstream for flow control. It passes through diff --git a/mediapipe/graphs/deformation/calculators/face_processor_calculator.cc b/mediapipe/graphs/deformation/calculators/face_processor_calculator.cc index 57768dbf1..0bc747363 100644 --- a/mediapipe/graphs/deformation/calculators/face_processor_calculator.cc +++ b/mediapipe/graphs/deformation/calculators/face_processor_calculator.cc @@ -19,7 +19,6 @@ #include #include #include -//#include #include #include "Tensor.h" diff --git a/mediapipe/graphs/deformation/config/BUILD b/mediapipe/graphs/deformation/config/BUILD index 3ece89a85..88e7f025f 100644 --- a/mediapipe/graphs/deformation/config/BUILD +++ b/mediapipe/graphs/deformation/config/BUILD @@ -16,18 +16,6 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) -load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") - -encode_binary_proto( - name = "triangles", - input = "triangles.pbtxt", - message_type = "mediapipe.face_geometry.Mesh3d", - output = "triangles.binarypb", - deps = [ - "//mediapipe/modules/face_geometry/protos:mesh_3d_proto", - ], -) - exports_files( srcs = glob(["**"]), ) diff --git a/mediapipe/graphs/image_style/BUILD b/mediapipe/graphs/image_style/BUILD index 9515ca573..c7350c9b3 100644 --- a/mediapipe/graphs/image_style/BUILD +++ b/mediapipe/graphs/image_style/BUILD @@ -29,14 +29,16 @@ cc_library( "//mediapipe/calculators/tensor:tensor_converter_calculator", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/image:image_transformation_calculator", - "//mediapipe/calculators/tensor:tensors_to_segmentation_calculator", "//mediapipe/calculators/util:to_image_calculator", "//mediapipe/calculators/util:from_image_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/modules/face_landmark:face_landmark_front_gpu", + "//mediapipe/calculators/image_style:apply_mask_calculator", "//mediapipe/calculators/image_style:fast_utils_calculator", "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", "//mediapipe/calculators/core:constant_side_packet_calculator", + "//mediapipe/calculators/tensor:tensors_to_image_calculator", ], ) @@ -48,11 +50,12 @@ cc_library( "//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:tensor_converter_calculator", - "//mediapipe/calculators/tensor:tensors_to_segmentation_calculator", + "//mediapipe/calculators/tensor:tensors_to_image_calculator", "//mediapipe/calculators/util:to_image_calculator", "//mediapipe/calculators/util:from_image_calculator", "//mediapipe/modules/face_landmark:face_landmark_front_cpu", "//mediapipe/calculators/image_style:fast_utils_calculator", + "//mediapipe/calculators/image_style:apply_mask_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/core:constant_side_packet_calculator", ], diff --git a/mediapipe/graphs/image_style/image_style_cpu.pbtxt b/mediapipe/graphs/image_style/image_style_cpu.pbtxt index 6dcb48f9e..e9ec9f806 100644 --- a/mediapipe/graphs/image_style/image_style_cpu.pbtxt +++ b/mediapipe/graphs/image_style/image_style_cpu.pbtxt @@ -17,6 +17,12 @@ node { output_stream: "throttled_input_video" } +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:input_video" + output_stream: "SIZE:original_size" +} + # Defines side packets for further use in the graph. node { @@ -44,7 +50,14 @@ node { calculator: "FastUtilsCalculator" input_stream: "NORM_LANDMARKS:multi_face_landmarks" input_stream: "IMAGE:throttled_input_video" + input_stream: "SIZE:original_size" output_stream: "IMAGE:out_image_frame" + output_stream: "LM_MASK:lm_mask" + options { + [mediapipe.FastUtilsCalculatorOptions.ext] { + back_to_image: false + } + } } node: { @@ -83,26 +96,91 @@ node { } node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE:transformed_input_video" - output_stream: "SIZE:input_size" + calculator: "TensorsToImageCalculator" + input_stream: "TENSORS:output_tensor" + output_stream: "IMAGE:fake_image" +} + + +node{ + calculator: "FromImageCalculator" + input_stream: "IMAGE:fake_image" + output_stream: "IMAGE_CPU:fake_image2" +} + + +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_video" + output_stream: "IMAGE:transformed_input_img" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + } + } } node { - calculator: "TensorsToSegmentationCalculator" - input_stream: "TENSORS:output_tensor" - input_stream: "OUTPUT_SIZE:input_size" - output_stream: "MASK:output" + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:transformed_input_img" + output_stream: "TENSORS:input_tensor_img" + options: { + [mediapipe.TensorConverterCalculatorOptions.ext] { + zero_center: true + } + } +} + +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensor_img" + output_stream: "TENSORS:output_tensor_img" options: { - [mediapipe.TensorsToSegmentationCalculatorOptions.ext] { - activation: NONE + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/models/model_float32.tflite" + delegate { xnnpack {} } } - } + } +} + + +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE:transformed_input_img" + output_stream: "SIZE:input_size_img" +} + +node { + calculator: "TensorsToImageCalculator" + input_stream: "TENSORS:output_tensor_img" + input_stream: "OUTPUT_SIZE:input_size_img" + output_stream: "IMAGE:fake_bg2" } node{ calculator: "FromImageCalculator" - input_stream: "IMAGE:output" - output_stream: "IMAGE_CPU:output_video" + input_stream: "IMAGE:fake_bg2" + output_stream: "IMAGE_CPU:fake_bg" } +node { + calculator: "FastUtilsCalculator" + input_stream: "NORM_LANDMARKS:multi_face_landmarks" + input_stream: "IMAGE:fake_image2" + input_stream: "SIZE:original_size" + output_stream: "IMAGE:back_image" + options { + [mediapipe.FastUtilsCalculatorOptions.ext] { + back_to_image: true + } + } +} + +node { + calculator: "ApplyMaskCalculator" + input_stream: "IMAGE:back_image" + input_stream: "FAKE_BG:fake_bg" + input_stream: "LM_MASK:lm_mask" + output_stream: "IMAGE:output_video" +} diff --git a/mediapipe/graphs/image_style/image_style_gpu.pbtxt b/mediapipe/graphs/image_style/image_style_gpu.pbtxt index d97da1d1c..7e5a1fb40 100644 --- a/mediapipe/graphs/image_style/image_style_gpu.pbtxt +++ b/mediapipe/graphs/image_style/image_style_gpu.pbtxt @@ -36,6 +36,12 @@ node { output_stream: "throttled_input_video_cpu" } +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_CPU:throttled_input_video_cpu" + output_stream: "SIZE:original_size" +} + # Subgraph that detects faces and corresponding landmarks. node { calculator: "FaceLandmarkFrontGpu" @@ -48,26 +54,33 @@ node { calculator: "FastUtilsCalculator" input_stream: "NORM_LANDMARKS:multi_face_landmarks" input_stream: "IMAGE:throttled_input_video_cpu" + input_stream: "SIZE:original_size" output_stream: "IMAGE:out_image_frame" + output_stream: "LM_MASK:lm_mask" + options { + [mediapipe.FastUtilsCalculatorOptions.ext] { + back_to_image: false + } + } } -#node: { -# calculator: "ImageTransformationCalculator" -# input_stream: "IMAGE:out_image_frame" -# output_stream: "IMAGE:out_image_frame1" -# node_options: { -# [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { -# output_width: 256 -# output_height: 256 -# } -# } -#} +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:out_image_frame" + output_stream: "IMAGE:image_frame" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + } + } +} node { calculator: "TensorConverterCalculator" - input_stream: "IMAGE:out_image_frame" - output_stream: "TENSORS:input_tensors" + input_stream: "IMAGE:image_frame" + output_stream: "TENSORS:input_tensor" options: { [mediapipe.TensorConverterCalculatorOptions.ext] { zero_center: true @@ -75,35 +88,109 @@ node { } } -#node { -# calculator: "InferenceCalculator" -# input_stream: "TENSORS:input_tensors" -# output_stream: "TENSORS:output_tensors" -# options: { -# [mediapipe.InferenceCalculatorOptions.ext] { -# model_path:"mediapipe/models/model_float32.tflite" -# delegate { gpu {} } -# } -# } -#} - -# Processes the output tensors into a segmentation mask that has the same size -# as the input image into the graph. node { - calculator: "TensorsToSegmentationCalculator" - input_stream: "TENSORS:input_tensors" - #input_stream: "OUTPUT_SIZE:input_size" - output_stream: "MASK:mask_image" + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensor" + output_stream: "TENSORS:output_tensor" options: { - [mediapipe.TensorsToSegmentationCalculatorOptions.ext] { - activation: NONE - gpu_origin: TOP_LEFT + [mediapipe.InferenceCalculatorOptions.ext] { + model_path:"mediapipe/models/model_float32.tflite" + delegate { gpu {} } } - } + } +} + +node { + calculator: "TensorsToImageCalculator" + input_stream: "TENSORS:output_tensor" + output_stream: "IMAGE:fake_image" } node: { calculator: "FromImageCalculator" - input_stream: "IMAGE:mask_image" - output_stream: "IMAGE_GPU:output_video" + input_stream: "IMAGE:fake_image" + output_stream: "IMAGE_CPU:cpu_fake_image" +} + +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:throttled_input_video_cpu" + output_stream: "IMAGE:transformed_input_img" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 256 + output_height: 256 + } + } +} + +node { + calculator: "TensorConverterCalculator" + input_stream: "IMAGE:transformed_input_img" + output_stream: "TENSORS:input_tensor_img" + options: { + [mediapipe.TensorConverterCalculatorOptions.ext] { + zero_center: true + } + } +} + +node { + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensor_img" + output_stream: "TENSORS:output_tensor_img" + options: { + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/models/model_float32.tflite" + delegate { xnnpack {} } + } + } +} + + +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_CPU:transformed_input_img" + output_stream: "SIZE:input_size_img" +} + +node { + calculator: "TensorsToImageCalculator" + input_stream: "TENSORS:output_tensor_img" + input_stream: "OUTPUT_SIZE:input_size_img" + output_stream: "IMAGE:fake_bg2" +} + +node{ + calculator: "FromImageCalculator" + input_stream: "IMAGE:fake_bg2" + output_stream: "IMAGE_CPU:fake_bg" +} + +node { + calculator: "FastUtilsCalculator" + input_stream: "NORM_LANDMARKS:multi_face_landmarks" + input_stream: "IMAGE:cpu_fake_image" + input_stream: "SIZE:original_size" + output_stream: "IMAGE:back_image" + options { + [mediapipe.FastUtilsCalculatorOptions.ext] { + back_to_image: true + } + } +} + +node { + calculator: "ApplyMaskCalculator" + input_stream: "IMAGE:back_image" + input_stream: "FAKE_BG:fake_bg" + input_stream: "LM_MASK:lm_mask" + output_stream: "IMAGE:out_image" +} + +# Defines side packets for further use in the graph. +node { + calculator: "ImageFrameToGpuBufferCalculator" + input_stream: "out_image" + output_stream: "output_video" }