wip: docs/tests/cleanup
Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>
This commit is contained in:
parent
93cf6d93a0
commit
96a8ff922d
|
@ -12,57 +12,41 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
proto_library(
|
||||
mediapipe_proto_library(
|
||||
name = "pytorch_converter_calculator_proto",
|
||||
srcs = ["pytorch_converter_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
mediapipe_proto_library(
|
||||
name = "pytorch_inference_calculator_proto",
|
||||
srcs = ["pytorch_inference_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
mediapipe_proto_library(
|
||||
name = "pytorch_tensors_to_classification_calculator_proto",
|
||||
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "pytorch_converter_calculator_cc_proto",
|
||||
srcs = ["pytorch_converter_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pytorch_converter_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "pytorch_inference_calculator_cc_proto",
|
||||
srcs = ["pytorch_inference_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pytorch_inference_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "pytorch_tensors_to_classification_calculator_cc_proto",
|
||||
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pytorch_tensors_to_classification_calculator_proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -136,3 +120,21 @@ cc_library(
|
|||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "pytorch_converter_calculator_test",
|
||||
srcs = ["pytorch_converter_calculator_test.cc"],
|
||||
deps = [
|
||||
":pytorch_converter_calculator",
|
||||
":pytorch_converter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
# "//mediapipe/framework:calculator_runner",
|
||||
# "//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
# "//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
# "//mediapipe/framework/tool:validate_type",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -28,11 +28,19 @@
|
|||
#if defined(MEDIAPIPE_IOS)
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/objc/util.h"
|
||||
#endif // iOS
|
||||
#endif // MEDIAPIPE_IOS
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
::mediapipe::Status EnsureFormat(const ImageFrame& image_frame) {
|
||||
const ImageFormat::Format format = image_frame.Format();
|
||||
if (!(format == ImageFormat::SRGB)) {
|
||||
RET_CHECK_FAIL() << "Unsupported input format.";
|
||||
}
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
@ -46,16 +54,16 @@ using Outputs = std::vector<torch::jit::IValue>;
|
|||
// This calculator is designed to be used with the PyTorchInferenceCalculator,
|
||||
// as a pre-processing step for calculator inputs.
|
||||
//
|
||||
// IMAGE inputs are normalized to [-1,1] (default) or [0,1],
|
||||
// specified by options (unless outputting a quantized tensor).
|
||||
// IMAGE and IMAGE_GPU inputs are normalized to [0;1].
|
||||
//
|
||||
// Input:
|
||||
// One of the following tags:
|
||||
// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
|
||||
// IMAGE - ImageFrame.
|
||||
// IMAGE_GPU - GpuBuffer.
|
||||
//
|
||||
// Output:
|
||||
// One of the following tags:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
|
||||
// TENSORS - Vector of torch::jit::IValue residing on CPU.
|
||||
//
|
||||
// Example use:
|
||||
// node {
|
||||
|
@ -64,12 +72,15 @@ using Outputs = std::vector<torch::jit::IValue>;
|
|||
// output_stream: "TENSORS:image_tensor"
|
||||
// options: {
|
||||
// [mediapipe.PyTorchConverterCalculatorOptions.ext] {
|
||||
// zero_center: true
|
||||
// per_channel_normalizations: {sub:0.485 div:0.229}
|
||||
// per_channel_normalizations: {sub:0.456 div:0.224}
|
||||
// per_channel_normalizations: {sub:0.406 div:0.225}
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// IMPORTANT Notes:
|
||||
// If given an IMAGE_GPU, PyTorch will convert TENSORS to CPU.
|
||||
// This calculator uses FixedSizeInputStreamHandler by default.
|
||||
//
|
||||
class PyTorchConverterCalculator : public CalculatorBase {
|
||||
|
@ -84,6 +95,7 @@ class PyTorchConverterCalculator : public CalculatorBase {
|
|||
::mediapipe::PyTorchConverterCalculatorOptions options_;
|
||||
bool has_image_tag_;
|
||||
bool has_image_gpu_tag_;
|
||||
bool has_tensors_tag_;
|
||||
};
|
||||
REGISTER_CALCULATOR(PyTorchConverterCalculator);
|
||||
|
||||
|
@ -97,7 +109,9 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
|
|||
const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag);
|
||||
RET_CHECK(has_tensors_tag);
|
||||
|
||||
if (has_image_tag) cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
|
||||
if (has_image_tag) {
|
||||
cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
|
||||
}
|
||||
if (has_image_gpu_tag) {
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
cc->Inputs().Tag(kImageGpuTag).Set<GpuBuffer>();
|
||||
|
@ -106,7 +120,7 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
|
|||
#endif
|
||||
}
|
||||
|
||||
if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
|
||||
cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
|
||||
|
||||
// Assign this calculator's default InputStreamHandler.
|
||||
cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
|
||||
|
@ -121,11 +135,12 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
|
|||
|
||||
has_image_tag_ = cc->Inputs().HasTag(kImageTag);
|
||||
has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag);
|
||||
has_tensors_tag_ = cc->Outputs().HasTag(kTensorsTag);
|
||||
|
||||
if (has_image_gpu_tag_) {
|
||||
#if !defined(MEDIAPIPE_IOS)
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
#endif
|
||||
#endif // MEDIAPIPE_IOS
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
|
@ -133,43 +148,68 @@ REGISTER_CALCULATOR(PyTorchConverterCalculator);
|
|||
|
||||
::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) {
|
||||
cv::Mat image;
|
||||
// Acquire input packet as ImageFrame image, if packet is not empty,
|
||||
if (has_image_gpu_tag_) {
|
||||
#if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
const auto& input = cc->Inputs().Tag(kImageGpuTag).Get<GpuBuffer>();
|
||||
std::unique_ptr<ImageFrame> frame =
|
||||
CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef());
|
||||
image = mediapipe::formats::MatView(frame.release());
|
||||
ImageFrame image_frame = frame.release();
|
||||
MP_RETURN_IF_ERROR(EnsureFormat(image_frame));
|
||||
image = mediapipe::formats::MatView(image_frame);
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing is not enabled.";
|
||||
#endif
|
||||
}
|
||||
if (has_image_tag_) {
|
||||
auto& output_frame =
|
||||
cc->Inputs().Tag(kImageTag).Get<mediapipe::ImageFrame>();
|
||||
image = mediapipe::formats::MatView(&output_frame);
|
||||
if (cc->Inputs().Tag(kImageTag).IsEmpty()) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
const ImageFrame& image_frame =
|
||||
cc->Inputs().Tag(kImageTag).Get<ImageFrame>();
|
||||
MP_RETURN_IF_ERROR(EnsureFormat(image_frame));
|
||||
image = mediapipe::formats::MatView(&image_frame);
|
||||
}
|
||||
|
||||
const int num_channels = image.channels();
|
||||
RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported";
|
||||
const int width = image.cols;
|
||||
const int height = image.rows;
|
||||
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
|
||||
const auto width = image.cols, height = image.rows,
|
||||
num_channels = image.channels();
|
||||
RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now";
|
||||
|
||||
cv::Mat img_float;
|
||||
// Normalize to [0;1]
|
||||
image.convertTo(img_float, CV_32F, 1.0 / 255);
|
||||
// FIXME: try NCWH directly here
|
||||
auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3});
|
||||
img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH
|
||||
if (options_.per_channel_normalizations().size() > 0)
|
||||
// Permute from NWHC to NCWH
|
||||
img_tensor = img_tensor.permute({0, 3, 1, 2});
|
||||
|
||||
if (options_.per_channel_normalizations().size() > 0) {
|
||||
// Further normalize each channel of input image
|
||||
for (int i = 0; i < num_channels; ++i) {
|
||||
const auto& subdiv = options_.per_channel_normalizations(i);
|
||||
img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div());
|
||||
const float sub = subdiv.sub();
|
||||
const float div = subdiv.div();
|
||||
img_tensor[0][i] = img_tensor[0][i].sub_(sub).div_(div);
|
||||
}
|
||||
}
|
||||
|
||||
auto output_tensors = absl::make_unique<Outputs>();
|
||||
output_tensors->reserve(1);
|
||||
output_tensors->emplace_back(img_tensor.cpu());
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
if (has_tensors_tag_) {
|
||||
// FIXME: move to ctor
|
||||
auto output_tensors = absl::make_unique<Outputs>();
|
||||
output_tensors->reserve(1);
|
||||
output_tensors->emplace_back(img_tensor.cpu());
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
} else {
|
||||
RET_CHECK_FAIL() << "Unsupported output kind.";
|
||||
}
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) {
|
||||
|
|
|
@ -27,7 +27,7 @@ message PyTorchConverterCalculatorOptions {
|
|||
optional float sub = 1 [default = 0];
|
||||
optional float div = 2 [default = 1];
|
||||
}
|
||||
// Normalizations to apply per input image channel.
|
||||
// Normalizations to apply per channel of input image.
|
||||
// There must be exactly as many items in this list as there are channels
|
||||
// or none.
|
||||
repeated SubDiv per_channel_normalizations = 1;
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
// #include "absl/memory/memory.h"
|
||||
// #include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
// #include "mediapipe/framework/calculator_runner.h"
|
||||
// #include "mediapipe/framework/formats/image_format.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
// #include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h" // NOLINT
|
||||
// #include "mediapipe/framework/tool/validate_type.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
using Outputs = std::vector<torch::jit::IValue>;
|
||||
|
||||
class PyTorchConverterCalculatorTest : public ::testing::Test {
|
||||
std::unique_ptr<CalculatorGraph> graph_;
|
||||
};
|
||||
|
||||
TEST_F(PyTorchConverterCalculatorTest, CustomDivAndSub) {
|
||||
CalculatorGraph graph;
|
||||
// Run the calculator and verify that one output is generated.
|
||||
CalculatorGraphConfig graph_config =
|
||||
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
|
||||
input_stream: "input_image"
|
||||
node {
|
||||
calculator: "PyTorchConverterCalculator"
|
||||
input_stream: "IMAGE:input_image"
|
||||
output_stream: "TENSORS:tensor"
|
||||
options {
|
||||
[mediapipe.PyTorchConverterCalculatorOptions.ext] {
|
||||
per_channel_normalizations: { sub: 0.485 div: 0.229 }
|
||||
per_channel_normalizations: { sub: 0.456 div: 0.224 }
|
||||
per_channel_normalizations: { sub: 0.406 div: 0.225 }
|
||||
}
|
||||
}
|
||||
}
|
||||
)");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensor", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
auto input_image = absl::make_unique<ImageFrame>(ImageFormat::SRGB, 1, 1);
|
||||
cv::Mat mat = ::mediapipe::formats::MatView(input_image.get());
|
||||
// mat.at<uint8>(0, 0) = 200;
|
||||
mat.at<float>(0, 0, 0) = 200;
|
||||
// mat.at<uint8>(0, 0, 1) = 200;
|
||||
// mat.at<uint8>(0, 0, 2) = 200;
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_image", Adopt(input_image.release()).At(Timestamp(0))));
|
||||
|
||||
// Wait until the calculator done processing.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Get and process results.
|
||||
const Outputs& tensor_vec = output_packets[0].Get<Outputs>();
|
||||
EXPECT_EQ(1, tensor_vec.size());
|
||||
const torch::Tensor& tensor = tensor_vec[0].toTensor();
|
||||
EXPECT_EQ(4, tensor.dim());
|
||||
|
||||
// std::tuple<torch::Tensor, torch::Tensor> result =
|
||||
// tensor.sort(/*dim*/ -1, /*descending*/ true);
|
||||
// const torch::Tensor result_tensor = std::get<0>(result)[0];
|
||||
// auto results = result_tensor.accessor<float, 1>();
|
||||
// EXPECT_EQ(1, results.size(0));
|
||||
// const float r0 = results[0];
|
||||
// EXPECT_FLOAT_EQ(67.0f, r0);
|
||||
// const float r1 = results[1];
|
||||
// EXPECT_FLOAT_EQ(67.0f, r1);
|
||||
// const float r2 = results[2];
|
||||
// EXPECT_FLOAT_EQ(67.0f, r2);
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
MP_ASSERT_OK(graph.CloseInputStream("input_image"));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
Loading…
Reference in New Issue
Block a user