CPU inference on desktop Linux with PyTorch v1.4.0
Signed-off-by: Pierre Fenoll <pierrefenoll@gmail.com>
This commit is contained in:
parent
cccf6244d3
commit
1f121dd3eb
13
WORKSPACE
13
WORKSPACE
|
@ -361,6 +361,19 @@ http_archive(
|
|||
],
|
||||
)
|
||||
|
||||
# PyTorch archives
|
||||
|
||||
# 2020-03-12
|
||||
# On https://pytorch.org: Stable > Linux > LibTorch > C++ > (CUDA) None > cxx11 ABI
|
||||
http_archive(
|
||||
name = "linux_libtorch_cpu",
|
||||
build_file = "@//third_party:libtorch_linux.BUILD",
|
||||
sha256 = "33a9dd142d0497375db42b055bd90780f9d92047a19edc8891e6232e2b5bdba7",
|
||||
strip_prefix = "libtorch",
|
||||
type = "zip",
|
||||
url = "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip",
|
||||
)
|
||||
|
||||
#Tensorflow repo should always go after the other external dependencies.
|
||||
# 2020-08-30
|
||||
_TENSORFLOW_GIT_COMMIT = "57b009e31e59bd1a7ae85ef8c0232ed86c9b71db"
|
||||
|
|
79
docs/solutions/object_classification.md
Normal file
79
docs/solutions/object_classification.md
Normal file
|
@ -0,0 +1,79 @@
|
|||
---
|
||||
layout: default
|
||||
title: Object Classification
|
||||
parent: Solutions
|
||||
nav_order: TODO
|
||||
---
|
||||
|
||||
# MediaPipe Object Classification
|
||||
{: .no_toc }
|
||||
|
||||
1. TOC
|
||||
{:toc}
|
||||
---
|
||||
|
||||
## Example Apps
|
||||
|
||||
Note: To visualize a graph, copy the graph and paste it into
|
||||
[MediaPipe Visualizer](https://viz.mediapipe.dev/). For more information on how
|
||||
to visualize its associated subgraphs, please see
|
||||
[visualizer documentation](../tools/visualizer.md).
|
||||
|
||||
<!-- ### Mobile
|
||||
|
||||
Please first see general instructions for
|
||||
[iOS](../getting_started/building_examples.md#ios) on how to build MediaPipe examples.
|
||||
|
||||
#### GPU Pipeline
|
||||
|
||||
* iOS target:
|
||||
[`mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/objectdetectiongpu/BUILD)
|
||||
-->
|
||||
|
||||
### Desktop
|
||||
|
||||
#### Live Camera Input
|
||||
|
||||
Please first see general instructions for
|
||||
[desktop](../getting_started/building_examples.md#desktop) on how to build MediaPipe examples.
|
||||
|
||||
* Graph:
|
||||
[`mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt)
|
||||
* Target:
|
||||
[`mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/object_classification/BUILD)
|
||||
|
||||
#### Video File Input
|
||||
|
||||
* With a PyTorch Model
|
||||
|
||||
This uses a MobileNetv2 trace model from PyTorch Hub. To fetch and prepare it, run:
|
||||
|
||||
```bash
|
||||
python mediapipe/models/trace_mobilenetv2.py
|
||||
```
|
||||
|
||||
The pipeline is implemented in this
|
||||
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt).
|
||||
|
||||
To build the application, run:
|
||||
|
||||
```bash
|
||||
bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu
|
||||
```
|
||||
|
||||
To run the application, replace `<input video path>` and `<output video
|
||||
path>` in the command below with your own paths:
|
||||
|
||||
Tip: You can find a test video available in
|
||||
`mediapipe/examples/desktop/object_detection`.
|
||||
|
||||
```
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/object_classification/object_classification_pytorch_cpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/object_classification/object_classification_desktop_live.pbtxt \
|
||||
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
|
||||
```
|
||||
<!--
|
||||
## Resources
|
||||
|
||||
* [Models and model cards](./models.md#object_detection)
|
||||
-->
|
138
mediapipe/calculators/pytorch/BUILD
Normal file
138
mediapipe/calculators/pytorch/BUILD
Normal file
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_proto//proto:defs.bzl", "proto_library")
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
proto_library(
|
||||
name = "pytorch_converter_calculator_proto",
|
||||
srcs = ["pytorch_converter_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "pytorch_inference_calculator_proto",
|
||||
srcs = ["pytorch_inference_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "pytorch_tensors_to_classification_calculator_proto",
|
||||
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "pytorch_converter_calculator_cc_proto",
|
||||
srcs = ["pytorch_converter_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pytorch_converter_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "pytorch_inference_calculator_cc_proto",
|
||||
srcs = ["pytorch_inference_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pytorch_inference_calculator_proto"],
|
||||
)
|
||||
|
||||
mediapipe_cc_proto_library(
|
||||
name = "pytorch_tensors_to_classification_calculator_cc_proto",
|
||||
srcs = ["pytorch_tensors_to_classification_calculator.proto"],
|
||||
cc_deps = ["//mediapipe/framework:calculator_cc_proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":pytorch_tensors_to_classification_calculator_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pytorch_converter_calculator",
|
||||
srcs = ["pytorch_converter_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pytorch_converter_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//third_party:libtorch",
|
||||
] + select({
|
||||
"//mediapipe:ios": [
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/objc:util",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pytorch_inference_calculator",
|
||||
srcs = ["pytorch_inference_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pytorch_inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//mediapipe/util:resource_util",
|
||||
"//third_party:libtorch",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pytorch_tensors_to_classification_calculator",
|
||||
srcs = ["pytorch_tensors_to_classification_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":pytorch_tensors_to_classification_calculator_cc_proto",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:location",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/util:resource_util",
|
||||
"//third_party:libtorch",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
"//mediapipe/util/android/file/base",
|
||||
],
|
||||
"//mediapipe:apple": [
|
||||
"//mediapipe/util/android/file/base",
|
||||
],
|
||||
"//mediapipe:macos": [
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
179
mediapipe/calculators/pytorch/pytorch_converter_calculator.cc
Normal file
179
mediapipe/calculators/pytorch/pytorch_converter_calculator.cc
Normal file
|
@ -0,0 +1,179 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/pytorch/pytorch_converter_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "torch/script.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/objc/util.h"
|
||||
#endif // iOS
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageGpuTag[] = "IMAGE_GPU";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
using Outputs = std::vector<torch::jit::IValue>;
|
||||
} // namespace
|
||||
|
||||
// Calculator for normalizing and converting an ImageFrame
|
||||
// into a PyTorchTensor.
|
||||
//
|
||||
// This calculator is designed to be used with the PyTorchInferenceCalculator,
|
||||
// as a pre-processing step for calculator inputs.
|
||||
//
|
||||
// IMAGE inputs are normalized to [-1,1] (default) or [0,1],
|
||||
// specified by options (unless outputting a quantized tensor).
|
||||
//
|
||||
// Input:
|
||||
// One of the following tags:
|
||||
// IMAGE - ImageFrame (assumed to be 8-bit or 32-bit data).
|
||||
//
|
||||
// Output:
|
||||
// One of the following tags:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
|
||||
//
|
||||
// Example use:
|
||||
// node {
|
||||
// calculator: "PyTorchConverterCalculator"
|
||||
// input_stream: "IMAGE:input_image"
|
||||
// output_stream: "TENSORS:image_tensor"
|
||||
// options: {
|
||||
// [mediapipe.PyTorchConverterCalculatorOptions.ext] {
|
||||
// zero_center: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// IMPORTANT Notes:
|
||||
// This calculator uses FixedSizeInputStreamHandler by default.
|
||||
//
|
||||
class PyTorchConverterCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::PyTorchConverterCalculatorOptions options_;
|
||||
bool has_image_tag_;
|
||||
bool has_image_gpu_tag_;
|
||||
};
|
||||
REGISTER_CALCULATOR(PyTorchConverterCalculator);
|
||||
|
||||
::mediapipe::Status PyTorchConverterCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
const bool has_image_tag = cc->Inputs().HasTag(kImageTag);
|
||||
const bool has_image_gpu_tag = cc->Inputs().HasTag(kImageGpuTag);
|
||||
// Confirm only one of the input streams is present.
|
||||
RET_CHECK(has_image_tag ^ has_image_gpu_tag);
|
||||
|
||||
const bool has_tensors_tag = cc->Outputs().HasTag(kTensorsTag);
|
||||
RET_CHECK(has_tensors_tag);
|
||||
|
||||
if (has_image_tag) cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
|
||||
if (has_image_gpu_tag) {
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
cc->Inputs().Tag(kImageGpuTag).Set<GpuBuffer>();
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
#endif
|
||||
}
|
||||
|
||||
if (has_tensors_tag) cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
|
||||
|
||||
// Assign this calculator's default InputStreamHandler.
|
||||
cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchConverterCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<::mediapipe::PyTorchConverterCalculatorOptions>();
|
||||
|
||||
has_image_tag_ = cc->Inputs().HasTag(kImageTag);
|
||||
has_image_gpu_tag_ = cc->Inputs().HasTag(kImageGpuTag);
|
||||
|
||||
if (has_image_gpu_tag_) {
|
||||
#if !defined(MEDIAPIPE_IOS)
|
||||
RET_CHECK_FAIL() << "GPU processing not enabled.";
|
||||
#endif
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchConverterCalculator::Process(CalculatorContext* cc) {
|
||||
cv::Mat image;
|
||||
if (has_image_gpu_tag_) {
|
||||
#if defined(MEDIAPIPE_IOS) && MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||
const auto& input = cc->Inputs().Tag(kImageGpuTag).Get<GpuBuffer>();
|
||||
std::unique_ptr<ImageFrame> frame =
|
||||
CreateImageFrameForCVPixelBuffer(input.GetCVPixelBufferRef());
|
||||
image = mediapipe::formats::MatView(frame.release());
|
||||
#else
|
||||
RET_CHECK_FAIL() << "GPU processing is not enabled.";
|
||||
#endif
|
||||
}
|
||||
if (has_image_tag_) {
|
||||
auto& output_frame =
|
||||
cc->Inputs().Tag(kImageTag).Get<mediapipe::ImageFrame>();
|
||||
image = mediapipe::formats::MatView(&output_frame);
|
||||
}
|
||||
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
|
||||
const auto width = image.cols, height = image.rows,
|
||||
num_channels = image.channels();
|
||||
RET_CHECK_EQ(num_channels, 3) << "Only RGB images are supported for now";
|
||||
|
||||
cv::Mat img_float;
|
||||
image.convertTo(img_float, CV_32F, 1.0 / 255);
|
||||
auto img_tensor = torch::from_blob(img_float.data, {1, width, height, 3});
|
||||
img_tensor = img_tensor.permute({0, 3, 1, 2}); // To NCWH
|
||||
if (options_.per_channel_normalizations().size() > 0)
|
||||
for (int i = 0; i < num_channels; ++i) {
|
||||
const auto& subdiv = options_.per_channel_normalizations(i);
|
||||
img_tensor[0][i] = img_tensor[0][i].sub_(subdiv.sub()).div_(subdiv.div());
|
||||
}
|
||||
|
||||
auto output_tensors = absl::make_unique<Outputs>();
|
||||
output_tensors->reserve(1);
|
||||
output_tensors->emplace_back(img_tensor.cpu());
|
||||
cc->Outputs()
|
||||
.Tag(kTensorsTag)
|
||||
.Add(output_tensors.release(), cc->InputTimestamp());
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchConverterCalculator::Close(CalculatorContext* cc) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message PyTorchConverterCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional PyTorchConverterCalculatorOptions ext = 245877797;
|
||||
}
|
||||
|
||||
message SubDiv {
|
||||
optional float sub = 1 [default = 0];
|
||||
optional float div = 2 [default = 1];
|
||||
}
|
||||
// Normalizations to apply per input image channel.
|
||||
// There must be exactly as many items in this list as there are channels
|
||||
// or none.
|
||||
repeated SubDiv per_channel_normalizations = 1;
|
||||
}
|
172
mediapipe/calculators/pytorch/pytorch_inference_calculator.cc
Normal file
172
mediapipe/calculators/pytorch/pytorch_inference_calculator.cc
Normal file
|
@ -0,0 +1,172 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/pytorch/pytorch_inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "torch/script.h"
|
||||
#include "torch/torch.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
using Inputs = std::vector<torch::jit::IValue>;
|
||||
using Outputs = torch::Tensor;
|
||||
} // namespace
|
||||
|
||||
// Calculator Header Section
|
||||
|
||||
// Runs inference on the provided input TFLite tensors and TFLite model.
|
||||
//
|
||||
// Creates an interpreter with given model and calls invoke().
|
||||
// Optionally run inference on CPU/GPU.
|
||||
//
|
||||
// This calculator is designed to be used with the TfLiteConverterCalcualtor,
|
||||
// to get the appropriate inputs.
|
||||
//
|
||||
// When the input tensors are on CPU, gpu inference is optional and can be
|
||||
// specified in the calculator options.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
|
||||
//
|
||||
// Output:
|
||||
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32 or kTfLiteUInt8
|
||||
//
|
||||
// Example use:
|
||||
// node {
|
||||
// calculator: "PyTorchInferenceCalculator"
|
||||
// input_stream: "TENSORS:tensor_image"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// options: {
|
||||
// [mediapipe.PyTorchInferenceCalculatorOptions.ext] {
|
||||
// model_path: "modelname.tflite"
|
||||
// delegate { gpu {} }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// IMPORTANT Notes:
|
||||
// This calculator uses FixedSizeInputStreamHandler by default.
|
||||
//
|
||||
class PyTorchInferenceCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::PyTorchInferenceCalculatorOptions options_;
|
||||
torch::jit::script::Module module_;
|
||||
torch::jit::IValue hidden_state_;
|
||||
};
|
||||
REGISTER_CALCULATOR(PyTorchInferenceCalculator);
|
||||
|
||||
// Calculator Core Section
|
||||
|
||||
::mediapipe::Status PyTorchInferenceCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
|
||||
RET_CHECK(cc->Outputs().HasTag(kTensorsTag));
|
||||
|
||||
if (cc->Inputs().HasTag(kTensorsTag))
|
||||
cc->Inputs().Tag(kTensorsTag).Set<Inputs>();
|
||||
|
||||
if (cc->Outputs().HasTag(kTensorsTag))
|
||||
cc->Outputs().Tag(kTensorsTag).Set<Outputs>();
|
||||
|
||||
// Assign this calculator's default InputStreamHandler.
|
||||
cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchInferenceCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<::mediapipe::PyTorchInferenceCalculatorOptions>();
|
||||
|
||||
std::string model_path = options_.model_path();
|
||||
ASSIGN_OR_RETURN(model_path, mediapipe::PathToResourceAsFile(model_path));
|
||||
try {
|
||||
// https://github.com/pytorch/ios-demo-app/issues/8#issuecomment-612996683
|
||||
auto qengines = at::globalContext().supportedQEngines();
|
||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
|
||||
qengines.end()) {
|
||||
LOG(INFO) << "Using QEngine at::QEngine::QNNPACK";
|
||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||
}
|
||||
#if defined(MEDIAPIPE_IOS)
|
||||
else {
|
||||
RET_CHECK_FAIL() << "QEngine::QNNPACK is required for iOS";
|
||||
}
|
||||
#endif
|
||||
|
||||
module_ = torch::jit::load(model_path);
|
||||
module_.eval();
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
return ::mediapipe::UnknownError(e.what());
|
||||
}
|
||||
|
||||
if (options_.model_has_hidden_state())
|
||||
hidden_state_ = torch::zeros({1, 1, 10}); // TODO: read options
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchInferenceCalculator::Process(CalculatorContext* cc) {
|
||||
const auto inputs = cc->Inputs().Tag(kTensorsTag).Get<Inputs>();
|
||||
RET_CHECK_GT(inputs.size(), 0);
|
||||
|
||||
// Disables autograd
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
// Disables autograd even more? https://github.com/pytorch/pytorch/pull/26477
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
|
||||
Outputs out_tensor;
|
||||
if (options_.model_has_hidden_state()) {
|
||||
RET_CHECK_EQ(inputs.size(), 1) << "Not sure how to forward() hidden state";
|
||||
// auto tuple = torch::ivalue::Tuple::create({inp,hidden_state_});
|
||||
const auto result = module_.forward({{inputs[0], hidden_state_}});
|
||||
const auto out = result.toTuple()->elements();
|
||||
out_tensor = out[0].toTensor();
|
||||
hidden_state_ = out[1].toTensor();
|
||||
} else {
|
||||
const auto result = module_.forward(std::move(inputs));
|
||||
out_tensor = result.toTensor();
|
||||
}
|
||||
|
||||
auto out = absl::make_unique<Outputs>(out_tensor);
|
||||
cc->Outputs().Tag(kTensorsTag).Add(out.release(), cc->InputTimestamp());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchInferenceCalculator::Close(CalculatorContext* cc) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message PyTorchInferenceCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional PyTorchInferenceCalculatorOptions ext = 233877213;
|
||||
}
|
||||
|
||||
// message Delegate {
|
||||
// // Default inference provided by tflite.
|
||||
// message TfLite {}
|
||||
// // Delegate to run GPU inference depending on the device.
|
||||
// // (Can use OpenGl, OpenCl, Metal depending on the device.)
|
||||
// message Gpu {}
|
||||
// // Android only.
|
||||
// message Nnapi {}
|
||||
|
||||
// oneof delegate {
|
||||
// TfLite tflite = 1;
|
||||
// Gpu gpu = 2;
|
||||
// Nnapi nnapi = 3;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Path to the PyTorch trace model (ex: /path/to/modelname.pt).
|
||||
optional string model_path = 1;
|
||||
|
||||
// Whether model has hidden state.
|
||||
optional bool model_has_hidden_state = 2;
|
||||
|
||||
// // Whether the TF Lite GPU or CPU backend should be used. Effective only
|
||||
// when
|
||||
// // input tensors are on CPU. For input tensors on GPU, GPU backend is
|
||||
// always
|
||||
// // used.
|
||||
// // DEPRECATED: configure "delegate" instead.
|
||||
// optional bool use_gpu = 2 [deprecated = true, default = false];
|
||||
|
||||
// // Android only. When true, an NNAPI delegate will be used for inference.
|
||||
// // If NNAPI is not available, then the default CPU delegate will be used
|
||||
// // automatically.
|
||||
// // DEPRECATED: configure "delegate" instead.
|
||||
// optional bool use_nnapi = 3 [deprecated = true, default = false];
|
||||
|
||||
// // The number of threads available to the interpreter. Effective only when
|
||||
// // input tensors are on CPU and 'use_gpu' is false.
|
||||
// optional int32 cpu_num_thread = 4 [default = -1];
|
||||
|
||||
// // TfLite delegate to run inference.
|
||||
// // NOTE: calculator is free to choose delegate if not specified explicitly.
|
||||
// // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes
|
||||
// // precedence over use_* deprecated options.)
|
||||
// optional Delegate delegate = 5;
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/pytorch/pytorch_tensors_to_classification_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#include "torch/torch.h"
|
||||
#if defined(MEDIAPIPE_MOBILE)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#else
|
||||
#include "mediapipe/framework/port/file_helpers.h"
|
||||
#endif
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kClassificationListTag[] = "CLASSIFICATION_LIST";
|
||||
|
||||
using Inputs = torch::Tensor;
|
||||
} // namespace
|
||||
|
||||
// Convert result PyTorch tensors from classification models into MediaPipe
|
||||
// classifications.
|
||||
//
|
||||
// Input:
|
||||
// TENSORS - Vector of PyTorch of type kTfLiteFloat32 containing one
|
||||
// tensor, the size of which must be (1, * num_classes).
|
||||
// Output:
|
||||
// CLASSIFICATIONS - Result MediaPipe ClassificationList. The score and index
|
||||
// fields of each classification are set, while the label
|
||||
// field is only set if label_map_path is provided.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "PyTorchTensorsToClassificationCalculator"
|
||||
// input_stream: "TENSORS:tensors"
|
||||
// output_stream: "CLASSIFICATIONS:classifications"
|
||||
// options: {
|
||||
// [mediapipe.PyTorchTensorsToClassificationCalculatorOptions.ext] {
|
||||
// num_classes: 1024
|
||||
// min_score_threshold: 0.1
|
||||
// label_map_path: "labelmap.txt"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class PyTorchTensorsToClassificationCalculator : public CalculatorBase {
|
||||
public:
|
||||
static ::mediapipe::Status GetContract(CalculatorContract* cc);
|
||||
|
||||
::mediapipe::Status Open(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Process(CalculatorContext* cc) override;
|
||||
::mediapipe::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::PyTorchTensorsToClassificationCalculatorOptions options_;
|
||||
std::unordered_map<int, std::string> label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
};
|
||||
REGISTER_CALCULATOR(PyTorchTensorsToClassificationCalculator);
|
||||
|
||||
::mediapipe::Status PyTorchTensorsToClassificationCalculator::GetContract(
|
||||
CalculatorContract* cc) {
|
||||
RET_CHECK(cc->Inputs().HasTag(kTensorsTag));
|
||||
cc->Inputs().Tag(kTensorsTag).Set<Inputs>();
|
||||
|
||||
RET_CHECK(!cc->Outputs().GetTags().empty());
|
||||
if (cc->Outputs().HasTag(kClassificationListTag)) {
|
||||
cc->Outputs().Tag(kClassificationListTag).Set<ClassificationList>();
|
||||
}
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchTensorsToClassificationCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ = cc->Options<
|
||||
::mediapipe::PyTorchTensorsToClassificationCalculatorOptions>();
|
||||
|
||||
if (options_.has_label_map_path()) {
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||
|
||||
std::istringstream stream(label_map_string);
|
||||
std::string line;
|
||||
int i = 0;
|
||||
while (std::getline(stream, line)) label_map_[i++] = line;
|
||||
label_map_loaded_ = true;
|
||||
}
|
||||
|
||||
if (options_.has_top_k()) RET_CHECK_GT(options_.top_k(), 0);
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchTensorsToClassificationCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get<Inputs>();
|
||||
RET_CHECK_EQ(input_tensors.dim(), 2);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> result =
|
||||
input_tensors.sort(/*dim*/ -1, /*descending*/ true);
|
||||
const torch::Tensor scores_tensor = std::get<0>(result)[0];
|
||||
const torch::Tensor indices_tensor =
|
||||
std::get<1>(result)[0].toType(torch::kInt32);
|
||||
|
||||
auto scores = scores_tensor.accessor<float, 1>();
|
||||
auto indices = indices_tensor.accessor<int, 1>();
|
||||
|
||||
const auto indices_count = indices.size(0);
|
||||
RET_CHECK_EQ(indices_count, scores.size(0));
|
||||
if (label_map_loaded_)
|
||||
RET_CHECK_EQ(indices_count, label_map_.size())
|
||||
<< "need: " << indices_count << ", got: " << label_map_.size();
|
||||
|
||||
// RET_CHECK_GE(indices_count, options_.top_k());
|
||||
auto top_k = indices.size(0);
|
||||
if (options_.has_top_k()) top_k = options_.top_k();
|
||||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
for (int i = 0; i < indices_count; ++i) {
|
||||
if (classification_list->classification_size() == top_k) break;
|
||||
const float score = scores[i];
|
||||
const int index = indices[i];
|
||||
if (options_.has_min_score_threshold() &&
|
||||
score < options_.min_score_threshold())
|
||||
continue;
|
||||
|
||||
Classification* classification = classification_list->add_classification();
|
||||
classification->set_score(score);
|
||||
classification->set_index(index);
|
||||
if (label_map_loaded_) classification->set_label(label_map_[index]);
|
||||
}
|
||||
|
||||
cc->Outputs()
|
||||
.Tag(kClassificationListTag)
|
||||
.Add(classification_list.release(), cc->InputTimestamp());
|
||||
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
::mediapipe::Status PyTorchTensorsToClassificationCalculator::Close(
|
||||
CalculatorContext* cc) {
|
||||
return ::mediapipe::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2020 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message PyTorchTensorsToClassificationCalculatorOptions {
|
||||
extend .mediapipe.CalculatorOptions {
|
||||
optional PyTorchTensorsToClassificationCalculatorOptions ext = 266379463;
|
||||
}
|
||||
|
||||
// Score threshold for perserving the class.
|
||||
optional float min_score_threshold = 1;
|
||||
// Number of highest scoring labels to output. If top_k is not positive then
|
||||
// all labels are used.
|
||||
optional int32 top_k = 2;
|
||||
// Path to a label map file for getting the actual name of class ids.
|
||||
optional string label_map_path = 3;
|
||||
}
|
27
mediapipe/examples/desktop/object_classification/BUILD
Normal file
27
mediapipe/examples/desktop/object_classification/BUILD
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_cc//cc:defs.bzl", "cc_binary")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
|
||||
|
||||
cc_binary(
|
||||
name = "object_classification_pytorch_cpu",
|
||||
deps = [
|
||||
"//mediapipe/examples/desktop:demo_run_graph_main",
|
||||
"//mediapipe/graphs/object_classification:desktop_pytorch_calculators",
|
||||
],
|
||||
)
|
32
mediapipe/graphs/object_classification/BUILD
Normal file
32
mediapipe/graphs/object_classification/BUILD
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
cc_library(
|
||||
name = "desktop_pytorch_calculators",
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/pytorch:pytorch_converter_calculator",
|
||||
"//mediapipe/calculators/pytorch:pytorch_inference_calculator",
|
||||
"//mediapipe/calculators/pytorch:pytorch_tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/util:annotation_overlay_calculator",
|
||||
"//mediapipe/calculators/util:labels_to_render_data_calculator",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,120 @@
|
|||
# MediaPipe graph that performs object classification with PyTorch on CPU.
|
||||
# Used in the examples in
|
||||
# mediapipe/examples/desktop/object_classification:object_classification_pytorch_cpu.
|
||||
|
||||
# Images on CPU coming into and out of the graph.
|
||||
input_stream: "input_video"
|
||||
output_stream: "output_video"
|
||||
|
||||
# Throttles the images flowing downstream for flow control. It passes through
|
||||
# the very first incoming image unaltered, and waits for
|
||||
# PyTorchTensorsToClassificationCalculator downstream in the graph to finish
|
||||
# generating the corresponding classifications before it passes through another
|
||||
# image. All images that come in while waiting are dropped, limiting the number
|
||||
# of in-flight images between this calculator and
|
||||
# PyTorchTensorsToClassificationCalculator to 1. This prevents the nodes in between
|
||||
# from queuing up incoming images and data excessively, which leads to increased
|
||||
# latency and memory usage, unwanted in real-time mobile applications. It also
|
||||
# eliminates unnecessarily computation, e.g., a transformed image produced by
|
||||
# ImageTransformationCalculator may get dropped downstream if the subsequent
|
||||
# PyTorchConverterCalculator or PyTorchInferenceCalculator is still busy
|
||||
# processing previous inputs.
|
||||
node {
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "input_video"
|
||||
input_stream: "FINISHED:output_classifications"
|
||||
input_stream_info: {
|
||||
tag_index: "FINISHED"
|
||||
back_edge: true
|
||||
}
|
||||
output_stream: "throttled_input_video"
|
||||
}
|
||||
|
||||
# Transforms the input image on CPU to a 224x224 image. To scale the input
|
||||
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
|
||||
# resulting in potential letterboxing in the transformed image.
|
||||
node: {
|
||||
calculator: "ImageTransformationCalculator"
|
||||
input_stream: "IMAGE:throttled_input_video"
|
||||
output_stream: "IMAGE:transformed_input_video"
|
||||
output_stream: "LETTERBOX_PADDING:letterbox_padding"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
|
||||
output_width: 224
|
||||
output_height: 224
|
||||
scale_mode: FIT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the transformed input image on CPU into an image tensor stored as a
|
||||
# PyTorch tensor. Pixel values are normalized using mean = [0.485, 0.456, 0.406]
|
||||
# and std = [0.229, 0.224, 0.225].
|
||||
node {
|
||||
calculator: "PyTorchConverterCalculator"
|
||||
input_stream: "IMAGE:transformed_input_video"
|
||||
output_stream: "TENSORS:image_tensor"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.PyTorchConverterCalculatorOptions] {
|
||||
per_channel_normalizations: {sub:0.485 div:0.229}
|
||||
per_channel_normalizations: {sub:0.456 div:0.224}
|
||||
per_channel_normalizations: {sub:0.406 div:0.225}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Runs a PyTorch model on CPU that takes an image tensor and outputs a
|
||||
# vector of tensors representing, for instance, classification scores.
|
||||
node {
|
||||
calculator: "PyTorchInferenceCalculator"
|
||||
input_stream: "TENSORS:image_tensor"
|
||||
output_stream: "TENSORS:classification_tensors"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.PyTorchInferenceCalculatorOptions] {
|
||||
model_path: "mediapipe/models/mobilenetv2.pt"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Decodes the classifications tensors generated by the PyTorch model, based on
|
||||
# the specification in the options, into a vector of classifications.
|
||||
# Maps classification label IDs to the corresponding label text. The label map is
|
||||
# provided in the label_map_path option.
|
||||
node {
|
||||
calculator: "PyTorchTensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:classification_tensors"
|
||||
output_stream: "CLASSIFICATION_LIST:output_classifications"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.PyTorchTensorsToClassificationCalculatorOptions] {
|
||||
top_k: 3
|
||||
min_score_threshold: 0.1
|
||||
label_map_path: "mediapipe/models/mobilenetv2.labelmap"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Converts the classifications label to drawing primitives for annotation overlay.
|
||||
node {
|
||||
calculator: "LabelsToRenderDataCalculator"
|
||||
input_stream: "CLASSIFICATIONS:output_classifications"
|
||||
output_stream: "RENDER_DATA:render_data"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.LabelsToRenderDataCalculatorOptions] {
|
||||
color { r: 255 g: 0 b: 0 }
|
||||
color { r: 0 g: 255 b: 0 }
|
||||
color { r: 0 g: 0 b: 255 }
|
||||
thickness: 2.0
|
||||
font_height_px: 20
|
||||
font_face: 1
|
||||
location: TOP_LEFT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Draws annotations and overlays them on top of the input images.
|
||||
node {
|
||||
calculator: "AnnotationOverlayCalculator"
|
||||
input_stream: "IMAGE:throttled_input_video"
|
||||
input_stream: "render_data"
|
||||
output_stream: "IMAGE:output_video"
|
||||
}
|
1
mediapipe/models/.gitignore
vendored
Normal file
1
mediapipe/models/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
/mobilenetv2.pt
|
1000
mediapipe/models/mobilenetv2.labelmap
Normal file
1000
mediapipe/models/mobilenetv2.labelmap
Normal file
File diff suppressed because it is too large
Load Diff
14
mediapipe/models/trace_mobilenetv2.py
Normal file
14
mediapipe/models/trace_mobilenetv2.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# `torch` package known to work:
|
||||
# python -m pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
import torch
|
||||
|
||||
# More about this model at https://pytorch.org/hub/pytorch_vision_mobilenet_v2
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = torch.hub.load('pytorch/vision:v0.5.0',
|
||||
'mobilenet_v2',
|
||||
pretrained=True)
|
||||
model.eval()
|
||||
input_tensor = torch.rand(1, 3, 224, 224)
|
||||
script_model = torch.jit.trace(model, input_tensor)
|
||||
script_model.save("mediapipe/models/mobilenetv2.pt")
|
15
third_party/BUILD
vendored
15
third_party/BUILD
vendored
|
@ -209,6 +209,21 @@ cc_library(
|
|||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libtorch",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
"//mediapipe:android_x86": [],
|
||||
"//mediapipe:android_x86_64": [],
|
||||
"//mediapipe:android_armeabi": [],
|
||||
"//mediapipe:android_arm": [],
|
||||
"//mediapipe:android_arm64": [],
|
||||
"//mediapipe:ios": ["@ios_libtorch//:libtorch"],
|
||||
"//mediapipe:macos": ["@macos_libtorch_cpu//:libtorch_cpu"],
|
||||
"//conditions:default": ["@linux_libtorch_cpu//:libtorch_cpu"],
|
||||
}),
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "androidx_annotation",
|
||||
exports = [
|
||||
|
|
32
third_party/libtorch_linux.BUILD
vendored
Normal file
32
third_party/libtorch_linux.BUILD
vendored
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
|
||||
cc_library(
|
||||
name = "libtorch_cpu",
|
||||
srcs = glob([
|
||||
"lib/libc10.so",
|
||||
"lib/libgomp-*.so.1",
|
||||
"lib/libtorch.so",
|
||||
]),
|
||||
hdrs = glob(["include/**/*.h"]),
|
||||
includes = [
|
||||
"include",
|
||||
"include/TH",
|
||||
"include/THC",
|
||||
"include/torch/csrc/api/include",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user