CPU inference on desktop Linux with PyTorch v1.4.0

Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>
This commit is contained in:
Pierre Fenoll 2020-10-17 21:03:22 +02:00
parent cccf6244d3
commit 1f121dd3eb
17 changed files with 2133 additions and 0 deletions

View File

@ -361,6 +361,19 @@ http_archive(
],
)
# PyTorch archives
# 2020-03-12
# On https://pytorch.org: Stable > Linux > LibTorch > C++ > (CUDA) None > cxx11 ABI
http_archive(
name = "linux_libtorch_cpu",
build_file = "@//third_party:libtorch_linux.BUILD",
sha256 = "33a9dd142d0497375db42b055bd90780f9d92047a19edc8891e6232e2b5bdba7",
strip_prefix = "libtorch",
type = "zip",
url = "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip",
)
#Tensorflow repo should always go after the other external dependencies.
# 2020-08-30
_TENSORFLOW_GIT_COMMIT = "57b009e31e59bd1a7ae85ef8c0232ed86c9b71db"

View File

@ -0,0 +1,79 @@
---
layout: default
title: Object Classification
parent: Solutions
nav_order: TODO
---
# MediaPipe Object Classification
{: .no_toc }
1. TOC
{:toc}
---
## Example Apps
Note: To visualize a graph, copy the graph and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
to visualize its associated subgraphs, please see
[visualizer documentation](../tools/visualizer.md).
<!-- ### Mobile
Please first see general instructions for
[iOS](../getting_started/building_examples.md#ios) on how to build MediaPipe examples.
#### GPU Pipeline
* iOS target:
[`mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/objectdetectiongpu/BUILD)
-->
### Desktop
#### Live Camera Input
Please first see general instructions for
[desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe examples.
* Graph:
[`mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt)
* Target:
[`mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/object_classification/BUILD)
#### Video File Input
* With a PyTorch Model
This uses a MobileNetv2 trace model from PyTorch Hub. To fetch and prepare it, run:
```bash
python mediapipe/models/trace_mobilenetv2.py
```
The pipeline is implemented in this
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt).
To build the application, run:
```bash
bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu
```
To run the application, replace `<input video path>` and `<output video
path>` in the command below with your own paths:
Tip: You can find a test video available in
`mediapipe/examples/desktop/object_detection`.
```
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_classification/object_classification_pytorch_cpu \
--calculator_graph_config_file=mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt \
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
```
<!--
## Resources
* [Models and model cards](./models.md#object_detection)
-->

View File

@ -0,0 +1,138 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:private"])
proto_library(
name = "pytorch_converter_calculator_proto",
srcs = ["pytorch_converter_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "pytorch_inference_calculator_proto",
srcs = ["pytorch_inference_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
proto_library(
name = "pytorch_tensors_to_classification_calculator_proto",
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pytorch_converter_calculator_cc_proto",
srcs = ["pytorch_converter_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":pytorch_converter_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pytorch_inference_calculator_cc_proto",
srcs = ["pytorch_inference_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":pytorch_inference_calculator_proto"],
)
mediapipe_cc_proto_library(
name = "pytorch_tensors_to_classification_calculator_cc_proto",
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
visibility = ["//visibility:public"],
deps = [":pytorch_tensors_to_classification_calculator_proto"],
)
cc_library(
name = "pytorch_converter_calculator",
srcs = ["pytorch_converter_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":pytorch_converter_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//third_party:libtorch",
] + select({
"//mediapipe:ios": [
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/objc:util",
],
"//conditions:default": [],
}),
alwayslink = 1,
)
cc_library(
name = "pytorch_inference_calculator",
srcs = ["pytorch_inference_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":pytorch_inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/util:resource_util",
"//third_party:libtorch",
],
alwayslink = 1,
)
cc_library(
name = "pytorch_tensors_to_classification_calculator",
srcs = ["pytorch_tensors_to_classification_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":pytorch_tensors_to_classification_calculator_cc_proto",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"//mediapipe/framework/port:status",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:location",
"//mediapipe/framework/port:ret_check",
"//mediapipe/util:resource_util",
"//third_party:libtorch",
] + select({
"//mediapipe:android": [
"//mediapipe/util/android/file/base",
],
"//mediapipe:apple": [
"//mediapipe/util/android/file/base",
],
"//mediapipe:macos": [
"//mediapipe/framework/port:file_helpers",
],
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
],
}),
alwayslink = 1,
)

View File

@ -0,0 +1,179 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "torch/script.h"
#include "torch/torch.h"
#if defined(MEDIAPIPE_IOS)
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/objc/util.h"
#endif // iOS
namespace mediapipe {
namespace {
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kTensorsTag[] = "TENSORS";
using Outputs = std::vector<torch::jit::IValue>;
} // namespace
// Calculator for normalizing and converting an ImageFrame
// into a PyTorchTensor.
//
// This calculator is designed to be used with the PyTorchInferenceCalculator,
// as a pre-processing step for calculator inputs.
//
// IMAGE inputs are normalized to [-1,1] (default) or [0,1],
// specified by options (unless outputting a quantized tensor).
//
// Input:
// One of the following tags:
// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
//
// Output:
// One of the following tags:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
//
// Example use:
// node {
// calculator: "PyTorchConverterCalculator"
// input_stream: "IMAGE:input_image"
// output_stream: "TENSORS:image_tensor"
// options: {
// [mediapipe.PyTorchConverterCalculatorOptions.ext] {
// zero_center: true
// }
// }
// }
//
// IMPORTANT Notes:
// This calculator uses FixedSizeInputStreamHandler by default.
//
class PyTorchConverterCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::PyTorchConverterCalculatorOptions options_;
bool has_image_tag_;
bool has_image_gpu_tag_;
};
REGISTER_CALCULATOR(PyTorchConverterCalculator);
::mediapipe::Status PyTorchConverterCalculator::GetContract(
CalculatorContract* cc) {
const bool has_image_tag = cc->Inputs().HasTag(kImageTag);
const bool has_image_gpu_tag = cc->Inputs().HasTag(kImageGpuTag);
// Confirm only one of the input streams is present.
RET_CHECK(has_image_tag ^ has_image_gpu_tag);
const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag);
RET_CHECK(has_tensors_tag);
if (has_image_tag) cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
if (has_image_gpu_tag) {
#if defined(MEDIAPIPE_IOS)
cc->Inputs().Tag(kImageGpuTag).Set<GpuBuffer>();
#else
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif
}
if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
// Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchConverterCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<::mediapipe::PyTorchConverterCalculatorOptions>();
has_image_tag_ = cc->Inputs().HasTag(kImageTag);
has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag);
if (has_image_gpu_tag_) {
#if !defined(MEDIAPIPE_IOS)
RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) {
cv::Mat image;
if (has_image_gpu_tag_) {
#if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
const auto& input = cc->Inputs().Tag(kImageGpuTag).Get<GpuBuffer>();
std::unique_ptr<ImageFrame> frame =
CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef());
image = mediapipe::formats::MatView(frame.release());
#else
RET_CHECK_FAIL() << "GPU processing is not enabled.";
#endif
}
if (has_image_tag_) {
auto& output_frame =
cc->Inputs().Tag(kImageTag).Get<mediapipe::ImageFrame>();
image = mediapipe::formats::MatView(&output_frame);
}
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
const auto width = image.cols, height = image.rows,
num_channels = image.channels();
RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now";
cv::Mat img_float;
image.convertTo(img_float, CV_32F, 1.0 / 255);
auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3});
img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH
if (options_.per_channel_normalizations().size() > 0)
for (int i = 0; i < num_channels; ++i) {
const auto& subdiv = options_.per_channel_normalizations(i);
img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div());
}
auto output_tensors = absl::make_unique<Outputs>();
output_tensors->reserve(1);
output_tensors->emplace_back(img_tensor.cpu());
cc->Outputs()
.Tag(kTensorsTag)
.Add(output_tensors.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,34 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message PyTorchConverterCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional PyTorchConverterCalculatorOptions ext = 245877797;
}
message SubDiv {
optional float sub = 1 [default = 0];
optional float div = 2 [default = 1];
}
// Normalizations to apply per input image channel.
// There must be exactly as many items in this list as there are channels
// or none.
repeated SubDiv per_channel_normalizations = 1;
}

View File

@ -0,0 +1,172 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/calculators/pytorch/pytorch_inference_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/resource_util.h"
#include "torch/script.h"
#include "torch/torch.h"
namespace mediapipe {
namespace {
constexpr char kTensorsTag[] = "TENSORS";
using Inputs = std::vector<torch::jit::IValue>;
using Outputs = torch::Tensor;
} // namespace
// Calculator Header Section
// Runs inference on the provided input TFLite tensors and TFLite model.
//
// Creates an interpreter with given model and calls invoke().
// Optionally run inference on CPU/GPU.
//
// This calculator is designed to be used with the TfLiteConverterCalcualtor,
// to get the appropriate inputs.
//
// When the input tensors are on CPU, gpu inference is optional and can be
// specified in the calculator options.
//
// Input:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
//
// Output:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
//
// Example use:
// node {
// calculator: "PyTorchInferenceCalculator"
// input_stream: "TENSORS:tensor_image"
// output_stream: "TENSORS:tensors"
// options: {
// [mediapipe.PyTorchInferenceCalculatorOptions.ext] {
// model_path: "modelname.tflite"
// delegate { gpu {} }
// }
// }
// }
//
// IMPORTANT Notes:
// This calculator uses FixedSizeInputStreamHandler by default.
//
class PyTorchInferenceCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::PyTorchInferenceCalculatorOptions options_;
torch::jit::script::Module module_;
torch::jit::IValue hidden_state_;
};
REGISTER_CALCULATOR(PyTorchInferenceCalculator);
// Calculator Core Section
::mediapipe::Status PyTorchInferenceCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
RET_CHECK(cc->Outputs().HasTag(kTensorsTag));
if (cc->Inputs().HasTag(kTensorsTag))
cc->Inputs().Tag(kTensorsTag).Set<Inputs>();
if (cc->Outputs().HasTag(kTensorsTag))
cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
// Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchInferenceCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<::mediapipe::PyTorchInferenceCalculatorOptions>();
std::string model_path = options_.model_path();
ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path));
try {
// https://github.com/pytorch/ios-demo-app/issues/8#issuecomment-612996683
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) {
LOG(INFO) << "Using QEngine at::QEngine::QNNPACK";
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
#if defined(MEDIAPIPE_IOS)
else {
RET_CHECK_FAIL() << "QEngine::QNNPACK is required for iOS";
}
#endif
module_ = torch::jit::load(model_path);
module_.eval();
} catch (const std::exception& e) {
LOG(ERROR) << e.what();
return ::mediapipe::UnknownError(e.what());
}
if (options_.model_has_hidden_state())
hidden_state_ = torch::zeros({1, 1, 10}); // TODO: read options
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchInferenceCalculator::Process(CalculatorContext* cc) {
const auto inputs = cc->Inputs().Tag(kTensorsTag).Get<Inputs>();
RET_CHECK_GT(inputs.size(), 0);
// Disables autograd
torch::autograd::AutoGradMode guard(false);
// Disables autograd even more? https://github.com/pytorch/pytorch/pull/26477
at::AutoNonVariableTypeMode non_var_type_mode(true);
Outputs out_tensor;
if (options_.model_has_hidden_state()) {
RET_CHECK_EQ(inputs.size(), 1) << "Not sure how to forward() hidden state";
// auto tuple = torch::ivalue::Tuple::create({inp,hidden_state_});
const auto result = module_.forward({{inputs[0], hidden_state_}});
const auto out = result.toTuple()->elements();
out_tensor = out[0].toTensor();
hidden_state_ = out[1].toTensor();
} else {
const auto result = module_.forward(std::move(inputs));
out_tensor = result.toTensor();
}
auto out = absl::make_unique<Outputs>(out_tensor);
cc->Outputs().Tag(kTensorsTag).Add(out.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchInferenceCalculator::Close(CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,71 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message PyTorchInferenceCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional PyTorchInferenceCalculatorOptions ext = 233877213;
}
// message Delegate {
// // Default inference provided by tflite.
// message TfLite {}
// // Delegate to run GPU inference depending on the device.
// // (Can use OpenGl, OpenCl, Metal depending on the device.)
// message Gpu {}
// // Android only.
// message Nnapi {}
// oneof delegate {
// TfLite tflite = 1;
// Gpu gpu = 2;
// Nnapi nnapi = 3;
// }
// }
// Path to the PyTorch trace model (ex: /path/to/modelname.pt).
optional string model_path = 1;
// Whether model has hidden state.
optional bool model_has_hidden_state = 2;
// // Whether the TF Lite GPU or CPU backend should be used. Effective only
// when
// // input tensors are on CPU. For input tensors on GPU, GPU backend is
// always
// // used.
// // DEPRECATED: configure "delegate" instead.
// optional bool use_gpu = 2 [deprecated = true, default = false];
// // Android only. When true, an NNAPI delegate will be used for inference.
// // If NNAPI is not available, then the default CPU delegate will be used
// // automatically.
// // DEPRECATED: configure "delegate" instead.
// optional bool use_nnapi = 3 [deprecated = true, default = false];
// // The number of threads available to the interpreter. Effective only when
// // input tensors are on CPU and 'use_gpu' is false.
// optional int32 cpu_num_thread = 4 [default = -1];
// // TfLite delegate to run inference.
// // NOTE: calculator is free to choose delegate if not specified explicitly.
// // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes
// // precedence over use_* deprecated options.)
// optional Delegate delegate = 5;
}

View File

@ -0,0 +1,173 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/resource_util.h"
#include "torch/torch.h"
#if defined(MEDIAPIPE_MOBILE)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h"
#else
#include "mediapipe/framework/port/file_helpers.h"
#endif
namespace mediapipe {
namespace {
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kClassificationListTag[] = "CLASSIFICATION_LIST";
using Inputs = torch::Tensor;
} // namespace
// Convert result PyTorch tensors from classification models into MediaPipe
// classifications.
//
// Input:
// TENSORS - Vector of PyTorch of type kTfLiteFloat32 containing one
// tensor, the size of which must be (1, * num_classes).
// Output:
// CLASSIFICATIONS - Result MediaPipe ClassificationList. The score and index
// fields of each classification are set, while the label
// field is only set if label_map_path is provided.
//
// Usage example:
// node {
// calculator: "PyTorchTensorsToClassificationCalculator"
// input_stream: "TENSORS:tensors"
// output_stream: "CLASSIFICATIONS:classifications"
// options: {
// [mediapipe.PyTorchTensorsToClassificationCalculatorOptions.ext] {
// num_classes: 1024
// min_score_threshold: 0.1
// label_map_path: "labelmap.txt"
// }
// }
// }
class PyTorchTensorsToClassificationCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
::mediapipe::Status Close(CalculatorContext* cc) override;
private:
::mediapipe::PyTorchTensorsToClassificationCalculatorOptions options_;
std::unordered_map<int, std::string> label_map_;
bool label_map_loaded_ = false;
};
REGISTER_CALCULATOR(PyTorchTensorsToClassificationCalculator);
::mediapipe::Status PyTorchTensorsToClassificationCalculator::GetContract(
CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
cc->Inputs().Tag(kTensorsTag).Set<Inputs>();
RET_CHECK(!cc->Outputs().GetTags().empty());
if (cc->Outputs().HasTag(kClassificationListTag)) {
cc->Outputs().Tag(kClassificationListTag).Set<ClassificationList>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchTensorsToClassificationCalculator::Open(
CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<
::mediapipe::PyTorchTensorsToClassificationCalculatorOptions>();
if (options_.has_label_map_path()) {
std::string string_path;
ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options_.label_map_path()));
std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
std::istringstream stream(label_map_string);
std::string line;
int i = 0;
while (std::getline(stream, line)) label_map_[i++] = line;
label_map_loaded_ = true;
}
if (options_.has_top_k()) RET_CHECK_GT(options_.top_k(), 0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchTensorsToClassificationCalculator::Process(
CalculatorContext* cc) {
const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get<Inputs>();
RET_CHECK_EQ(input_tensors.dim(), 2);
std::tuple<torch::Tensor, torch::Tensor> result =
input_tensors.sort(/*dim*/ -1, /*descending*/ true);
const torch::Tensor scores_tensor = std::get<0>(result)[0];
const torch::Tensor indices_tensor =
std::get<1>(result)[0].toType(torch::kInt32);
auto scores = scores_tensor.accessor<float, 1>();
auto indices = indices_tensor.accessor<int, 1>();
const auto indices_count = indices.size(0);
RET_CHECK_EQ(indices_count, scores.size(0));
if (label_map_loaded_)
RET_CHECK_EQ(indices_count, label_map_.size())
<< "need: " << indices_count << ", got: " << label_map_.size();
// RET_CHECK_GE(indices_count, options_.top_k());
auto top_k = indices.size(0);
if (options_.has_top_k()) top_k = options_.top_k();
auto classification_list = absl::make_unique<ClassificationList>();
for (int i = 0; i < indices_count; ++i) {
if (classification_list->classification_size() == top_k) break;
const float score = scores[i];
const int index = indices[i];
if (options_.has_min_score_threshold() &&
score < options_.min_score_threshold())
continue;
Classification* classification = classification_list->add_classification();
classification->set_score(score);
classification->set_index(index);
if (label_map_loaded_) classification->set_label(label_map_[index]);
}
cc->Outputs()
.Tag(kClassificationListTag)
.Add(classification_list.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
::mediapipe::Status PyTorchTensorsToClassificationCalculator::Close(
CalculatorContext* cc) {
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2020 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message PyTorchTensorsToClassificationCalculatorOptions {
extend .mediapipe.CalculatorOptions {
optional PyTorchTensorsToClassificationCalculatorOptions ext = 266379463;
}
// Score threshold for perserving the class.
optional float min_score_threshold = 1;
// Number of highest scoring labels to output. If top_k is not positive then
// all labels are used.
optional int32 top_k = 2;
// Path to a label map file for getting the actual name of class ids.
optional string label_map_path = 3;
}

View File

@ -0,0 +1,27 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_cc//cc:defs.bzl", "cc_binary")
licenses(["notice"])
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
cc_binary(
name = "object_classification_pytorch_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
"//mediapipe/graphs/object_classification:desktop_pytorch_calculators",
],
)

View File

@ -0,0 +1,32 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_cc//cc:defs.bzl", "cc_library")
licenses(["notice"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "desktop_pytorch_calculators",
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/pytorch:pytorch_converter_calculator",
"//mediapipe/calculators/pytorch:pytorch_inference_calculator",
"//mediapipe/calculators/pytorch:pytorch_tensors_to_classification_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:labels_to_render_data_calculator",
],
)

View File

@ -0,0 +1,120 @@
# MediaPipe graph that performs object classification with PyTorch on CPU.
# Used in the examples in
# mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu.
# Images on CPU coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for
# PyTorchTensorsToClassificationCalculator downstream in the graph to finish
# generating the corresponding classifications before it passes through another
# image. All images that come in while waiting are dropped, limiting the number
# of in-flight images between this calculator and
# PyTorchTensorsToClassificationCalculator to 1. This prevents the nodes in between
# from queuing up incoming images and data excessively, which leads to increased
# latency and memory usage, unwanted in real-time mobile applications. It also
# eliminates unnecessarily computation, e.g., a transformed image produced by
# ImageTransformationCalculator may get dropped downstream if the subsequent
# PyTorchConverterCalculator or PyTorchInferenceCalculator is still busy
# processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:output_classifications"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Transforms the input image on CPU to a 224x224 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:throttled_input_video"
output_stream: "IMAGE:transformed_input_video"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 224
output_height: 224
scale_mode: FIT
}
}
}
# Converts the transformed input image on CPU into an image tensor stored as a
# PyTorch tensor. Pixel values are normalized using mean = [0.485, 0.456, 0.406]
# and std = [0.229, 0.224, 0.225].
node {
calculator: "PyTorchConverterCalculator"
input_stream: "IMAGE:transformed_input_video"
output_stream: "TENSORS:image_tensor"
node_options: {
[type.googleapis.com/mediapipe.PyTorchConverterCalculatorOptions] {
per_channel_normalizations: {sub:0.485 div:0.229}
per_channel_normalizations: {sub:0.456 div:0.224}
per_channel_normalizations: {sub:0.406 div:0.225}
}
}
}
# Runs a PyTorch model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, classification scores.
node {
calculator: "PyTorchInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:classification_tensors"
node_options: {
[type.googleapis.com/mediapipe.PyTorchInferenceCalculatorOptions] {
model_path: "mediapipe/models/mobilenetv2.pt"
}
}
}
# Decodes the classifications tensors generated by the PyTorch model, based on
# the specification in the options, into a vector of classifications.
# Maps classification label IDs to the corresponding label text. The label map is
# provided in the label_map_path option.
node {
calculator: "PyTorchTensorsToClassificationCalculator"
input_stream: "TENSORS:classification_tensors"
output_stream: "CLASSIFICATION_LIST:output_classifications"
node_options: {
[type.googleapis.com/mediapipe.PyTorchTensorsToClassificationCalculatorOptions] {
top_k: 3
min_score_threshold: 0.1
label_map_path: "mediapipe/models/mobilenetv2.labelmap"
}
}
}
# Converts the classifications label to drawing primitives for annotation overlay.
node {
calculator: "LabelsToRenderDataCalculator"
input_stream: "CLASSIFICATIONS:output_classifications"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.LabelsToRenderDataCalculatorOptions] {
color { r: 255 g: 0 b: 0 }
color { r: 0 g: 255 b: 0 }
color { r: 0 g: 0 b: 255 }
thickness: 2.0
font_height_px: 20
font_face: 1
location: TOP_LEFT
}
}
}
# Draws annotations and overlays them on top of the input images.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "IMAGE:throttled_input_video"
input_stream: "render_data"
output_stream: "IMAGE:output_video"
}

1
mediapipe/models/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/mobilenetv2.pt

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,14 @@
# `torch` package known to work:
# python -m pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
import torch
# More about this model at https://pytorch.org/hub/pytorch_vision_mobilenet_v2
if __name__ == '__main__':
model = torch.hub.load('pytorch/vision:v0.5.0',
'mobilenet_v2',
pretrained=True)
model.eval()
input_tensor = torch.rand(1, 3, 224, 224)
script_model = torch.jit.trace(model, input_tensor)
script_model.save("mediapipe/models/mobilenetv2.pt")

15
third_party/BUILD vendored
View File

@ -209,6 +209,21 @@ cc_library(
}),
)
cc_library(
name = "libtorch",
visibility = ["//visibility:public"],
deps = select({
"//mediapipe:android_x86": [],
"//mediapipe:android_x86_64": [],
"//mediapipe:android_armeabi": [],
"//mediapipe:android_arm": [],
"//mediapipe:android_arm64": [],
"//mediapipe:ios": ["@ios_libtorch//:libtorch"],
"//mediapipe:macos": ["@macos_libtorch_cpu//:libtorch_cpu"],
"//conditions:default": ["@linux_libtorch_cpu//:libtorch_cpu"],
}),
)
android_library(
name = "androidx_annotation",
exports = [

32
third_party/libtorch_linux.BUILD vendored Normal file
View File

@ -0,0 +1,32 @@
# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_cc//cc:defs.bzl", "cc_library")
cc_library(
name = "libtorch_cpu",
srcs = glob([
"lib/libc10.so",
"lib/libgomp-*.so.1",
"lib/libtorch.so",
]),
hdrs = glob(["include/**/*.h"]),
includes = [
"include",
"include/TH",
"include/THC",
"include/torch/csrc/api/include",
],
visibility = ["//visibility:public"],
)