wip: docs/tests/cleanup

Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>
This commit is contained in:
Pierre Fenoll 2020-10-19 17:54:25 +02:00
parent 93cf6d93a0
commit 96a8ff922d
4 changed files with 203 additions and 60 deletions

View File

@ -12,57 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
proto_library(
mediapipe_proto_library(
name = "pytorch_converter_calculator_proto",
srcs = ["pytorch_converter_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
mediapipe_proto_library(
name = "pytorch_inference_calculator_proto",
srcs = ["pytorch_inference_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
proto_library(
mediapipe_proto_library(
name = "pytorch_tensors_to_classification_calculator_proto",
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pytorch_converter_calculator_cc_proto",
srcs = ["pytorch_converter_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":pytorch_converter_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pytorch_inference_calculator_cc_proto",
srcs = ["pytorch_inference_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":pytorch_inference_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pytorch_tensors_to_classification_calculator_cc_proto",
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":pytorch_tensors_to_classification_calculator_proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
@ -136,3 +120,21 @@ cc_library(
}),
alwayslink = 1,
)
cc_test(
name = "pytorch_converter_calculator_test",
srcs = ["pytorch_converter_calculator_test.cc"],
deps = [
":pytorch_converter_calculator",
":pytorch_converter_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
# "//mediapipe/framework:calculator_runner",
# "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest_main",
# "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:parse_text_proto",
# "//mediapipe/framework/tool:validate_type",
],
)

View File

@ -28,11 +28,19 @@
#if defined(MEDIAPIPE_IOS)
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/objc/util.h"
#endif // iOS
#endif // MEDIAPIPE_IOS
namespace mediapipe {
namespace {
::mediapipe::Status EnsureFormat(const ImageFrame& image_frame) {
const ImageFormat::Format format = image_frame.Format();
if (!(format == ImageFormat::SRGB)) {
RET_CHECK_FAIL() << "Unsupported input format.";
}
return ::mediapipe::OkStatus();
}
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kTensorsTag[] = "TENSORS";
@ -46,16 +54,16 @@ using Outputs = std::vector<torch::jit::IValue>;
// This calculator is designed to be used with the PyTorchInferenceCalculator,
// as a pre-processing step for calculator inputs.
//
// IMAGE inputs are normalized to [-1,1] (default) or [0,1],
// specified by options (unless outputting a quantized tensor).
// IMAGE and IMAGE_GPU inputs are normalized to [0;1].
//
// Input:
// One of the following tags:
// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
// IMAGE - ImageFrame.
// IMAGE_GPU - GpuBuffer.
//
// Output:
// One of the following tags:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
// TENSORS - Vector of torch::jit::IValue residing on CPU.
//
// Example use:
// node {
@ -64,12 +72,15 @@ using Outputs = std::vector<torch::jit::IValue>;
// output_stream: "TENSORS:image_tensor"
// options: {
// [mediapipe.PyTorchConverterCalculatorOptions.ext] {
// zero_center: true
// per_channel_normalizations: {sub:0.485 div:0.229}
// per_channel_normalizations: {sub:0.456 div:0.224}
// per_channel_normalizations: {sub:0.406 div:0.225}
// }
// }
// }
//
// IMPORTANT Notes:
// If given an IMAGE_GPU, PyTorch will convert TENSORS to CPU.
// This calculator uses FixedSizeInputStreamHandler by default.
//
class PyTorchConverterCalculator : public CalculatorBase {
@ -84,6 +95,7 @@ class PyTorchConverterCalculator : public CalculatorBase {
::mediapipe::PyTorchConverterCalculatorOptions options_;
bool has_image_tag_;
bool has_image_gpu_tag_;
bool has_tensors_tag_;
};
REGISTER_CALCULATOR(PyTorchConverterCalculator);
@ -97,7 +109,9 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag);
RET_CHECK(has_tensors_tag);
if (has_image_tag) cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
if (has_image_tag) {
cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
}
if (has_image_gpu_tag) {
#if defined(MEDIAPIPE_IOS)
cc->Inputs().Tag(kImageGpuTag).Set<GpuBuffer>();
@ -106,7 +120,7 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
#endif
}
if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
// Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
@ -121,11 +135,12 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
has_image_tag_ = cc->Inputs().HasTag(kImageTag);
has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag);
has_tensors_tag_ = cc->Outputs().HasTag(kTensorsTag);
if (has_image_gpu_tag_) {
#if !defined(MEDIAPIPE_IOS)
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif
#endif // MEDIAPIPE_IOS
}
return ::mediapipe::OkStatus();
@ -133,43 +148,68 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) {
cv::Mat image;
// Acquire input packet as ImageFrame image, if packet is not empty,
if (has_image_gpu_tag_) {
#if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input = cc->Inputs().Tag(kImageGpuTag).Get<GpuBuffer>();
std::unique_ptr<ImageFrame> frame =
CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef());
image = mediapipe::formats::MatView(frame.release());
ImageFrame image_frame = frame.release();
MP_RETURN_IF_ERROR(EnsureFormat(image_frame));
image = mediapipe::formats::MatView(image_frame);
#else
RET_CHECK_FAIL() << "GPU processing is not enabled.";
#endif
}
if (has_image_tag_) {
auto& output_frame =
cc->Inputs().Tag(kImageTag).Get<mediapipe::ImageFrame>();
image = mediapipe::formats::MatView(&output_frame);
if (cc->Inputs().Tag(kImageTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const ImageFrame& image_frame =
cc->Inputs().Tag(kImageTag).Get<ImageFrame>();
MP_RETURN_IF_ERROR(EnsureFormat(image_frame));
image = mediapipe::formats::MatView(&image_frame);
}
const int num_channels = image.channels();
RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported";
const int width = image.cols;
const int height = image.rows;
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
const auto width = image.cols, height = image.rows,
num_channels = image.channels();
RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now";
cv::Mat img_float;
// Normalize to [0;1]
image.convertTo(img_float, CV_32F, 1.0 / 255);
// FIXME: try NCWH directly here
auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3});
img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH
if (options_.per_channel_normalizations().size() > 0)
// Permute from NWHC to NCWH
img_tensor = img_tensor.permute({0, 3, 1, 2});
if (options_.per_channel_normalizations().size() > 0) {
// Further normalize each channel of input image
for (int i = 0; i < num_channels; ++i) {
const auto& subdiv = options_.per_channel_normalizations(i);
img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div());
const float sub = subdiv.sub();
const float div = subdiv.div();
img_tensor[0][i] = img_tensor[0][i].sub_(sub).div_(div);
}
}
auto output_tensors = absl::make_unique<Outputs>();
output_tensors->reserve(1);
output_tensors->emplace_back(img_tensor.cpu());
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
if (has_tensors_tag_) {
// FIXME: move to ctor
auto output_tensors = absl::make_unique<Outputs>();
output_tensors->reserve(1);
output_tensors->emplace_back(img_tensor.cpu());
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
} else {
RET_CHECK_FAIL() << "Unsupported output kind.";
}
}
::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) {

View File

@ -27,7 +27,7 @@ message PyTorchConverterCalculatorOptions {
optional float sub = 1 [default = 0];
optional float div = 2 [default = 1];
}
// Normalizations to apply per input image channel.
// Normalizations to apply per channel of input image.
// There must be exactly as many items in this list as there are channels
// or none.
repeated SubDiv per_channel_normalizations = 1;

View File

@ -0,0 +1,101 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
// #include "absl/memory/memory.h"
// #include "absl/strings/substitute.h"
#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
// #include "mediapipe/framework/calculator_runner.h"
// #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/gtest.h"
// #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
// #include "mediapipe/framework/tool/validate_type.h"
#include "torch/torch.h"
namespace mediapipe {
using Outputs = std::vector<torch::jit::IValue>;
class PyTorchConverterCalculatorTest : public ::testing::Test {
std::unique_ptr<CalculatorGraph> graph_;
};
TEST_F(PyTorchConverterCalculatorTest, CustomDivAndSub) {
CalculatorGraph graph;
// Run the calculator and verify that one output is generated.
CalculatorGraphConfig graph_config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "input_image"
node {
calculator: "PyTorchConverterCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.PyTorchConverterCalculatorOptions.ext] {
per_channel_normalizations: { sub: 0.485 div: 0.229 }
per_channel_normalizations: { sub: 0.456 div: 0.224 }
per_channel_normalizations: { sub: 0.406 div: 0.225 }
}
}
}
)");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor", &graph_config, &output_packets);
// Run the graph.
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.StartRun({}));
auto input_image = absl::make_unique<ImageFrame>(ImageFormat::SRGB, 1, 1);
cv::Mat mat = ::mediapipe::formats::MatView(input_image.get());
// mat.at<uint8>(0, 0) = 200;
mat.at<float>(0, 0, 0) = 200;
// mat.at<uint8>(0, 0, 1) = 200;
// mat.at<uint8>(0, 0, 2) = 200;
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_image", Adopt(input_image.release()).At(Timestamp(0))));
// Wait until the calculator done processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
// Get and process results.
const Outputs& tensor_vec = output_packets[0].Get<Outputs>();
EXPECT_EQ(1, tensor_vec.size());
const torch::Tensor& tensor = tensor_vec[0].toTensor();
EXPECT_EQ(4, tensor.dim());
// std::tuple<torch::Tensor, torch::Tensor> result =
// tensor.sort(/*dim*/ -1, /*descending*/ true);
// const torch::Tensor result_tensor = std::get<0>(result)[0];
// auto results = result_tensor.accessor<float, 1>();
// EXPECT_EQ(1, results.size(0));
// const float r0 = results[0];
// EXPECT_FLOAT_EQ(67.0f, r0);
// const float r1 = results[1];
// EXPECT_FLOAT_EQ(67.0f, r1);
// const float r2 = results[2];
// EXPECT_FLOAT_EQ(67.0f, r2);
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("input_image"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace mediapipe