From 96a8ff922d560ecc9d98750a3d5cee896e477d00 Mon Sep 17 00:00:00 2001 From: Pierre Fenoll Date: Mon, 19 Oct 2020 17:54:25 +0200 Subject: [PATCH] wip: docs/tests/cleanup Signed-off-by: Pierre Fenoll --- mediapipe/calculators/pytorch/BUILD | 68 ++++++------ .../pytorch/pytorch_converter_calculator.cc | 92 +++++++++++----- .../pytorch_converter_calculator.proto | 2 +- .../pytorch_converter_calculator_test.cc | 101 ++++++++++++++++++ 4 files changed, 203 insertions(+), 60 deletions(-) create mode 100644 mediapipe/calculators/pytorch/pytorch_converter_calculator_test.cc diff --git a/mediapipe/calculators/pytorch/BUILD b/mediapipe/calculators/pytorch/BUILD index 8880fca5d..ca0b8b130 100644 --- a/mediapipe/calculators/pytorch/BUILD +++ b/mediapipe/calculators/pytorch/BUILD @@ -12,57 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_library") -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:private"]) -proto_library( +mediapipe_proto_library( name = "pytorch_converter_calculator_proto", srcs = ["pytorch_converter_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "pytorch_inference_calculator_proto", srcs = ["pytorch_inference_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) -proto_library( +mediapipe_proto_library( name = "pytorch_tensors_to_classification_calculator_proto", srcs = ["pytorch_tensors_to_classification_calculator.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework:calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "pytorch_converter_calculator_cc_proto", - srcs = ["pytorch_converter_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":pytorch_converter_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "pytorch_inference_calculator_cc_proto", - srcs = ["pytorch_inference_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":pytorch_inference_calculator_proto"], -) - -mediapipe_cc_proto_library( - name = "pytorch_tensors_to_classification_calculator_cc_proto", - srcs = ["pytorch_tensors_to_classification_calculator.proto"], - cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], - deps = [":pytorch_tensors_to_classification_calculator_proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], ) cc_library( @@ -136,3 +120,21 @@ cc_library( }), alwayslink = 1, ) + +cc_test( + name = "pytorch_converter_calculator_test", + srcs = ["pytorch_converter_calculator_test.cc"], + deps = [ + ":pytorch_converter_calculator", + ":pytorch_converter_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + # "//mediapipe/framework:calculator_runner", + # "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:gtest_main", + # "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:parse_text_proto", + # "//mediapipe/framework/tool:validate_type", + ], +) diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc b/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc index b2696f73f..ad0cad012 100644 --- a/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc +++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator.cc @@ -28,11 +28,19 @@ #if defined(MEDIAPIPE_IOS) #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/objc/util.h" -#endif // iOS +#endif // MEDIAPIPE_IOS namespace mediapipe { namespace { +::mediapipe::Status EnsureFormat(const ImageFrame& image_frame) { + const ImageFormat::Format format = image_frame.Format(); + if (!(format == ImageFormat::SRGB)) { + RET_CHECK_FAIL() << "Unsupported input format."; + } + return ::mediapipe::OkStatus(); +} + constexpr char kImageTag[] = "IMAGE"; constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kTensorsTag[] = "TENSORS"; @@ -46,16 +54,16 @@ using Outputs = std::vector; // This calculator is designed to be used with the PyTorchInferenceCalculator, // as a pre-processing step for calculator inputs. // -// IMAGE inputs are normalized to [-1,1] (default) or [0,1], -// specified by options (unless outputting a quantized tensor). +// IMAGE and IMAGE_GPU inputs are normalized to [0;1]. // // Input: // One of the following tags: -// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data). +// IMAGE - ImageFrame. +// IMAGE_GPU - GpuBuffer. // // Output: // One of the following tags: -// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8. +// TENSORS - Vector of torch::jit::IValue residing on CPU. // // Example use: // node { @@ -64,12 +72,15 @@ using Outputs = std::vector; // output_stream: "TENSORS:image_tensor" // options: { // [mediapipe.PyTorchConverterCalculatorOptions.ext] { -// zero_center: true +// per_channel_normalizations: {sub:0.485 div:0.229} +// per_channel_normalizations: {sub:0.456 div:0.224} +// per_channel_normalizations: {sub:0.406 div:0.225} // } // } // } // // IMPORTANT Notes: +// If given an IMAGE_GPU, PyTorch will convert TENSORS to CPU. // This calculator uses FixedSizeInputStreamHandler by default. // class PyTorchConverterCalculator : public CalculatorBase { @@ -84,6 +95,7 @@ class PyTorchConverterCalculator : public CalculatorBase { ::mediapipe::PyTorchConverterCalculatorOptions options_; bool has_image_tag_; bool has_image_gpu_tag_; + bool has_tensors_tag_; }; REGISTER_CALCULATOR(PyTorchConverterCalculator); @@ -97,7 +109,9 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator); const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag); RET_CHECK(has_tensors_tag); - if (has_image_tag) cc->Inputs().Tag(kImageTag).Set(); + if (has_image_tag) { + cc->Inputs().Tag(kImageTag).Set(); + } if (has_image_gpu_tag) { #if defined(MEDIAPIPE_IOS) cc->Inputs().Tag(kImageGpuTag).Set(); @@ -106,7 +120,7 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator); #endif } - if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set(); + cc->Outputs().Tag(kTensorsTag).Set(); // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); @@ -121,11 +135,12 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator); has_image_tag_ = cc->Inputs().HasTag(kImageTag); has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag); + has_tensors_tag_ = cc->Outputs().HasTag(kTensorsTag); if (has_image_gpu_tag_) { #if !defined(MEDIAPIPE_IOS) RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif +#endif // MEDIAPIPE_IOS } return ::mediapipe::OkStatus(); @@ -133,43 +148,68 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator); ::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) { cv::Mat image; + // Acquire input packet as ImageFrame image, if packet is not empty, if (has_image_gpu_tag_) { #if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } const auto& input = cc->Inputs().Tag(kImageGpuTag).Get(); std::unique_ptr frame = CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef()); - image = mediapipe::formats::MatView(frame.release()); + ImageFrame image_frame = frame.release(); + MP_RETURN_IF_ERROR(EnsureFormat(image_frame)); + image = mediapipe::formats::MatView(image_frame); #else RET_CHECK_FAIL() << "GPU processing is not enabled."; #endif } if (has_image_tag_) { - auto& output_frame = - cc->Inputs().Tag(kImageTag).Get(); - image = mediapipe::formats::MatView(&output_frame); + if (cc->Inputs().Tag(kImageTag).IsEmpty()) { + return ::mediapipe::OkStatus(); + } + const ImageFrame& image_frame = + cc->Inputs().Tag(kImageTag).Get(); + MP_RETURN_IF_ERROR(EnsureFormat(image_frame)); + image = mediapipe::formats::MatView(&image_frame); } + + const int num_channels = image.channels(); + RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported"; + const int width = image.cols; + const int height = image.rows; cv::cvtColor(image, image, cv::COLOR_BGR2RGB); - const auto width = image.cols, height = image.rows, - num_channels = image.channels(); - RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now"; cv::Mat img_float; + // Normalize to [0;1] image.convertTo(img_float, CV_32F, 1.0 / 255); + // FIXME: try NCWH directly here auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3}); - img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH - if (options_.per_channel_normalizations().size() > 0) + // Permute from NWHC to NCWH + img_tensor = img_tensor.permute({0, 3, 1, 2}); + + if (options_.per_channel_normalizations().size() > 0) { + // Further normalize each channel of input image for (int i = 0; i < num_channels; ++i) { const auto& subdiv = options_.per_channel_normalizations(i); - img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div()); + const float sub = subdiv.sub(); + const float div = subdiv.div(); + img_tensor[0][i] = img_tensor[0][i].sub_(sub).div_(div); } + } - auto output_tensors = absl::make_unique(); - output_tensors->reserve(1); - output_tensors->emplace_back(img_tensor.cpu()); - cc->Outputs() - .Tag(kTensorsTag) - .Add(output_tensors.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + if (has_tensors_tag_) { + // FIXME: move to ctor + auto output_tensors = absl::make_unique(); + output_tensors->reserve(1); + output_tensors->emplace_back(img_tensor.cpu()); + cc->Outputs() + .Tag(kTensorsTag) + .Add(output_tensors.release(), cc->InputTimestamp()); + return ::mediapipe::OkStatus(); + } else { + RET_CHECK_FAIL() << "Unsupported output kind."; + } } ::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) { diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto b/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto index 22b3ac9c3..5b8fe97e4 100644 --- a/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto +++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator.proto @@ -27,7 +27,7 @@ message PyTorchConverterCalculatorOptions { optional float sub = 1 [default = 0]; optional float div = 2 [default = 1]; } - // Normalizations to apply per input image channel. + // Normalizations to apply per channel of input image. // There must be exactly as many items in this list as there are channels // or none. repeated SubDiv per_channel_normalizations = 1; diff --git a/mediapipe/calculators/pytorch/pytorch_converter_calculator_test.cc b/mediapipe/calculators/pytorch/pytorch_converter_calculator_test.cc new file mode 100644 index 000000000..414541bcb --- /dev/null +++ b/mediapipe/calculators/pytorch/pytorch_converter_calculator_test.cc @@ -0,0 +1,101 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +// #include "absl/memory/memory.h" +// #include "absl/strings/substitute.h" +#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +// #include "mediapipe/framework/calculator_runner.h" +// #include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/gtest.h" +// #include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +// #include "mediapipe/framework/tool/validate_type.h" +#include "torch/torch.h" + +namespace mediapipe { + +using Outputs = std::vector; + +class PyTorchConverterCalculatorTest : public ::testing::Test { + std::unique_ptr graph_; +}; + +TEST_F(PyTorchConverterCalculatorTest, CustomDivAndSub) { + CalculatorGraph graph; + // Run the calculator and verify that one output is generated. + CalculatorGraphConfig graph_config = + ::mediapipe::ParseTextProtoOrDie(R"( + input_stream: "input_image" + node { + calculator: "PyTorchConverterCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.PyTorchConverterCalculatorOptions.ext] { + per_channel_normalizations: { sub: 0.485 div: 0.229 } + per_channel_normalizations: { sub: 0.456 div: 0.224 } + per_channel_normalizations: { sub: 0.406 div: 0.225 } + } + } + } + )"); + std::vector output_packets; + tool::AddVectorSink("tensor", &graph_config, &output_packets); + + // Run the graph. + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.StartRun({})); + auto input_image = absl::make_unique(ImageFormat::SRGB, 1, 1); + cv::Mat mat = ::mediapipe::formats::MatView(input_image.get()); + // mat.at(0, 0) = 200; + mat.at(0, 0, 0) = 200; + // mat.at(0, 0, 1) = 200; + // mat.at(0, 0, 2) = 200; + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_image", Adopt(input_image.release()).At(Timestamp(0)))); + + // Wait until the calculator done processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Get and process results. + const Outputs& tensor_vec = output_packets[0].Get(); + EXPECT_EQ(1, tensor_vec.size()); + const torch::Tensor& tensor = tensor_vec[0].toTensor(); + EXPECT_EQ(4, tensor.dim()); + + // std::tuple result = + // tensor.sort(/*dim*/ -1, /*descending*/ true); + // const torch::Tensor result_tensor = std::get<0>(result)[0]; + // auto results = result_tensor.accessor(); + // EXPECT_EQ(1, results.size(0)); + // const float r0 = results[0]; + // EXPECT_FLOAT_EQ(67.0f, r0); + // const float r1 = results[1]; + // EXPECT_FLOAT_EQ(67.0f, r1); + // const float r2 = results[2]; + // EXPECT_FLOAT_EQ(67.0f, r2); + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("input_image")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace mediapipe