diff --git a/mediapipe/calculators/image_style/BUILD b/mediapipe/calculators/image_style/BUILD new file mode 100644 index 000000000..06ff1b455 --- /dev/null +++ b/mediapipe/calculators/image_style/BUILD @@ -0,0 +1,50 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "fast_utils_calculator", + srcs = ["fast_utils_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/util:color_cc_proto", + "@com_google_absl//absl/strings", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:video_stream_header", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + "//mediapipe/util:annotation_renderer", + "//mediapipe/util:render_data_cc_proto", + ], + alwayslink = 1, +) + + + + + + diff --git a/mediapipe/calculators/image_style/fast_utils_calculator.cc b/mediapipe/calculators/image_style/fast_utils_calculator.cc new file mode 100644 index 000000000..929b87c0f --- /dev/null +++ b/mediapipe/calculators/image_style/fast_utils_calculator.cc @@ -0,0 +1,399 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +//#include + +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/video_stream_header.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/util/annotation_renderer.h" +#include "mediapipe/util/render_data.pb.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/vector.h" +#include "mediapipe/util/color.pb.h" + +namespace mediapipe +{ + namespace + { + static const std::vector FFHQ_NORM_LM = { + {638.68525475 / 1024, 486.24604922 / 1024}, + {389.31496114 / 1024, 485.8921848 / 1024}, + {513.67979275 / 1024, 620.8915371 / 1024}, + {405.50932642 / 1024, 756.52797927 / 1024}, + {622.55630397 / 1024, 756.15509499 / 1024}}; + + constexpr char kImageFrameTag[] = "IMAGE"; + constexpr char kVectorTag[] = "VECTOR"; + + std::tuple _normalized_to_pixel_coordinates(float normalized_x, + float normalized_y, int image_width, int image_height) + { + // Converts normalized value pair to pixel coordinates + int x_px = std::min(floor(normalized_x * image_width), image_width - 1); + int y_px = std::min(floor(normalized_y * image_height), image_height - 1); + + return {x_px, y_px}; + }; + + static const std::unordered_set FACEMESH_FACE_OVAL = + {{10, 338}, {338, 297}, {297, 332}, {332, 284}, {284, 251}, {251, 389}, {389, 356}, {356, 454}, {454, 323}, {323, 361}, {361, 288}, {288, 397}, {397, 365}, {365, 379}, {379, 378}, {378, 400}, {400, 377}, {377, 152}, {152, 148}, {148, 176}, {176, 149}, {149, 150}, {150, 136}, {136, 172}, {172, 58}, {58, 132}, {132, 93}, {93, 234}, {234, 127}, {127, 162}, {162, 21}, {21, 54}, {54, 103}, {103, 67}, {67, 109}, {109, 10}}; + + enum + { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + NUM_ATTRIBUTES + }; + + // Round up n to next multiple of m. + size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT + inline bool HasImageTag(mediapipe::CalculatorContext *cc) { return false; } + + using Point = RenderAnnotation::Point; + + bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y, + int image_width, int image_height, int *x_px, + int *y_px) + { + CHECK(x_px != nullptr); + CHECK(y_px != nullptr); + CHECK_GT(image_width, 0); + CHECK_GT(image_height, 0); + + if (normalized_x < 0 || normalized_x > 1.0 || normalized_y < 0 || + normalized_y > 1.0) + { + VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0"; + } + + *x_px = static_cast(round(normalized_x * image_width)); + *y_px = static_cast(round(normalized_y * image_height)); + + return true; + } + } // namespace + + class FastUtilsCalculator : public CalculatorBase + { + public: + FastUtilsCalculator() = default; + ~FastUtilsCalculator() override = default; + + static absl::Status GetContract(CalculatorContract *cc); + + // From Calculator. + absl::Status Open(CalculatorContext *cc) override; + absl::Status Process(CalculatorContext *cc) override; + absl::Status Close(CalculatorContext *cc) override; + + private: + absl::Status CreateRenderTargetCpu(CalculatorContext *cc, + std::unique_ptr &image_mat, + ImageFormat::Format *target_format); + + absl::Status RenderToCpu( + CalculatorContext *cc, const ImageFormat::Format &target_format, + uchar *data_image, std::unique_ptr &image_mat); + + absl::Status Call(CalculatorContext *cc, + std::unique_ptr &image_mat, + ImageFormat::Format *target_format, + const RenderData &render_data, + std::unordered_map &all_masks); + + // Indicates if image frame is available as input. + bool image_frame_available_ = false; + std::unordered_map> index_dict = { + {"leftEye", {384, 385, 386, 387, 388, 390, 263, 362, 398, 466, 373, 374, 249, 380, 381, 382}}, + {"rightEye", {160, 33, 161, 163, 133, 7, 173, 144, 145, 246, 153, 154, 155, 157, 158, 159}}, + {"nose", {4}}, + {"lips", {0, 13, 14, 17, 84}}, + {"leftLips", {61, 146}}, + {"rightLips", {291, 375}}, + }; + + int width_ = 0; + int height_ = 0; + int width_canvas_ = 0; // Size of overlay drawing texture canvas. + int height_canvas_ = 0; + + int max_num_faces = 1; + bool refine_landmarks = True; + double min_detection_confidence = 0.5; + double min_tracking_confidence = 0.5; + }; + REGISTER_CALCULATOR(FastUtilsCalculator); + + absl::Status FastUtilsCalculator::GetContract(CalculatorContract *cc) + { + CHECK_GE(cc->Inputs().NumEntries(), 1); + + if (cc->Inputs().HasTag(kImageFrameTag)) + { + cc->Inputs().Tag(kImageFrameTag).Set(); + CHECK(cc->Outputs().HasTag(kImageFrameTag)); + } + + if (cc->Outputs().HasTag(kImageFrameTag)) + { + cc->Outputs().Tag(kImageFrameTag).Set(); + } + + return absl::OkStatus(); + } + + absl::Status FastUtilsCalculator::Open(CalculatorContext *cc) + { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Inputs().HasTag(kImageFrameTag) || HasImageTag(cc)) + { + image_frame_available_ = true; + } + else + { + } + + // Set the output header based on the input header (if present). + const char *tag = kImageFrameTag; + if (image_frame_available_ && !cc->Inputs().Tag(tag).Header().IsEmpty()) + { + const auto &input_header = + cc->Inputs().Tag(tag).Header().Get(); + auto *output_video_header = new VideoHeader(input_header); + cc->Outputs().Tag(tag).SetHeader(Adopt(output_video_header)); + } + + return absl::OkStatus(); + } + + absl::Status FastUtilsCalculator::Process(CalculatorContext *cc) + { + if (cc->Inputs().HasTag(kImageFrameTag) && + cc->Inputs().Tag(kImageFrameTag).IsEmpty()) + { + return absl::OkStatus(); + } + + // Initialize render target, drawn with OpenCV. + std::unique_ptr image_mat; + ImageFormat::Format target_format; + std::unordered_map all_masks; + + if (cc->Outputs().HasTag(kImageFrameTag)) + { + MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); + } + + // Render streams onto render target. + for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); + ++id) + { + auto tag_and_index = cc->Inputs().TagAndIndexFromId(id); + std::string tag = tag_and_index.first; + if (!tag.empty() && tag != kVectorTag) + { + continue; + } + if (cc->Inputs().Get(id).IsEmpty()) + { + continue; + } + if (tag.empty()) + { + // Empty tag defaults to accepting a single object of RenderData type. + const RenderData &render_data = cc->Inputs().Get(id).Get(); + MP_RETURN_IF_ERROR(Call(cc, image_mat, &target_format, render_data, all_masks)); + } + else + { + RET_CHECK_EQ(kVectorTag, tag); + const std::vector &render_data_vec = + cc->Inputs().Get(id).Get>(); + for (const RenderData &render_data : render_data_vec) + { + MP_RETURN_IF_ERROR(Call(cc, image_mat, &target_format, render_data, all_masks)); + } + } + } + + // Copy the rendered image to output. + uchar *image_mat_ptr = image_mat->data; + MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr, image_mat)); + + return absl::OkStatus(); + } + + absl::Status FastUtilsCalculator::Close(CalculatorContext *cc) + { + return absl::OkStatus(); + } + + absl::Status FastUtilsCalculator::RenderToCpu( + CalculatorContext *cc, const ImageFormat::Format &target_format, + uchar *data_image, std::unique_ptr &image_mat) + { + + cv::Mat mat_image_ = *image_mat.get(); + + auto output_frame = absl::make_unique( + target_format, mat_image_.cols, mat_image_.rows); + + output_frame->CopyPixelData(target_format, mat_image_.cols, mat_image_.rows, data_image, + ImageFrame::kDefaultAlignmentBoundary); + + if (cc->Outputs().HasTag(kImageFrameTag)) + { + cc->Outputs() + .Tag(kImageFrameTag) + .Add(output_frame.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); + } + + absl::Status FastUtilsCalculator::CreateRenderTargetCpu( + CalculatorContext *cc, std::unique_ptr &image_mat, + ImageFormat::Format *target_format) + { + if (image_frame_available_) + { + const auto &input_frame = + cc->Inputs().Tag(kImageFrameTag).Get(); + + int target_mat_type; + switch (input_frame.Format()) + { + case ImageFormat::SRGBA: + *target_format = ImageFormat::SRGBA; + target_mat_type = CV_8UC4; + break; + case ImageFormat::SRGB: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + case ImageFormat::GRAY8: + *target_format = ImageFormat::SRGB; + target_mat_type = CV_8UC3; + break; + default: + return absl::UnknownError("Unexpected image frame format."); + break; + } + + image_mat = absl::make_unique( + input_frame.Height(), input_frame.Width(), target_mat_type); + + auto input_mat = formats::MatView(&input_frame); + + if (input_frame.Format() == ImageFormat::GRAY8) + { + cv::Mat rgb_mat; + cv::cvtColor(input_mat, rgb_mat, CV_GRAY2RGB); + rgb_mat.copyTo(*image_mat); + } + else + { + input_mat.copyTo(*image_mat); + } + } + else + { + image_mat = absl::make_unique( + 150, 150, CV_8UC4, + cv::Scalar(255, 255, + 255)); + *target_format = ImageFormat::SRGBA; + } + + return absl::OkStatus(); + } + + absl::Status FastUtilsCalculator::Call(CalculatorContext *cc, + std::unique_ptr &image_mat, + ImageFormat::Format *target_format, + const RenderData &render_data, + std::unordered_map &all_masks) + { + cv::Mat mat_image_ = *image_mat.get(); + + int image_width_ = image_mat->cols; + int image_height_ = image_mat->rows; + + cv::Mat mask; + std::vector kps, landmarks; + std::vector> lms_out; + + int c = 0; + + for (const auto &[key, value] : index_dict) + { + for (auto order : value) + { + c = 0; + for (auto &annotation : render_data.render_annotations()) + { + if (annotation.data_case() == RenderAnnotation::kPoint) + { + if (order == c) + { + const auto &point = annotation.point(); + int x = -1; + int y = -1; + CHECK(NormalizedtoPixelCoordinates(point.x(), point.y(), image_width_, + image_height_, &x, &y)); + kps.push_back(cv::Point(x, y)); + } + c += 1; + } + } + } + double sumx = 0, sumy = 0, meanx, meany; + + for (auto p : kps) + { + sumx += p.x; + sumy += p.y; + } + meanx = sumx / kps.size(); + meany = sumy / kps.size(); + + landmarks.push_back({meanx, meany}); + + kps.clear(); + } + + lms_out.push_back(landmarks); + + return absl::OkStatus(); + } + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index a03a60189..27c18e352 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -13,6 +13,8 @@ // limitations under the License. #include +#include +#include #include "absl/strings/str_format.h" #include "absl/types/span.h" @@ -35,7 +37,7 @@ #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #include "tensorflow/lite/delegates/gpu/gl/converters/util.h" @@ -43,7 +45,7 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_METAL_ENABLED #import @@ -52,339 +54,385 @@ #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" -#endif // MEDIAPIPE_METAL_ENABLED +#endif // MEDIAPIPE_METAL_ENABLED -namespace { -constexpr int kWorkgroupSize = 8; // Block size for GPU shader. -enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; +namespace +{ + constexpr int kWorkgroupSize = 8; // Block size for GPU shader. + enum + { + ATTRIB_VERTEX, + ATTRIB_TEXTURE_POSITION, + NUM_ATTRIBUTES + }; -// Commonly used to compute the number of blocks to launch in a kernel. -int NumGroups(const int size, const int group_size) { // NOLINT - return (size + group_size - 1) / group_size; -} + std::chrono::steady_clock::time_point begin; -bool CanUseGpu() { +std::chrono::steady_clock::time_point end; + // Commonly used to compute the number of blocks to launch in a kernel. + int NumGroups(const int size, const int group_size) + { // NOLINT + return (size + group_size - 1) / group_size; + } + + bool CanUseGpu() + { #if !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED - // TODO: Configure GPU usage policy in individual calculators. - constexpr bool kAllowGpuProcessing = true; - return kAllowGpuProcessing; + // TODO: Configure GPU usage policy in individual calculators. + constexpr bool kAllowGpuProcessing = true; + return kAllowGpuProcessing; #else - return false; -#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED -} - -constexpr char kTensorsTag[] = "TENSORS"; -constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; -constexpr char kMaskTag[] = "MASK"; - -absl::StatusOr> GetHwcFromDims( - const std::vector& dims) { - if (dims.size() == 3) { - return std::make_tuple(dims[0], dims[1], dims[2]); - } else if (dims.size() == 4) { - // BHWC format check B == 1 - RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; - return std::make_tuple(dims[1], dims[2], dims[3]); - } else { - RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); + return false; +#endif // !MEDIAPIPE_DISABLE_GPU || MEDIAPIPE_METAL_ENABLED } -} -} // namespace -namespace mediapipe { + constexpr char kTensorsTag[] = "TENSORS"; + constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; + constexpr char kMaskTag[] = "MASK"; + + absl::StatusOr> GetHwcFromDims( + const std::vector &dims) + { + if (dims.size() == 3) + { + return std::make_tuple(dims[0], dims[1], dims[2]); + } + else if (dims.size() == 4) + { + // BHWC format check B == 1 + RET_CHECK_EQ(1, dims[0]) << "Expected batch to be 1 for BHWC heatmap"; + return std::make_tuple(dims[1], dims[2], dims[3]); + } + else + { + RET_CHECK(false) << "Invalid shape for segmentation tensor " << dims.size(); + } + } +} // namespace + +namespace mediapipe +{ #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 -using ::tflite::gpu::gl::GlProgram; -using ::tflite::gpu::gl::GlShader; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + using ::tflite::gpu::gl::GlProgram; + using ::tflite::gpu::gl::GlShader; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 -// Converts Tensors from a tflite segmentation model to an image mask. -// -// Performs optional upscale to OUTPUT_SIZE dimensions if provided, -// otherwise the mask is the same size as input tensor. -// -// If at least one input tensor is already on GPU, processing happens on GPU and -// the output mask is also stored on GPU. Otherwise, processing and the output -// mask are both on CPU. -// -// On GPU, the mask is an RGBA image, in both the R & A channels, scaled 0-1. -// On CPU, the mask is a ImageFormat::VEC32F1 image, with values scaled 0-1. -// -// -// Inputs: -// One of the following TENSORS tags: -// TENSORS: Vector of Tensor, -// The tensor dimensions are specified in this calculator's options. -// OUTPUT_SIZE(optional): std::pair, -// If provided, the size to upscale mask to. -// -// Output: -// MASK: An Image output mask, RGBA(GPU) / VEC32F1(CPU). -// -// Options: -// See tensors_to_segmentation_calculator.proto -// -// Usage example: -// node { -// calculator: "TensorsToSegmentationCalculator" -// input_stream: "TENSORS:tensors" -// input_stream: "OUTPUT_SIZE:size" -// output_stream: "MASK:hair_mask" -// node_options: { -// [mediapipe.TensorsToSegmentationCalculatorOptions] { -// output_layer_index: 1 -// # gpu_origin: CONVENTIONAL # or TOP_LEFT -// } -// } -// } -// -// TODO Refactor and add support for other backends/platforms. -// -class TensorsToSegmentationCalculator : public CalculatorBase { - public: - static absl::Status GetContract(CalculatorContract* cc); + // Converts Tensors from a tflite segmentation model to an image mask. + // + // Performs optional upscale to OUTPUT_SIZE dimensions if provided, + // otherwise the mask is the same size as input tensor. + // + // If at least one input tensor is already on GPU, processing happens on GPU and + // the output mask is also stored on GPU. Otherwise, processing and the output + // mask are both on CPU. + // + // On GPU, the mask is an RGBA image, in both the R & A channels, scaled 0-1. + // On CPU, the mask is a ImageFormat::VEC32F1 image, with values scaled 0-1. + // + // + // Inputs: + // One of the following TENSORS tags: + // TENSORS: Vector of Tensor, + // The tensor dimensions are specified in this calculator's options. + // OUTPUT_SIZE(optional): std::pair, + // If provided, the size to upscale mask to. + // + // Output: + // MASK: An Image output mask, RGBA(GPU) / VEC32F1(CPU). + // + // Options: + // See tensors_to_segmentation_calculator.proto + // + // Usage example: + // node { + // calculator: "TensorsToSegmentationCalculator" + // input_stream: "TENSORS:tensors" + // input_stream: "OUTPUT_SIZE:size" + // output_stream: "MASK:hair_mask" + // node_options: { + // [mediapipe.TensorsToSegmentationCalculatorOptions] { + // output_layer_index: 1 + // # gpu_origin: CONVENTIONAL # or TOP_LEFT + // } + // } + // } + // + // TODO Refactor and add support for other backends/platforms. + // + class TensorsToSegmentationCalculator : public CalculatorBase + { + public: + static absl::Status GetContract(CalculatorContract *cc); - absl::Status Open(CalculatorContext* cc) override; - absl::Status Process(CalculatorContext* cc) override; - absl::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext *cc) override; + absl::Status Process(CalculatorContext *cc) override; + absl::Status Close(CalculatorContext *cc) override; - private: - absl::Status LoadOptions(CalculatorContext* cc); - absl::Status InitGpu(CalculatorContext* cc); - absl::Status ProcessGpu(CalculatorContext* cc); - absl::Status ProcessCpu(CalculatorContext* cc); - void GlRender(); + private: + absl::Status LoadOptions(CalculatorContext *cc); + absl::Status InitGpu(CalculatorContext *cc); + absl::Status ProcessGpu(CalculatorContext *cc); + absl::Status ProcessCpu(CalculatorContext *cc); + void GlRender(); - bool DoesGpuTextureStartAtBottom() { - return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; - } + bool DoesGpuTextureStartAtBottom() + { + return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; + } - template - absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); + template + absl::Status ApplyActivation(cv::Mat &tensor_mat, cv::Mat *small_mask_mat); - ::mediapipe::TensorsToSegmentationCalculatorOptions options_; + ::mediapipe::TensorsToSegmentationCalculatorOptions options_; #if !MEDIAPIPE_DISABLE_GPU - mediapipe::GlCalculatorHelper gpu_helper_; - GLuint upsample_program_; + mediapipe::GlCalculatorHelper gpu_helper_; + GLuint upsample_program_; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - std::unique_ptr mask_program_31_; + std::unique_ptr mask_program_31_; #else - GLuint mask_program_20_; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + GLuint mask_program_20_; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_METAL_ENABLED - MPPMetalHelper* metal_helper_ = nullptr; - id mask_program_; -#endif // MEDIAPIPE_METAL_ENABLED -#endif // !MEDIAPIPE_DISABLE_GPU -}; -REGISTER_CALCULATOR(TensorsToSegmentationCalculator); + MPPMetalHelper *metal_helper_ = nullptr; + id mask_program_; +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + }; + REGISTER_CALCULATOR(TensorsToSegmentationCalculator); -// static -absl::Status TensorsToSegmentationCalculator::GetContract( - CalculatorContract* cc) { - RET_CHECK(!cc->Inputs().GetTags().empty()); - RET_CHECK(!cc->Outputs().GetTags().empty()); + // static + absl::Status TensorsToSegmentationCalculator::GetContract( + CalculatorContract *cc) + { + RET_CHECK(!cc->Inputs().GetTags().empty()); + RET_CHECK(!cc->Outputs().GetTags().empty()); - // Inputs. - cc->Inputs().Tag(kTensorsTag).Set>(); - if (cc->Inputs().HasTag(kOutputSizeTag)) { - cc->Inputs().Tag(kOutputSizeTag).Set>(); - } + // Inputs. + cc->Inputs().Tag(kTensorsTag).Set>(); + if (cc->Inputs().HasTag(kOutputSizeTag)) + { + cc->Inputs().Tag(kOutputSizeTag).Set>(); + } - // Outputs. - cc->Outputs().Tag(kMaskTag).Set(); + // Outputs. + cc->Outputs().Tag(kMaskTag).Set(); - if (CanUseGpu()) { + if (CanUseGpu()) + { #if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if MEDIAPIPE_METAL_ENABLED - MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); -#endif // MEDIAPIPE_METAL_ENABLED -#endif // !MEDIAPIPE_DISABLE_GPU - } + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } - return absl::OkStatus(); -} - -absl::Status TensorsToSegmentationCalculator::Open(CalculatorContext* cc) { - cc->SetOffset(TimestampDiff(0)); - bool use_gpu = false; - - if (CanUseGpu()) { -#if !MEDIAPIPE_DISABLE_GPU - use_gpu = true; - MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#if MEDIAPIPE_METAL_ENABLED - metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; - RET_CHECK(metal_helper_); -#endif // MEDIAPIPE_METAL_ENABLED -#endif // !MEDIAPIPE_DISABLE_GPU - } - - MP_RETURN_IF_ERROR(LoadOptions(cc)); - - if (use_gpu) { -#if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(InitGpu(cc)); -#else - RET_CHECK_FAIL() << "GPU processing disabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } - - return absl::OkStatus(); -} - -absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) { - if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { return absl::OkStatus(); } - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); + absl::Status TensorsToSegmentationCalculator::Open(CalculatorContext *cc) + { + cc->SetOffset(TimestampDiff(0)); + bool use_gpu = false; + begin = std::chrono::steady_clock::now(); + if (CanUseGpu()) + { +#if !MEDIAPIPE_DISABLE_GPU + use_gpu = true; + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#if MEDIAPIPE_METAL_ENABLED + metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(metal_helper_); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } - bool use_gpu = false; - if (CanUseGpu()) { - // Use GPU processing only if at least one input tensor is already on GPU. - for (const auto& tensor : input_tensors) { - if (tensor.ready_on_gpu()) { - use_gpu = true; - break; + MP_RETURN_IF_ERROR(LoadOptions(cc)); + + if (use_gpu) + { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(InitGpu(cc)); +#else + RET_CHECK_FAIL() << "GPU processing disabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); + } + + absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext *cc) + { + if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) + { + return absl::OkStatus(); + } + + const auto &input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + + bool use_gpu = false; + if (CanUseGpu()) + { + // Use GPU processing only if at least one input tensor is already on GPU. + for (const auto &tensor : input_tensors) + { + if (tensor.ready_on_gpu()) + { + use_gpu = true; + break; + } } } - } - // Validate tensor channels and activation type. - { - RET_CHECK(!input_tensors.empty()); - ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - int tensor_channels = std::get<2>(hwc); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; - switch (options_.activation()) { - case Options::NONE: - RET_CHECK_EQ(tensor_channels, 1); - break; - case Options::SIGMOID: - RET_CHECK_EQ(tensor_channels, 1); - break; - case Options::SOFTMAX: - RET_CHECK_EQ(tensor_channels, 2); - break; + // Validate tensor channels and activation type. + /*{ + RET_CHECK(!input_tensors.empty()); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + int tensor_channels = std::get<2>(hwc); + typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + switch (options_.activation()) { + case Options::NONE: + RET_CHECK_EQ(tensor_channels, 1); + break; + case Options::SIGMOID: + RET_CHECK_EQ(tensor_channels, 1); + break; + case Options::SOFTMAX: + RET_CHECK_EQ(tensor_channels, 2); + break; + } } - } - - if (use_gpu) { + */ + /* if (use_gpu) + { #if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status + { MP_RETURN_IF_ERROR(ProcessGpu(cc)); - return absl::OkStatus(); - })); + return absl::OkStatus(); })); #else - RET_CHECK_FAIL() << "GPU processing disabled."; -#endif // !MEDIAPIPE_DISABLE_GPU - } else { - MP_RETURN_IF_ERROR(ProcessCpu(cc)); - } - - return absl::OkStatus(); -} - -absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) { -#if !MEDIAPIPE_DISABLE_GPU - gpu_helper_.RunInGlContext([this] { - if (upsample_program_) glDeleteProgram(upsample_program_); - upsample_program_ = 0; -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - mask_program_31_.reset(); -#else - if (mask_program_20_) glDeleteProgram(mask_program_20_); - mask_program_20_ = 0; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 -#if MEDIAPIPE_METAL_ENABLED - mask_program_ = nil; -#endif // MEDIAPIPE_METAL_ENABLED - }); -#endif // !MEDIAPIPE_DISABLE_GPU - - return absl::OkStatus(); -} - -absl::Status TensorsToSegmentationCalculator::ProcessCpu( - CalculatorContext* cc) { - // Get input streams, and dimensions. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) { - const auto& size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } - - // Create initial working mask. - cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); - - // Wrap input tensor. - auto raw_input_tensor = &input_tensors[0]; - auto raw_input_view = raw_input_tensor->GetCpuReadView(); - const float* raw_input_data = raw_input_view.buffer(); - cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), - CV_MAKETYPE(CV_32F, tensor_channels), - const_cast(raw_input_data)); - - // Process mask tensor and apply activation function. - if (tensor_channels == 2) { - MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); - } else if (tensor_channels == 1) { - RET_CHECK(mediapipe::TensorsToSegmentationCalculatorOptions::SOFTMAX != - options_.activation()); // Requires 2 channels. - if (mediapipe::TensorsToSegmentationCalculatorOptions::NONE == - options_.activation()) // Pass-through optimization. - tensor_mat.copyTo(small_mask_mat); + RET_CHECK_FAIL() << "GPU processing disabled."; +#endif // !MEDIAPIPE_DISABLE_GPU + } else - MP_RETURN_IF_ERROR(ApplyActivation(tensor_mat, &small_mask_mat)); - } else { - RET_CHECK_FAIL() << "Unsupported number of tensor channels " - << tensor_channels; + { */ + MP_RETURN_IF_ERROR(ProcessCpu(cc)); + //} + + return absl::OkStatus(); } - // Send out image as CPU packet. - std::shared_ptr mask_frame = std::make_shared( - ImageFormat::VEC32F1, output_width, output_height); - std::unique_ptr output_mask = absl::make_unique(mask_frame); - auto output_mat = formats::MatView(output_mask.get()); - // Upsample small mask into output. - cv::resize(small_mask_mat, *output_mat, - cv::Size(output_width, output_height)); - cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); + absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext *cc) + { + end = std::chrono::steady_clock::now(); - return absl::OkStatus(); -} + std::cout << "Time difference = " << std::chrono::duration_cast(end - begin).count() << "[µs]" << std::endl; + std::cout << "Time difference = " << std::chrono::duration_cast(end - begin).count() << "[ns]" << std::endl; -template -absl::Status TensorsToSegmentationCalculator::ApplyActivation( - cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { - // Configure activation function. - const int output_layer_index = options_.output_layer_index(); - typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; - const auto activation_fn = [&](const cv::Vec2f& mask_value) { - float new_mask_value = 0; - // TODO consider moving switch out of the loop, - // and also avoid float/Vec2f casting. - switch (options_.activation()) { - case Options::NONE: { +#if !MEDIAPIPE_DISABLE_GPU + gpu_helper_.RunInGlContext([this] + { + if (upsample_program_) + glDeleteProgram(upsample_program_); + upsample_program_ = 0; +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + mask_program_31_.reset(); +#else + if (mask_program_20_) + glDeleteProgram(mask_program_20_); + mask_program_20_ = 0; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#if MEDIAPIPE_METAL_ENABLED + mask_program_ = nil; +#endif // MEDIAPIPE_METAL_ENABLED + }); +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); + } + + absl::Status TensorsToSegmentationCalculator::ProcessCpu( + CalculatorContext *cc) + { + // Get input streams, and dimensions. + const auto &input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) + { + const auto &size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + // Create initial working mask. + cv::Mat small_mask_mat(cv::Size(tensor_width, tensor_height), CV_32FC1); + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float *raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + // std::cout << tensor_mat.channels() << std::endl; + std::vector channels(4); + cv::split(tensor_mat, channels); + for (auto ch : channels) + ch = (ch + 1) * 127.5; + + cv::merge(channels, tensor_mat); + + cv::convertScaleAbs(tensor_mat, tensor_mat); + // std::cout << "R (numpy) = " << std::endl << cv::format(tensor_mat, cv::Formatter::FMT_NUMPY ) << std::endl << std::endl; + + // Send out image as CPU packet. + std::shared_ptr mask_frame = std::make_shared( + ImageFormat::SRGB, output_width, output_height); + std::unique_ptr output_mask = absl::make_unique(mask_frame); + auto output_mat = formats::MatView(output_mask.get()); + // Upsample small mask into output. + cv::resize(tensor_mat, *output_mat, + cv::Size(output_width, output_height)); + cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + + template + absl::Status TensorsToSegmentationCalculator::ApplyActivation( + cv::Mat &tensor_mat, cv::Mat *small_mask_mat) + { + // Configure activation function. + const int output_layer_index = options_.output_layer_index(); + typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; + const auto activation_fn = [&](const cv::Vec2f &mask_value) + { + float new_mask_value = 0; + // TODO consider moving switch out of the loop, + // and also avoid float/Vec2f casting. + switch (options_.activation()) + { + case Options::NONE: + { new_mask_value = mask_value[0]; break; } - case Options::SIGMOID: { + case Options::SIGMOID: + { const float pixel0 = mask_value[0]; new_mask_value = 1.0 / (std::exp(-pixel0) + 1.0); break; } - case Options::SOFTMAX: { + case Options::SOFTMAX: + { const float pixel0 = mask_value[0]; const float pixel1 = mask_value[1]; const float max_pixel = std::max(pixel0, pixel1); @@ -396,224 +444,147 @@ absl::Status TensorsToSegmentationCalculator::ApplyActivation( softmax_denom; break; } - } - return new_mask_value; - }; + } + return new_mask_value; + }; - // Process mask tensor. - for (int i = 0; i < tensor_mat.rows; ++i) { - for (int j = 0; j < tensor_mat.cols; ++j) { - const T& input_pix = tensor_mat.at(i, j); - const float mask_value = activation_fn(input_pix); - small_mask_mat->at(i, j) = mask_value; - } - } - - return absl::OkStatus(); -} - -// Steps: -// 1. receive tensor -// 2. process segmentation tensor into small mask -// 3. upsample small mask into output mask to be same size as input image -absl::Status TensorsToSegmentationCalculator::ProcessGpu( - CalculatorContext* cc) { -#if !MEDIAPIPE_DISABLE_GPU - // Get input streams, and dimensions. - const auto& input_tensors = - cc->Inputs().Tag(kTensorsTag).Get>(); - ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); - auto [tensor_height, tensor_width, tensor_channels] = hwc; - int output_width = tensor_width, output_height = tensor_height; - if (cc->Inputs().HasTag(kOutputSizeTag)) { - const auto& size = - cc->Inputs().Tag(kOutputSizeTag).Get>(); - output_width = size.first; - output_height = size.second; - } - - // Create initial working mask texture. -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - tflite::gpu::gl::GlTexture small_mask_texture; -#else - mediapipe::GlTexture small_mask_texture; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - - // Run shader, process mask tensor. -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - { - MP_RETURN_IF_ERROR(CreateReadWriteRgbaImageTexture( - tflite::gpu::DataType::UINT8, // GL_RGBA8 - {tensor_width, tensor_height}, &small_mask_texture)); - - const int output_index = 0; - glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, - GL_WRITE_ONLY, GL_RGBA8); - - auto read_view = input_tensors[0].GetOpenGlBufferReadView(); - glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, read_view.name()); - - const tflite::gpu::uint3 workgroups = { - NumGroups(tensor_width, kWorkgroupSize), - NumGroups(tensor_height, kWorkgroupSize), 1}; - - glUseProgram(mask_program_31_->id()); - glUniform2i(glGetUniformLocation(mask_program_31_->id(), "out_size"), - tensor_width, tensor_height); - - MP_RETURN_IF_ERROR(mask_program_31_->Dispatch(workgroups)); - } -#elif MEDIAPIPE_METAL_ENABLED - { - id command_buffer = [metal_helper_ commandBuffer]; - command_buffer.label = @"SegmentationKernel"; - id command_encoder = - [command_buffer computeCommandEncoder]; - [command_encoder setComputePipelineState:mask_program_]; - - auto read_view = input_tensors[0].GetMtlBufferReadView(command_buffer); - [command_encoder setBuffer:read_view.buffer() offset:0 atIndex:0]; - - mediapipe::GpuBuffer small_mask_buffer = [metal_helper_ - mediapipeGpuBufferWithWidth:tensor_width - height:tensor_height - format:mediapipe::GpuBufferFormat::kBGRA32]; - id small_mask_texture_metal = - [metal_helper_ metalTextureWithGpuBuffer:small_mask_buffer]; - [command_encoder setTexture:small_mask_texture_metal atIndex:1]; - - unsigned int out_size[] = {static_cast(tensor_width), - static_cast(tensor_height)}; - [command_encoder setBytes:&out_size length:sizeof(out_size) atIndex:2]; - - MTLSize threads_per_group = MTLSizeMake(kWorkgroupSize, kWorkgroupSize, 1); - MTLSize threadgroups = - MTLSizeMake(NumGroups(tensor_width, kWorkgroupSize), - NumGroups(tensor_height, kWorkgroupSize), 1); - [command_encoder dispatchThreadgroups:threadgroups - threadsPerThreadgroup:threads_per_group]; - [command_encoder endEncoding]; - [command_buffer commit]; - - small_mask_texture = gpu_helper_.CreateSourceTexture(small_mask_buffer); - } -#else - { - small_mask_texture = gpu_helper_.CreateDestinationTexture( - tensor_width, tensor_height, - mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 - - // Go through CPU if not already texture 2D (no direct conversion yet). - // Tensor::GetOpenGlTexture2dReadView() doesn't automatically convert types. - if (!input_tensors[0].ready_as_opengl_texture_2d()) { - (void)input_tensors[0].GetCpuReadView(); + // Process mask tensor. + for (int i = 0; i < tensor_mat.rows; ++i) + { + for (int j = 0; j < tensor_mat.cols; ++j) + { + const T &input_pix = tensor_mat.at(i, j); + const float mask_value = activation_fn(input_pix); + small_mask_mat->at(i, j) = mask_value; + } } - auto read_view = input_tensors[0].GetOpenGlTexture2dReadView(); - - gpu_helper_.BindFramebuffer(small_mask_texture); - glActiveTexture(GL_TEXTURE1); - glBindTexture(GL_TEXTURE_2D, read_view.name()); - glUseProgram(mask_program_20_); - GlRender(); - glBindTexture(GL_TEXTURE_2D, 0); - glFlush(); + return absl::OkStatus(); } -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - // Upsample small mask into output. - mediapipe::GlTexture output_texture = gpu_helper_.CreateDestinationTexture( - output_width, output_height, - mediapipe::GpuBufferFormat::kBGRA32); // actually GL_RGBA8 - - // Run shader, upsample result. + // Steps: + // 1. receive tensor + // 2. process segmentation tensor into small mask + // 3. upsample small mask into output mask to be same size as input image + absl::Status TensorsToSegmentationCalculator::ProcessGpu( + CalculatorContext *cc) { - gpu_helper_.BindFramebuffer(output_texture); - glActiveTexture(GL_TEXTURE1); -#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - glBindTexture(GL_TEXTURE_2D, small_mask_texture.id()); -#else - glBindTexture(GL_TEXTURE_2D, small_mask_texture.name()); -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - glUseProgram(upsample_program_); - GlRender(); - glBindTexture(GL_TEXTURE_2D, 0); - glFlush(); +#if !MEDIAPIPE_DISABLE_GPU + + // Get input streams, and dimensions. + const auto &input_tensors = + cc->Inputs().Tag(kTensorsTag).Get>(); + ASSIGN_OR_RETURN(auto hwc, GetHwcFromDims(input_tensors[0].shape().dims)); + auto [tensor_height, tensor_width, tensor_channels] = hwc; + int output_width = tensor_width, output_height = tensor_height; + if (cc->Inputs().HasTag(kOutputSizeTag)) + { + const auto &size = + cc->Inputs().Tag(kOutputSizeTag).Get>(); + output_width = size.first; + output_height = size.second; + } + + // Wrap input tensor. + auto raw_input_tensor = &input_tensors[0]; + auto raw_input_view = raw_input_tensor->GetCpuReadView(); + const float *raw_input_data = raw_input_view.buffer(); + cv::Mat tensor_mat(cv::Size(tensor_width, tensor_height), + CV_MAKETYPE(CV_32F, tensor_channels), + const_cast(raw_input_data)); + + // std::cout << tensor_mat.channels() << std::endl; + std::vector channels(4); + cv::split(tensor_mat, channels); + for (auto ch : channels) + ch = (ch + 1) * 127.5; + + cv::merge(channels, tensor_mat); + + cv::convertScaleAbs(tensor_mat, tensor_mat); + // std::cout << "R (numpy) = " << std::endl << cv::format(tensor_mat, cv::Formatter::FMT_NUMPY ) << std::endl << std::endl; + + // Send out image as CPU packet. + std::shared_ptr mask_frame = std::make_shared( + ImageFormat::SRGB, output_width, output_height); + std::unique_ptr output_mask = absl::make_unique(mask_frame); + auto output_mat = formats::MatView(output_mask.get()); + // Upsample small mask into output. + cv::resize(tensor_mat, *output_mat, + cv::Size(output_width, output_height)); + cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); + +#endif // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); } - // Send out image as GPU packet. - auto output_image = output_texture.GetFrame(); - cc->Outputs().Tag(kMaskTag).Add(output_image.release(), cc->InputTimestamp()); - - // Cleanup - output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU - - return absl::OkStatus(); -} - -void TensorsToSegmentationCalculator::GlRender() { + void TensorsToSegmentationCalculator::GlRender() + { #if !MEDIAPIPE_DISABLE_GPU - static const GLfloat square_vertices[] = { - -1.0f, -1.0f, // bottom left - 1.0f, -1.0f, // bottom right - -1.0f, 1.0f, // top left - 1.0f, 1.0f, // top right - }; - static const GLfloat texture_vertices[] = { - 0.0f, 0.0f, // bottom left - 1.0f, 0.0f, // bottom right - 0.0f, 1.0f, // top left - 1.0f, 1.0f, // top right - }; + static const GLfloat square_vertices[] = { + -1.0f, -1.0f, // bottom left + 1.0f, -1.0f, // bottom right + -1.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; + static const GLfloat texture_vertices[] = { + 0.0f, 0.0f, // bottom left + 1.0f, 0.0f, // bottom right + 0.0f, 1.0f, // top left + 1.0f, 1.0f, // top right + }; - // vertex storage - GLuint vbo[2]; - glGenBuffers(2, vbo); - GLuint vao; - glGenVertexArrays(1, &vao); - glBindVertexArray(vao); + // vertex storage + GLuint vbo[2]; + glGenBuffers(2, vbo); + GLuint vao; + glGenVertexArrays(1, &vao); + glBindVertexArray(vao); - // vbo 0 - glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); - glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, - GL_STATIC_DRAW); - glEnableVertexAttribArray(ATTRIB_VERTEX); - glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); + // vbo 0 + glBindBuffer(GL_ARRAY_BUFFER, vbo[0]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), square_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_VERTEX); + glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, nullptr); - // vbo 1 - glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); - glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, - GL_STATIC_DRAW); - glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); - glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); + // vbo 1 + glBindBuffer(GL_ARRAY_BUFFER, vbo[1]); + glBufferData(GL_ARRAY_BUFFER, 4 * 2 * sizeof(GLfloat), texture_vertices, + GL_STATIC_DRAW); + glEnableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glVertexAttribPointer(ATTRIB_TEXTURE_POSITION, 2, GL_FLOAT, 0, 0, nullptr); - // draw - glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); + // draw + glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); - // cleanup - glDisableVertexAttribArray(ATTRIB_VERTEX); - glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); - glBindBuffer(GL_ARRAY_BUFFER, 0); - glBindVertexArray(0); - glDeleteVertexArrays(1, &vao); - glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU -} + // cleanup + glDisableVertexAttribArray(ATTRIB_VERTEX); + glDisableVertexAttribArray(ATTRIB_TEXTURE_POSITION); + glBindBuffer(GL_ARRAY_BUFFER, 0); + glBindVertexArray(0); + glDeleteVertexArrays(1, &vao); + glDeleteBuffers(2, vbo); +#endif // !MEDIAPIPE_DISABLE_GPU + } -absl::Status TensorsToSegmentationCalculator::LoadOptions( - CalculatorContext* cc) { - // Get calculator options specified in the graph. - options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); + absl::Status TensorsToSegmentationCalculator::LoadOptions( + CalculatorContext *cc) + { + // Get calculator options specified in the graph. + options_ = cc->Options<::mediapipe::TensorsToSegmentationCalculatorOptions>(); - return absl::OkStatus(); -} + return absl::OkStatus(); + } -absl::Status TensorsToSegmentationCalculator::InitGpu(CalculatorContext* cc) { + absl::Status TensorsToSegmentationCalculator::InitGpu(CalculatorContext *cc) + { #if !MEDIAPIPE_DISABLE_GPU - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { - // A shader to process a segmentation tensor into an output mask. - // Currently uses 4 channels for output, and sets R+A channels as mask value. + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status + { + // A shader to process a segmentation tensor into an output mask. + // Currently uses 4 channels for output, and sets R+A channels as mask value. #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // GLES 3.1 const tflite::gpu::uint3 workgroup_size = {kWorkgroupSize, kWorkgroupSize, @@ -794,7 +765,7 @@ void main() { vec4 out_value = vec4(new_mask_value, 0.0, 0.0, new_mask_value); fragColor = out_value; })"; -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Shader defines. typedef mediapipe::TensorsToSegmentationCalculatorOptions Options; @@ -830,7 +801,7 @@ void main() { "texture_coordinate", }; - // Main shader program & parameters + // Main shader program & parameters #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 GlShader shader_without_previous; MP_RETURN_IF_ERROR(GlShader::CompileShader( @@ -862,7 +833,7 @@ void main() { RET_CHECK(mask_program_20_) << "Problem initializing the program."; glUseProgram(mask_program_20_); glUniform1i(glGetUniformLocation(mask_program_20_, "input_texture"), 1); -#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Simple pass-through program, used for hardware upsampling. mediapipe::GlhCreateProgram( @@ -872,11 +843,10 @@ void main() { glUseProgram(upsample_program_); glUniform1i(glGetUniformLocation(upsample_program_, "video_frame"), 1); + return absl::OkStatus(); })); +#endif // !MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU + } - return absl::OkStatus(); -} - -} // namespace mediapipe +} // namespace mediapipe diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/imagestylegpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/imagestylegpu/BUILD new file mode 100644 index 000000000..9e50d2e74 --- /dev/null +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/imagestylegpu/BUILD @@ -0,0 +1,60 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//visibility:private"]) + +cc_binary( + name = "libmediapipe_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/graphs/image_style:mobile_calculators", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + ], +) + +cc_library( + name = "mediapipe_jni_lib", + srcs = [":libmediapipe_jni.so"], + alwayslink = 1, +) + +android_binary( + name = "imagestylegpu", + srcs = glob(["*.java"]), + assets = [ + "//mediapipe/graphs/image_style:mobile_gpu.binarypb", + "//mediapipe/models:model_float32.tflite", + ], + assets_dir = "", + manifest = "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.apps.imagestylegpu", + "appName": "Image Style", + "mainActivity": "com.google.mediapipe.apps.basic.MainActivity", + "cameraFacingFront": "True", + "binaryGraphName": "mobile_gpu.binarypb", + "inputVideoStreamName": "input_video", + "outputVideoStreamName": "output_video", + "flipFramesVertically": "True", + "converterNumBuffers": "2", + }, + multidex = "native", + deps = [ + ":mediapipe_jni_lib", + "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", + ], +) diff --git a/mediapipe/graphs/image_style/BUILD b/mediapipe/graphs/image_style/BUILD index 90e8d9346..27ed8b823 100644 --- a/mediapipe/graphs/image_style/BUILD +++ b/mediapipe/graphs/image_style/BUILD @@ -24,22 +24,14 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "mobile_calculators", deps = [ - "//mediapipe/calculators/tensorflow:tensor_to_image_frame_calculator", - "//mediapipe/calculators/tensorflow:vector_float_to_tensor_calculator", - "//mediapipe/calculators/tensor:tensors_to_floats_calculator", - "//mediapipe/calculators/tensor:tensors_to_segmentation_calculator", - "//mediapipe/calculators/util:from_image_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", - "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_segmentation_calculator", - "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_segmentation_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", ], ) @@ -47,18 +39,17 @@ cc_library( name = "desktop_calculators", deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/image:image_transformation_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_gpuimage_calculator", - "//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_segmentation_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_segmentation_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/calculators/util:from_image_calculator", ], ) mediapipe_binary_graph( name = "mobile_gpu_binary_graph", - graph = "image_style.pbtxt", + graph = "image_style_gpu.pbtxt", output_name = "mobile_gpu.binarypb", deps = [":mobile_calculators"], ) diff --git a/mediapipe/graphs/image_style/image_style.pbtxt b/mediapipe/graphs/image_style/image_style.pbtxt index a1860d14f..6ff9318f0 100644 --- a/mediapipe/graphs/image_style/image_style.pbtxt +++ b/mediapipe/graphs/image_style/image_style.pbtxt @@ -42,8 +42,8 @@ node { options { [mediapipe.TfLiteConverterCalculatorOptions.ext] { output_tensor_float_range { - min: 0 - max: 255 + min: -1 + max: 1 } } } diff --git a/mediapipe/graphs/image_style/image_style_cpu (copy).pbtxt b/mediapipe/graphs/image_style/image_style_cpu (copy).pbtxt index 6d9d64318..5ba1ab286 100644 --- a/mediapipe/graphs/image_style/image_style_cpu (copy).pbtxt +++ b/mediapipe/graphs/image_style/image_style_cpu (copy).pbtxt @@ -1,19 +1,30 @@ -# MediaPipe graph that performs object detection on desktop with TensorFlow Lite -# on CPU. -# Used in the example in -# mediapipe/examples/desktop/object_detection:object_detection_tflite. +# MediaPipe graph that performs face mesh with TensorFlow Lite on CPU. -# max_queue_size limits the number of packets enqueued on any input stream -# by throttling inputs to the graph. This makes the graph only process one -# frame per time. -max_queue_size: 1 +# Input image. (ImageFrame) +input_stream: "input_video" -# Decodes an input video file into images and a video header. +# Output image with rendered results. (ImageFrame) +output_stream: "output_video" + +# Throttles the images flowing downstream for flow control. It passes through +# the very first incoming image unaltered, and waits for downstream nodes +# (calculators and subgraphs) in the graph to finish their tasks before it +# passes through another image. All images that come in while waiting are +# dropped, limiting the number of in-flight images in most part of the graph to +# 1. This prevents the downstream nodes from queuing up incoming images and data +# excessively, which leads to increased latency and memory usage, unwanted in +# real-time mobile applications. It also eliminates unnecessarily computation, +# e.g., the output produced by a node may get dropped downstream if the +# subsequent nodes are still busy processing previous inputs. node { - calculator: "OpenCvVideoDecoderCalculator" - input_side_packet: "INPUT_FILE_PATH:input_video_path" - output_stream: "VIDEO:input_video" - output_stream: "VIDEO_PRESTREAM:input_video_header" + calculator: "FlowLimiterCalculator" + input_stream: "input_video" + input_stream: "FINISHED:output_video" + input_stream_info: { + tag_index: "FINISHED" + back_edge: true + } + output_stream: "throttled_input_video" } # Transforms the input image on CPU to a 320x320 image. To scale the image, by @@ -23,12 +34,12 @@ node { # detection model used in this graph is agnostic to that deformation. node: { calculator: "ImageTransformationCalculator" - input_stream: "IMAGE:input_video" + input_stream: "IMAGE:throttled_input_video" output_stream: "IMAGE:transformed_input_video" node_options: { [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { - output_width: 512 - output_height: 512 + output_width: 256 + output_height: 256 } } } @@ -39,58 +50,45 @@ node: { node { calculator: "TfLiteConverterCalculator" input_stream: "IMAGE:transformed_input_video" - output_stream: "TENSORS:image_tensor" - node_options: { - [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] { - zero_center: true - } + output_stream: "TENSORS:input_tensors" + options { + [mediapipe.TfLiteConverterCalculatorOptions.ext] { + zero_center: false + max_num_channels: 3 + output_tensor_float_range { + min: 0.0 + max: 255.0 + } + } } } + # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a # vector of tensors representing, for instance, detection boxes/keypoints and # scores. node { calculator: "TfLiteInferenceCalculator" - input_stream: "TENSORS:image_tensor" - output_stream: "TENSORS:stylized_tensor" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:output_tensors" node_options: { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "mediapipe/models/metaf-512-mobile3.tflite" + model_path: "mediapipe/models/model_float32.tflite" } } } -node { - calculator: "TfliteTensorsToGpuImageCalculator" - input_stream: "TENSORS:stylized_tensor" - output_stream: "IMAGE:image" -} -#node { -# calculator: "TfLiteTensorsToSegmentationCalculator" -# input_stream: "TENSORS:stylized_tensor" -# output_stream: "MASK:mask_image" -# node_options: { -# [type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { -# tensor_width: 512 -# tensor_height: 512 -# tensor_channels: 3 -# } -# } -#} - -# Encodes the annotated images into a video file, adopting properties specified -# in the input video header, e.g., video framerate. node { - calculator: "OpenCvVideoEncoderCalculator" - input_stream: "VIDEO:image" - input_stream: "VIDEO_PRESTREAM:input_video_header" - input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + calculator: "TfLiteTensorsToSegmentationCalculator" + input_stream: "TENSORS:output_tensors" + output_stream: "MASK:output_video" node_options: { - [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { - codec: "avc1" - video_format: "mp4" - } - } + [type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { + tensor_width: 256 + tensor_height: 256 + tensor_channels: 3 + } + } } + diff --git a/mediapipe/graphs/image_style/image_style_cpu.pbtxt b/mediapipe/graphs/image_style/image_style_cpu.pbtxt index 1a78cf6c0..24074ac84 100644 --- a/mediapipe/graphs/image_style/image_style_cpu.pbtxt +++ b/mediapipe/graphs/image_style/image_style_cpu.pbtxt @@ -6,16 +6,7 @@ input_stream: "input_video" # Output image with rendered results. (ImageFrame) output_stream: "output_video" -# Throttles the images flowing downstream for flow control. It passes through -# the very first incoming image unaltered, and waits for downstream nodes -# (calculators and subgraphs) in the graph to finish their tasks before it -# passes through another image. All images that come in while waiting are -# dropped, limiting the number of in-flight images in most part of the graph to -# 1. This prevents the downstream nodes from queuing up incoming images and data -# excessively, which leads to increased latency and memory usage, unwanted in -# real-time mobile applications. It also eliminates unnecessarily computation, -# e.g., the output produced by a node may get dropped downstream if the -# subsequent nodes are still busy processing previous inputs. + node { calculator: "FlowLimiterCalculator" input_stream: "input_video" @@ -27,67 +18,59 @@ node { output_stream: "throttled_input_video" } -# Transforms the input image on CPU to a 320x320 image. To scale the image, by -# default it uses the STRETCH scale mode that maps the entire input image to the -# entire transformed image. As a result, image aspect ratio may be changed and -# objects in the image may be deformed (stretched or squeezed), but the object -# detection model used in this graph is agnostic to that deformation. + node: { - calculator: "ImageTransformationCalculator" - input_stream: "IMAGE:throttled_input_video" - output_stream: "IMAGE:transformed_input_video" - node_options: { - [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { - output_width: 256 - output_height: 256 - } - } + calculator: "ToImageCalculator" + input_stream: "IMAGE_CPU:throttled_input_video" + output_stream: "IMAGE:image_input_video" } -# Converts the transformed input image on CPU into an image tensor as a -# TfLiteTensor. The zero_center option is set to true to normalize the -# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. node { - calculator: "TfLiteConverterCalculator" - input_stream: "IMAGE:transformed_input_video" - output_stream: "TENSORS:input_tensors" - options { - [mediapipe.TfLiteConverterCalculatorOptions.ext] { - output_tensor_float_range { - min: 0 - max: 255 - } - max_num_channels: 3 - } - } -} - - -# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a -# vector of tensors representing, for instance, detection boxes/keypoints and -# scores. -node { - calculator: "TfLiteInferenceCalculator" - input_stream: "TENSORS:input_tensors" - output_stream: "TENSORS:output_tensors" - node_options: { - [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { - model_path: "mediapipe/models/model_float32.tflite" + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:image_input_video" + output_stream: "TENSORS:input_tensor" + options: { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 256 + output_tensor_height: 256 + keep_aspect_ratio: true + output_tensor_float_range { + min: -1.0 + max: 1.0 + } + border_mode: BORDER_ZERO } } } node { - calculator: "TfLiteTensorsToSegmentationCalculator" - input_stream: "TENSORS:output_tensors" - output_stream: "MASK:output_video" - node_options: { - [type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] { - tensor_width: 256 - tensor_height: 256 - tensor_channels: 3 - } + calculator: "InferenceCalculator" + input_stream: "TENSORS:input_tensor" + output_stream: "TENSORS:output_tensor" + options: { + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/models/model_float32.tflite" + delegate { xnnpack {} } + } + } +} + + +node { + calculator: "TensorsToSegmentationCalculator" + input_stream: "TENSORS:output_tensor" + output_stream: "MASK:output" + options: { + [mediapipe.TensorsToSegmentationCalculatorOptions.ext] { + activation: NONE + } } } +node{ + calculator: "FromImageCalculator" + input_stream: "IMAGE:output" + output_stream: "IMAGE_CPU:output_video" +} + diff --git a/mediapipe/graphs/image_style/image_style_gpu.pbtxt b/mediapipe/graphs/image_style/image_style_gpu.pbtxt index 7e48e800e..9ec991440 100644 --- a/mediapipe/graphs/image_style/image_style_gpu.pbtxt +++ b/mediapipe/graphs/image_style/image_style_gpu.pbtxt @@ -18,30 +18,18 @@ node { output_stream: "throttled_input_video" } -node: { - calculator: "ImageTransformationCalculator" - input_stream: "IMAGE_GPU:throttled_input_video" - output_stream: "IMAGE_GPU:transformed_input_video" - node_options: { - [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { - output_width: 512 - output_height: 512 - } - } -} - node: { calculator: "ImageToTensorCalculator" - input_stream: "IMAGE_GPU:transformed_input_video" + input_stream: "IMAGE_GPU:throttled_input_video" output_stream: "TENSORS:input_tensors" options { [mediapipe.ImageToTensorCalculatorOptions.ext] { - output_tensor_width: 512 - output_tensor_height: 512 - keep_aspect_ratio: true + output_tensor_width: 256 + output_tensor_height: 256 + keep_aspect_ratio: false output_tensor_float_range { - min: 0.0 - max: 255.0 + min: -1.0 + max: 1.0 } gpu_origin: TOP_LEFT border_mode: BORDER_REPLICATE @@ -49,32 +37,42 @@ node: { } } + + node { calculator: "InferenceCalculator" - input_stream: "TENSORS_GPU:input_tensors" - output_stream: "TENSORS_GPU:output_tensors" + input_stream: "TENSORS:input_tensors" + output_stream: "TENSORS:output_tensors" options: { [mediapipe.InferenceCalculatorOptions.ext] { - model_path: "mediapipe/models/metaf-512-mobile3.tflite" - delegate { gpu {} } + model_path: "mediapipe/models/model_float32.tflite" + delegate { xnnpack {} } } } } +# Retrieves the size of the input image. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_video" + output_stream: "SIZE:input_size" +} + # Processes the output tensors into a segmentation mask that has the same size # as the input image into the graph. node { calculator: "TensorsToSegmentationCalculator" input_stream: "TENSORS:output_tensors" + input_stream: "OUTPUT_SIZE:input_size" output_stream: "MASK:mask_image" options: { [mediapipe.TensorsToSegmentationCalculatorOptions.ext] { activation: NONE + gpu_origin: TOP_LEFT } } } - node: { calculator: "FromImageCalculator" input_stream: "IMAGE:mask_image"